jax.nn.initializers.zeros#

jax.nn.initializers.zeros(key, shape, dtype=<class 'numpy.float64'>, out_sharding=None)[source]#

一个返回充满零的常量数组的初始化器。

key 参数被忽略。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
参数:
  • key (Array)

  • shape (core.Shape)

  • dtype (DTypeLikeInexact)

  • out_sharding (OutShardingType)

返回类型:

数组