jax.lax.pad#
- jax.lax.pad(operand, padding_value, padding_config)[source]#
对数组应用低填充、高填充和/或内部填充。
包装 XLA 的 Pad 运算符。
- 参数:
- 返回值:
根据
padding_config
在每个维度中插入填充值padding_value
的operand
数组。- 返回类型:
示例
>>> from jax import lax >>> import jax.numpy as jnp
用零填充 1 维数组,我们将在前面指定两个零,在末尾指定三个零
>>> x = jnp.array([1, 2, 3, 4]) >>> lax.pad(x, 0, [(2, 3, 0)]) Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
用内部零填充 1 维数组;即在每个值之间插入一个零
>>> lax.pad(x, 0, [(0, 0, 1)]) Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
用值
-1
在前面和末尾填充 2 维数组,每个维度的填充大小为 2>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) Array([[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, 1, 2, 3, -1, -1], [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)