jax.eval_shape
jax.eval_shape#
- jax.eval_shape(fun, *args, **kwargs)[source]#
Compute the shape/dtype of
fun
without any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) return jax.tree_util.tree_map(shape_dtype_struct, out) def shape_dtype_struct(x): return ShapeDtypeStruct(x.shape, x.dtype) class ShapeDtypeStruct: __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype
In particular, the output is a pytree of objects that have
shape
anddtype
attributes, but nothing else about them is guaranteed by the API.But instead of applying
fun
directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape()
can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs)
.- Parameters
fun (
Callable
) – The function whose output shape should be evaluated.*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
shape
anddtype
attributes are accessed, only values that duck-type arrays are required, rather than real ndarrays. The duck-typed objects cannot be namedtuples because those are treated as standard Python containers. See the example below.**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args
, array values need only be duck-typed to haveshape
anddtype
attributes.
For example:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = jnp.dtype(dtype) ... >>> A = MyArgArray((2000, 3000), jnp.float32) >>> x = MyArgArray((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32