jax.lax.reshape

jax.lax.reshape(operand, new_sizes, dimensions=None)[source]

Wraps XLA’s Reshape operator.

For inserting/removing dimensions of size 1, prefer using lax.squeeze / lax.expand_dims. These preserve information about axis identity that may be useful for advanced transformation rules.

Parameters
Return type

Any