jax.lax.expand_dims# jax.lax.expand_dims(array, dimensions)[source]# Insert any number of size 1 dimensions into an array. Parameters: array (jax.typing.ArrayLike) dimensions (Sequence[int]) Return type: Array