- jax.eval_shape(fun, *args, **kwargs)#
Compute the shape/dtype of
funwithout any FLOPs.
This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.tree_util.tree_map(shape_dtype_struct, out)
But instead of applying
fundirectly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.
eval_shape()can also catch shape errors, and will raise same shape errors as evaluating
Callable) – The function whose output shape should be evaluated.
*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
dtypeattributes are accessed, one can use
jax.ShapeDtypeStructor another container that duck-types as ndarrays (note however that duck-typed objects cannot be namedtuples because those are treated as standard Python containers).
**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args, array values need only be duck-typed to have
a nested PyTree containing
jax.ShapeDtypeStructobjects as leaves.
- Return type:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32