jax.numpy.pad#
- jax.numpy.pad(array, pad_width, mode='constant', **kwargs)[源代码]#
向数组添加填充。
JAX 实现的
numpy.pad()
。- 参数:
array (ArrayLike) – 要填充的数组。
pad_width (PadValueLike[int | Array | np.ndarray]) –
指定数组每个维度的填充宽度。可以分别为数组之前和之后指定填充宽度。选项包括
int
或(int,)
:在每个数组维度前后填充相同数量的值。(before, after)
:在每个数组之前填充before
个元素,之后填充after
个元素((before_1, after_1), (before_2, after_2), ... (before_N, after_N))
:为每个数组维度指定不同的before
和after
值。
mode (str | Callable[..., Any]) –
字符串或可调用对象。支持的填充模式包括
'constant'
(默认):用常量值填充,默认为零。'empty'
:用空值(即零)填充'edge'
:用数组的边缘值填充。'wrap'
:通过包裹数组进行填充。'linear_ramp'
:用线性渐变填充到指定的end_values
。'maximum'
:用最大值填充。'mean'
:用平均值填充。'median'
:用中位数填充。'minimum'
:用最小值填充。'reflect'
:通过反射填充。'symmetric'
:通过对称反射填充。<callable>
:可调用函数。请参见下面的注释。
constant_values – 用于
mode = 'constant'
。指定用于填充的常量值。stat_length – 用于
mode in ['maximum', 'mean', 'median', 'minimum']
。一个整数或元组,指定计算统计量时要使用的边缘值的数量。end_values – 用于
mode = 'linear_ramp'
。指定将填充值渐变到的结束值。reflect_type – 用于
mode in ['reflect', 'symmetric']
。指定是使用偶数还是奇数反射。
- 返回:
array
的填充副本。- 返回类型:
注意事项
当
mode
可调用时,它应具有以下签名def pad_func(row: Array, pad_width: tuple[int, int], iaxis: int, kwargs: dict) -> Array: ...
此处
row
是沿轴iaxis
填充数组的 1D 切片,填充值用零填充。pad_width
是一个元组,指定(before, after)
填充大小,而kwargs
是传递给jax.numpy.pad()
函数的任何其他关键字参数。请注意,虽然在 NumPy 中,函数应该就地修改
row
,但在 JAX 中,函数应该返回修改后的row
。在 JAX 中,自定义填充函数将使用jax.vmap()
转换映射到填充轴上。另请参阅
jax.numpy.resize()
:调整数组大小jax.numpy.tile()
:通过平铺较小的数组来创建更大的数组。jax.numpy.repeat()
:通过重复较小数组的值来创建更大的数组。
示例
用零填充一维数组
>>> x = jnp.array([10, 20, 30, 40]) >>> jnp.pad(x, 2) Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) >>> jnp.pad(x, (2, 4)) Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
用指定的值填充一维数组
>>> jnp.pad(x, 2, constant_values=99) Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
用平均数组值填充一维数组
>>> jnp.pad(x, 2, mode='mean') Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
用反射值填充一维数组
>>> jnp.pad(x, 2, mode='reflect') Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
用每个维度中不同的填充来填充二维数组
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.pad(x, ((1, 2), (3, 0))) Array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 2, 3], [0, 0, 0, 4, 5, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=int32)
使用自定义填充函数填充一维数组
>>> def custom_pad(row, pad_width, iaxis, kwargs): ... # row represents a 1D slice of the zero-padded array. ... before, after = pad_width ... before_value = kwargs.get('before_value', 0) ... after_value = kwargs.get('after_value', 0) ... row = row.at[:before].set(before_value) ... return row.at[len(row) - after:].set(after_value) >>> x = jnp.array([2, 3, 4]) >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)