jax.numpy.array#
- jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0)[source]#
Create an array.
LAX-backend implementation of
numpy.array()
.This function will create arrays on JAXβs default device. For control of the device placement of data, see
jax.device_put()
. More information is available in the JAX FAQ at Controlling data and computation placement on devices (full FAQ at https://jax.readthedocs.io/en/latest/faq.html).Original docstring below.
- Parameters
object (array_like) β An array, any object exposing the array interface, an object whose
__array__
method returns an array, or any (nested) sequence. If object is a scalar, a 0-dimensional array containing object is returned.dtype (data-type, optional) β The desired data-type for the array. If not given, NumPy will try to use a default
dtype
that can represent the values (by applying promotion rules when necessary.)copy (bool, optional) β If true (default), then the object is copied. Otherwise, a copy will only be made if
__array__
returns a copy, if obj is a nested sequence, or if a copy is needed to satisfy any of the other requirements (dtype
,order
, etc.).order ({'K', 'A', 'C', 'F'}, optional) β
Specify the memory layout of the array. If object is not an array, the newly created array will be in C order (row major) unless βFβ is specified, in which case it will be in Fortran order (column major). If object is an array the following holds.
order
no copy
copy=True
βKβ
unchanged
F & C order preserved, otherwise most similar order
βAβ
unchanged
F order if input is F and not C, otherwise C order
βCβ
C order
C order
βFβ
F order
F order
When
copy=False
and a copy is made for other reasons, the result is the same as ifcopy=True
, with some exceptions for βAβ, see the Notes section. The default order is βKβ.ndmin (int, optional) β Specifies the minimum number of dimensions that the resulting array should have. Ones will be prepended to the shape as needed to meet this requirement.
- Returns
out β An array object satisfying the specified requirements.
- Return type
ndarray