jax.nn.one_hotΒΆ

jax.nn.one_hot(x, num_classes, *, dtype=<class 'jax._src.numpy.lax_numpy.float64'>, axis=-1)[source]ΒΆ

One-hot encodes the given indicies.

Each index in the input x is encoded as a vector of zeros of length num_classes with the element at index set to one:

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
DeviceArray([[1., 0., 0.],
              [0., 1., 0.],
              [0., 0., 1.]], dtype=float32)

Indicies outside the range [0, num_classes) will be encoded as zeros:

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
DeviceArray([[0., 0., 0.],
             [0., 0., 0.]], dtype=float32)
Parameters
  • x (Any) – A tensor of indices.

  • num_classes (int) – Number of classes in the one-hot dimension.

  • dtype (Any) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • axis (Union[int, Hashable]) – the axis or axes along which the function should be computed.

Return type

Any