jax.hessian

Contents

jax.hessian#

jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

Hessian of fun as a dense array.

Parameters:
  • fun (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Return type:

Callable

Returns:

A function with the same arguments as fun, that evaluates the Hessian of fun.

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian() is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure of jax.hessian(fun)(x) is given by forming a tree product of the structure of fun(x) with a tree product of two copies of the structure of x. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

Thus each leaf in the tree structure of jax.hessian(fun)(x) corresponds to a leaf of fun(x) and a pair of leaves of x. For each leaf in jax.hessian(fun)(x), if the corresponding array leaf of fun(x) has shape (out_1, out_2, ...) and the corresponding array leaves of x have shape (in_1_1, in_1_2, ...) and (in_2_1, in_2_2, ...) respectively, then the Hessian leaf has shape (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...). In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.

In particular, an array is produced (with no pytrees involved) when the function input x and output fun(x) are each a single array, as in the g example above. If fun(x) has shape (out1, out2, ...) and x has shape (in1, in2, ...) then jax.hessian(fun)(x) has shape (out1, out2, ..., in1, in2, ..., in1, in2, ...). To flatten pytrees into 1D vectors, consider using jax.flatten_util.flatten_pytree().