jax.extend.core.Primitive#

class jax.extend.core.Primitive(name)[源代码]#
参数:

name (str)

__init__(name)[源代码]#
参数:

name (str)

方法

__init__(name)

abstract_eval(*args, **params)

bind(*args, **params)

bind_with_trace(trace, args, params)

def_abstract_eval(abstract_eval)

def_bind_with_trace(bind_with_trace)

def_effectful_abstract_eval(...)

def_impl(impl)

get_bind_params(params)

impl(*args, **params)

属性

call_primitive

map_primitive

multiple_results

ref_primitive

name