jax.eval_shape#

jax.eval_shape(fun, *args, **kwargs)[source]#

Compute the shape/dtype of fun without any FLOPs.

This utility function is useful for performing shape inference. Its input/output behavior is defined by:

def eval_shape(fun, *args, **kwargs):
  out = fun(*args, **kwargs)
  return jax.tree_util.tree_map(shape_dtype_struct, out)

def shape_dtype_struct(x):
  return ShapeDtypeStruct(x.shape, x.dtype)

class ShapeDtypeStruct:
  __slots__ = ["shape", "dtype"]
  def __init__(self, shape, dtype):
    self.shape = shape
    self.dtype = dtype

In particular, the output is a pytree of objects that have shape and dtype attributes, but nothing else about them is guaranteed by the API.

But instead of applying fun directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.

Using eval_shape() can also catch shape errors, and will raise same shape errors as evaluating fun(*args, **kwargs).

Parameters
  • fun (Callable) – The function whose output shape should be evaluated.

  • *args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the shape and dtype attributes are accessed, only values that duck-type arrays are required, rather than real ndarrays. The duck-typed objects cannot be namedtuples because those are treated as standard Python containers. See the example below.

  • **kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in args, array values need only be duck-typed to have shape and dtype attributes.

For example:

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> class MyArgArray(object):
...   def __init__(self, shape, dtype):
...     self.shape = shape
...     self.dtype = jnp.dtype(dtype)
...
>>> A = MyArgArray((2000, 3000), jnp.float32)
>>> x = MyArgArray((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x)  # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
float32