jax.nn.squareplus#

jax.nn.squareplus(x, b=4)[source]#

Squareplus 激活函数。

计算元素级函数

squareplus(x)=x+x2+b2

https://arxiv.org/abs/2112.11687 中所述。

参数:
  • x (ArrayLike) – 输入数组

  • b (ArrayLike) – 平滑度参数

返回类型:

Array