jax.eval_shape#
- jax.eval_shape(fun, *args, **kwargs)[source]#
Compute the shape/dtype of
fun
without any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.tree_util.tree_map(shape_dtype_struct, out)
But instead of applying
fun
directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape()
can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs)
.- Parameters:
fun (
Callable
) – The function whose output shape should be evaluated.*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
shape
anddtype
attributes are accessed, one can usejax.ShapeDtypeStruct
or another container that duck-types as ndarrays (note however that duck-typed objects cannot be namedtuples because those are treated as standard Python containers).**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args
, array values need only be duck-typed to haveshape
anddtype
attributes.
- Returns:
a nested PyTree containing
jax.ShapeDtypeStruct
objects as leaves.- Return type:
out
For example:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32