jax.numpy.column_stack

jax.numpy.column_stack(tup)[source]

Stack 1-D arrays as columns into a 2-D array.

LAX-backend implementation of column_stack(). Original docstring below.

Take a sequence of 1-D arrays and stack them as columns to make a single 2-D array. 2-D arrays are stacked as-is, just like with hstack. 1-D arrays are turned into 2-D columns first.

Parameters

tup (sequence of 1-D or 2-D arrays.) – Arrays to stack. All of them must have the same first dimension.

Returns

stacked – The array formed by stacking the given arrays.

Return type

2-D array

Examples

>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.column_stack((a,b))
array([[1, 2],
       [2, 3],
       [3, 4]])