jax.numpy.unpackbits

jax.numpy.unpackbits(a, axis=None, count=None, bitorder='big')[source]

Unpacks elements of a uint8 array into a binary-valued output array.

LAX-backend implementation of unpackbits().

Original docstring below.

Each element of a represents a bit-field that should be unpacked into a binary-valued output array. The shape of the output array is either 1-D (if axis is None) or the same shape as the input array with unpacking done along the axis specified.

Parameters
  • a (ndarray, uint8 type) – Input array.

  • axis (int, optional) – The dimension over which bit-unpacking is done. None implies unpacking the flattened array.

  • count (int or None, optional) – The number of elements to unpack along axis, provided as a way of undoing the effect of packing a size that is not a multiple of eight. A non-negative number means to only unpack count bits. A negative number means to trim off that many bits from the end. None means to unpack the entire array (the default). Counts larger than the available number of bits will add zero padding to the output. Negative counts must not exceed the available number of bits.

  • bitorder ({'big', 'little'}, optional) – The order of the returned bits. ‘big’ will mimic bin(val), 3 = 0b00000011 => [0, 0, 0, 0, 0, 0, 1, 1], ‘little’ will reverse the order to [1, 1, 0, 0, 0, 0, 0, 0]. Defaults to ‘big’.

Returns

unpacked – The elements are binary-valued (0 or 1).

Return type

ndarray, uint8 type