jax.numpy.amax# jax.numpy.amax(a, axis=None, out=None, keepdims=False, initial=None, where=None)[source]# 别名 jax.numpy.max()。 参数: a (ArrayLike) axis (Axis) out (None) keepdims (bool) initial (ArrayLike | None) where (ArrayLike | None) 返回类型: Array