jax.numpy.trim_zeros#
- jax.numpy.trim_zeros(filt, trim='fb', axis=None)[源代码]#
修剪输入数组的开头和/或结尾的零。
JAX 对
numpy.trim_zeros()
的实现。- 参数:
- 返回:
包含修剪后输入(与
filt
具有相同 dtype)的数组。- 返回类型:
示例
一维输入
>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) >>> jnp.trim_zeros(x) Array([2, 0, 1, 4, 3], dtype=int32) >>> jnp.trim_zeros(x, trim='f') Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32) >>> jnp.trim_zeros(x, trim='b') Array([0, 0, 2, 0, 1, 4, 3], dtype=int32)
二维输入
>>> x = jnp.zeros((4, 5)).at[1:3, 1:4].set(1) >>> x Array([[0., 0., 0., 0., 0.], [0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.], [0., 0., 0., 0., 0.]], dtype=float32) >>> jnp.trim_zeros(x) Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32) >>> jnp.trim_zeros(x, trim='f') Array([[1., 1., 1., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.]], dtype=float32) >>> jnp.trim_zeros(x, axis=0) Array([[0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.]], dtype=float32) >>> jnp.trim_zeros(x, axis=1) Array([[0., 0., 0.], [1., 1., 1.], [1., 1., 1.], [0., 0., 0.]], dtype=float32)