jax.numpy.fromfunction

Contents

jax.numpy.fromfunction#

jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[source]#

Construct an array by executing a function over each coordinate.

LAX-backend implementation of numpy.fromfunction().

Original docstring below.

The resulting array therefore has a value fn(x, y, z) at coordinate (x, y, z).

Parameters:
  • function (callable) – The function is called with N parameters, where N is the rank of shape. Each parameter represents the coordinates of the array varying along a specific axis. For example, if shape were (2, 2), then the parameters would be array([[0, 0], [1, 1]]) and array([[0, 1], [0, 1]])

  • shape ((N,) tuple of ints) – Shape of the output array, which also determines the shape of the coordinate arrays passed to function.

  • dtype (data-type, optional) – Data-type of the coordinate arrays passed to function. By default, dtype is float.

Returns:

fromfunction – The result of the call to function is passed back directly. Therefore the shape of fromfunction is completely determined by function. If function returns a scalar value, the shape of fromfunction would not match the shape parameter.

Return type:

any