jax.lax.pad#
- jax.lax.pad(operand, padding_value, padding_config)[source]#
将低、高和/或内部填充应用于数组。
封装了 XLA 的 Pad 运算符。
- 参数:
- 返回:
具有填充值
padding_value
的operand
数组,根据padding_config
插入到每个维度中。- 返回类型:
示例
>>> from jax import lax >>> import jax.numpy as jnp
用零填充 1 维数组。我们将在前面指定两个零,在末尾指定三个零
>>> x = jnp.array([1, 2, 3, 4]) >>> lax.pad(x, 0, [(2, 3, 0)]) Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
用内部零填充 1 维数组;即在每个值之间插入一个零
>>> lax.pad(x, 0, [(0, 0, 1)]) Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
用值
-1
在前面和末尾填充 2 维数组,每个维度的填充大小为 2>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) Array([[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, 1, 2, 3, -1, -1], [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)