jax.ShapeDtypeStruct

jax.ShapeDtypeStruct#

class jax.ShapeDtypeStruct(shape, dtype, named_shape=None, sharding=None)[source]#

A container for the shape, dtype, and other static attributes of an array.

ShapeDtypeStruct is often used in conjunction with jax.eval_shape().

Parameters:
  • shape – a sequence of integers representing an array shape

  • dtype – a dtype-like object

  • named_shape – (optional) a dictionary representing a named shape

  • sharding – (optional) a jax.Sharding object

__init__(shape, dtype, named_shape=None, sharding=None)[source]#

Methods

__init__(shape, dtype[, named_shape, sharding])

Attributes

shape

dtype

named_shape

sharding

layout

ndim

size