jax.numpy.fromfunction#
- jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[source]#
Construct an array by executing a function over each coordinate.
LAX-backend implementation of
numpy.fromfunction()
.Original docstring below.
The resulting array therefore has a value
fn(x, y, z)
at coordinate(x, y, z)
.- Parameters:
function (callable) – The function is called with N parameters, where N is the rank of shape. Each parameter represents the coordinates of the array varying along a specific axis. For example, if shape were
(2, 2)
, then the parameters would bearray([[0, 0], [1, 1]])
andarray([[0, 1], [0, 1]])
shape ((N,) tuple of ints) – Shape of the output array, which also determines the shape of the coordinate arrays passed to function.
dtype (data-type, optional) – Data-type of the coordinate arrays passed to function. By default, dtype is float.
- Returns:
fromfunction – The result of the call to function is passed back directly. Therefore the shape of fromfunction is completely determined by function. If function returns a scalar value, the shape of fromfunction would not match the shape parameter.
- Return type:
any