jax.numpy.trim_zeros#

jax.numpy.trim_zeros(filt, trim='fb', axis=None)[源代码]#

修剪输入数组的开头和/或结尾的零。

JAX 对 numpy.trim_zeros() 的实现。

参数:
  • filt (ArrayLike) – N 维输入数组。

  • trim (str) –

    字符串,可选,默认为 fb。指定从哪一端修剪输入。

    • f - 仅修剪开头的零。

    • b - 仅修剪结尾的零。

    • fb - 修剪开头和结尾的零。

  • axis (int | Sequence[int] | None) – 可选的用于修剪的轴或轴。如果未指定,则沿数组的所有轴进行修剪。

返回:

包含修剪后输入(与 filt 具有相同 dtype)的数组。

返回类型:

Array

示例

一维输入

>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0])
>>> jnp.trim_zeros(x)
Array([2, 0, 1, 4, 3], dtype=int32)
>>> jnp.trim_zeros(x, trim='f')
Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32)
>>> jnp.trim_zeros(x, trim='b')
Array([0, 0, 2, 0, 1, 4, 3], dtype=int32)

二维输入

>>> x = jnp.zeros((4, 5)).at[1:3, 1:4].set(1)
>>> x
Array([[0., 0., 0., 0., 0.],
       [0., 1., 1., 1., 0.],
       [0., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)
>>> jnp.trim_zeros(x)
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)
>>> jnp.trim_zeros(x, trim='f')
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.]], dtype=float32)
>>> jnp.trim_zeros(x, axis=0)
Array([[0., 1., 1., 1., 0.],
       [0., 1., 1., 1., 0.]], dtype=float32)
>>> jnp.trim_zeros(x, axis=1)
Array([[0., 0., 0.],
       [1., 1., 1.],
       [1., 1., 1.],
       [0., 0., 0.]], dtype=float32)