jax.experimental.checkify.check#
- jax.experimental.checkify.check(pred, msg, *fmt_args, debug=False, **fmt_kwargs)[源代码]#
检查一个谓词,如果谓词为 False,则添加带有 msg 的错误。
这是一个有副作用的操作,不能被 staged (jitted/scanned/…)。在 staging 带有检查的函数之前,使用
checkify()
来 checkify 它!- 参数:
pred (Bool) – 如果为 False,则添加 FailedCheckError 错误。
msg (str) – 如果添加错误,则为错误消息。 可以是格式化字符串。
debug (bool) – 是否打开调试模式。 如果为 True,则将在执行期间删除 check。 如果为 False,则必须使用 checkify.checkify 对 check 进行函数化。
fmt_args – msg 的位置和关键字格式化参数,例如:
check(.., "check failed on values {} and {named_arg}", x, named_arg=y)
请注意,这些参数可以是 traced 值,允许您将运行时值添加到错误消息。 请注意,跟踪这些运行时数组将增加您的内存使用量,即使没有发生错误。fmt_kwargs – msg 的位置和关键字格式化参数,例如:
check(.., "check failed on values {} and {named_arg}", x, named_arg=y)
请注意,这些参数可以是 traced 值,允许您将运行时值添加到错误消息。 请注意,跟踪这些运行时数组将增加您的内存使用量,即使没有发生错误。
- 返回类型:
无
例如
>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "{x} needs to be positive!", x=x) ... return 1/x >>> checked_f = checkify.checkify(f) >>> err, out = jax.jit(checked_f)(-3.) >>> err.throw() Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: -3. needs to be positive!