4510-2
Tutorials
JAX Quickstart
The Autodiff Cookbook
Autobatching log-densities example
Training a Simple Neural Network, with Tensorflow Datasets Data Loading
Advanced JAX Tutorials
🔪 JAX - The Sharp Bits 🔪
Custom derivative rules for JAX-transformable Python functions
How JAX primitives work
Writing custom Jaxpr interpreters in JAX
Notes
Change Log
JAX Frequently Asked Questions (FAQ)
Understanding Jaxprs
Asynchronous dispatch
Concurrency
GPU memory allocation
Profiling JAX programs
Device Memory Profiling
Pytrees
Rank promotion warning
Type promotion semantics
Developer documentation
Building from source
Running the tests
Type checking
Update documentation
Internal APIs
API documentation
Public API: jax package
JAX
Docs
»
Index
Edit on GitHub
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
|
W
|
X
|
Z
_
__init__() (jax.core.ClosedJaxpr method)
(jax.core.Jaxpr method)
(jax.numpy.bool_ method)
(jax.numpy.character method)
(jax.numpy.complex128 method)
(jax.numpy.complex64 method)
(jax.numpy.complexfloating method)
(jax.numpy.dtype method)
(jax.numpy.flexible method)
(jax.numpy.float16 method)
(jax.numpy.float32 method)
(jax.numpy.float64 method)
(jax.numpy.floating method)
(jax.numpy.iinfo method)
(jax.numpy.inexact method)
(jax.numpy.int16 method)
(jax.numpy.int32 method)
(jax.numpy.int64 method)
(jax.numpy.int8 method)
(jax.numpy.integer method)
(jax.numpy.ndarray method)
(jax.numpy.number method)
(jax.numpy.object_ method)
(jax.numpy.signedinteger method)
(jax.numpy.uint16 method)
(jax.numpy.uint32 method)
(jax.numpy.uint64 method)
(jax.numpy.uint8 method)
(jax.numpy.unsignedinteger method)
(jax.profiler.TraceContext method)
A
abs() (in module jax.lax)
(in module jax.numpy)
absolute() (in module jax.numpy)
acos() (in module jax.lax)
adagrad() (in module jax.experimental.optimizers)
adam() (in module jax.experimental.optimizers)
(in module jax.experimental.optix)
adamax() (in module jax.experimental.optimizers)
add() (in module jax.lax)
(in module jax.numpy)
add_noise() (in module jax.experimental.optix)
AddNoiseState (class in jax.experimental.optix)
all() (in module jax.numpy)
all_gather() (in module jax.lax)
all_leaves() (in module jax.tree_util)
all_to_all() (in module jax.lax)
allclose() (in module jax.numpy)
alltrue() (in module jax.numpy)
amax() (in module jax.numpy)
amin() (in module jax.numpy)
angle() (in module jax.numpy)
any() (in module jax.numpy)
append() (in module jax.numpy)
apply_along_axis() (in module jax.numpy)
apply_every() (in module jax.experimental.optix)
apply_over_axes() (in module jax.numpy)
apply_round() (in module jax.random)
apply_updates() (in module jax.experimental.optix)
ApplyEvery (class in jax.experimental.optix)
arange() (in module jax.numpy)
arccos() (in module jax.numpy)
arccosh() (in module jax.numpy)
arcsin() (in module jax.numpy)
arcsinh() (in module jax.numpy)
arctan() (in module jax.numpy)
arctan2() (in module jax.numpy)
arctanh() (in module jax.numpy)
argmax() (in module jax.lax)
(in module jax.numpy)
argmin() (in module jax.lax)
(in module jax.numpy)
argsort() (in module jax.numpy)
argwhere() (in module jax.numpy)
around() (in module jax.numpy)
array() (in module jax.numpy)
array_equal() (in module jax.numpy)
array_equiv() (in module jax.numpy)
array_repr() (in module jax.numpy)
array_split() (in module jax.numpy)
array_str() (in module jax.numpy)
asarray() (in module jax.numpy)
asin() (in module jax.lax)
associative_scan() (in module jax.lax)
atan() (in module jax.lax)
atan2() (in module jax.lax)
atleast_1d() (in module jax.numpy)
atleast_2d() (in module jax.numpy)
atleast_3d() (in module jax.numpy)
average() (in module jax.numpy)
AvgPool() (in module jax.experimental.stax)
axis_index() (in module jax.lax)
B
bartlett() (in module jax.numpy)
batch_matmul() (in module jax.lax)
BatchNorm() (in module jax.experimental.stax)
bernoulli() (in module jax.random)
bessel_i0e() (in module jax.lax)
bessel_i1e() (in module jax.lax)
beta() (in module jax.random)
betainc() (in module jax.lax)
(in module jax.scipy.special)
bincount() (in module jax.numpy)
bitcast_convert_type() (in module jax.lax)
bits (jax.numpy.iinfo attribute)
bitwise_and() (in module jax.lax)
(in module jax.numpy)
bitwise_not() (in module jax.lax)
(in module jax.numpy)
bitwise_or() (in module jax.lax)
(in module jax.numpy)
bitwise_xor() (in module jax.lax)
(in module jax.numpy)
blackman() (in module jax.numpy)
block() (in module jax.numpy)
block_diag() (in module jax.scipy.linalg)
bool_ (class in jax.numpy)
broadcast() (in module jax.lax)
broadcast_arrays() (in module jax.numpy)
broadcast_in_dim() (in module jax.lax)
broadcast_to() (in module jax.numpy)
broadcasted_iota() (in module jax.lax)
build_tree() (in module jax.tree_util)
C
can_cast() (in module jax.numpy)
categorical() (in module jax.random)
cauchy() (in module jax.random)
cbrt() (in module jax.numpy)
cdf() (in module jax.scipy.stats.laplace)
(in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.norm)
cdouble (in module jax.numpy)
ceil() (in module jax.lax)
(in module jax.numpy)
celu() (in module jax.nn)
cg() (in module jax.scipy.sparse.linalg)
chain() (in module jax.experimental.optix)
character (class in jax.numpy)
checkpoint() (in module jax)
cho_factor() (in module jax.scipy.linalg)
cho_solve() (in module jax.scipy.linalg)
choice() (in module jax.random)
cholesky() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
choose() (in module jax.numpy)
clamp() (in module jax.lax)
clip() (in module jax.experimental.optix)
(in module jax.numpy)
clip_by_global_norm() (in module jax.experimental.optix)
clip_grads() (in module jax.experimental.optimizers)
ClipByGlobalNormState (class in jax.experimental.optix)
ClipState (class in jax.experimental.optix)
ClosedJaxpr (class in jax.core)
collapse() (in module jax.lax)
column_stack() (in module jax.numpy)
complex() (in module jax.lax)
complex128 (class in jax.numpy)
complex64 (class in jax.numpy)
complex_ (in module jax.numpy)
complexfloating (class in jax.numpy)
ComplexWarning
compress() (in module jax.numpy)
concatenate() (in module jax.lax)
(in module jax.numpy)
cond() (in module jax.lax)
(in module jax.numpy.linalg)
cond_range() (jax.experimental.loops.Scope method)
conj() (in module jax.lax)
(in module jax.numpy)
conjugate() (in module jax.numpy)
constant() (in module jax.experimental.optimizers)
Conv() (in module jax.experimental.stax)
conv() (in module jax.lax)
Conv1DTranspose() (in module jax.experimental.stax)
conv_general_dilated() (in module jax.lax)
conv_transpose() (in module jax.lax)
conv_with_general_padding() (in module jax.lax)
convert_element_type() (in module jax.lax)
convolve() (in module jax.numpy)
(in module jax.scipy.signal)
convolve2d() (in module jax.scipy.signal)
ConvTranspose() (in module jax.experimental.stax)
copysign() (in module jax.numpy)
corrcoef() (in module jax.numpy)
correlate() (in module jax.numpy)
(in module jax.scipy.signal)
correlate2d() (in module jax.scipy.signal)
cos() (in module jax.lax)
(in module jax.numpy)
cosh() (in module jax.lax)
(in module jax.numpy)
count() (jax.experimental.optix.AddNoiseState property)
(jax.experimental.optix.ApplyEvery property)
(jax.experimental.optix.ScaleByAdamState property)
(jax.experimental.optix.ScaleByScheduleState property)
count_nonzero() (in module jax.numpy)
cov() (in module jax.numpy)
cross() (in module jax.numpy)
csingle (in module jax.numpy)
cumprod() (in module jax.numpy)
cumproduct() (in module jax.numpy)
cumsum() (in module jax.numpy)
custom_jvp() (in module jax)
custom_linear_solve() (in module jax.lax)
custom_root() (in module jax.lax)
custom_vjp() (in module jax)
D
deg2rad() (in module jax.numpy)
degrees() (in module jax.numpy)
Dense() (in module jax.experimental.stax)
det (in module jax.numpy.linalg)
det() (in module jax.scipy.linalg)
device_count() (in module jax)
device_memory_profile() (in module jax.profiler)
device_put() (in module jax)
devices() (in module jax)
diag() (in module jax.numpy)
diag_indices() (in module jax.numpy)
diag_indices_from() (in module jax.numpy)
diagflat() (in module jax.numpy)
diagonal() (in module jax.numpy)
diff() (in module jax.numpy)
digamma() (in module jax.lax)
(in module jax.scipy.special)
digitize() (in module jax.numpy)
dirichlet() (in module jax.random)
disable_jit() (in module jax)
div() (in module jax.lax)
divide() (in module jax.numpy)
divmod() (in module jax.numpy)
dot() (in module jax.lax)
(in module jax.numpy)
dot_general() (in module jax.lax)
double (in module jax.numpy)
double_sided_maxwell() (in module jax.random)
Dropout() (in module jax.experimental.stax)
dsplit() (in module jax.numpy)
dstack() (in module jax.numpy)
dtype (class in jax.numpy)
dynamic_index_in_dim() (in module jax.lax)
dynamic_slice() (in module jax.lax)
dynamic_slice_in_dim() (in module jax.lax)
dynamic_update_index_in_dim() (in module jax.lax)
dynamic_update_slice_in_dim() (in module jax.lax)
E
ediff1d() (in module jax.numpy)
eig() (in module jax.numpy.linalg)
eigh() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
eigvals() (in module jax.numpy.linalg)
eigvalsh() (in module jax.numpy.linalg)
einsum() (in module jax.numpy)
einsum_path() (in module jax.numpy)
elementwise() (in module jax.experimental.stax)
elu() (in module jax.nn)
empty() (in module jax.numpy)
empty_like() (in module jax.numpy)
entr() (in module jax.scipy.special)
eq() (in module jax.lax)
equal() (in module jax.numpy)
erf() (in module jax.lax)
(in module jax.scipy.special)
erf_inv() (in module jax.lax)
erfc() (in module jax.lax)
(in module jax.scipy.special)
erfinv() (in module jax.scipy.special)
eval_shape() (in module jax)
exp() (in module jax.lax)
(in module jax.numpy)
exp2() (in module jax.numpy)
expand_dims() (in module jax.lax)
(in module jax.numpy)
expit (in module jax.scipy.special)
expm() (in module jax.scipy.linalg)
expm1() (in module jax.lax)
(in module jax.numpy)
expm_frechet() (in module jax.scipy.linalg)
exponential() (in module jax.random)
exponential_decay() (in module jax.experimental.optimizers)
extract() (in module jax.numpy)
eye() (in module jax.numpy)
F
fabs() (in module jax.numpy)
FanInConcat() (in module jax.experimental.stax)
FanOut() (in module jax.experimental.stax)
fft() (in module jax.lax)
(in module jax.numpy.fft)
fft2() (in module jax.numpy.fft)
fftfreq() (in module jax.numpy.fft)
fftn() (in module jax.numpy.fft)
fftshift() (in module jax.numpy.fft)
finfo() (in module jax.numpy)
fix() (in module jax.numpy)
flatnonzero() (in module jax.numpy)
flexible (class in jax.numpy)
flip() (in module jax.numpy)
fliplr() (in module jax.numpy)
flipud() (in module jax.numpy)
float16 (class in jax.numpy)
float32 (class in jax.numpy)
float64 (class in jax.numpy)
float_ (in module jax.numpy)
float_power() (in module jax.numpy)
floating (class in jax.numpy)
floor() (in module jax.lax)
(in module jax.numpy)
floor_divide() (in module jax.numpy)
fmax() (in module jax.numpy)
fmin() (in module jax.numpy)
fmod() (in module jax.numpy)
fold_in() (in module jax.random)
fori_loop() (in module jax.lax)
frexp() (in module jax.numpy)
from_dlpack() (in module jax.dlpack)
full() (in module jax.lax)
(in module jax.numpy)
full_like() (in module jax.lax)
(in module jax.numpy)
G
gamma() (in module jax.random)
gammainc() (in module jax.scipy.special)
gammaincc() (in module jax.scipy.special)
gammaln() (in module jax.scipy.special)
gather() (in module jax.lax)
gcd() (in module jax.numpy)
ge() (in module jax.lax)
gelu() (in module jax.nn)
GeneralConv() (in module jax.experimental.stax)
GeneralConvTranspose() (in module jax.experimental.stax)
geomspace() (in module jax.numpy)
global_norm() (in module jax.experimental.optix)
glorot_normal() (in module jax.nn.initializers)
glorot_uniform() (in module jax.nn.initializers)
glu() (in module jax.nn)
grad() (in module jax)
grad_acc() (jax.experimental.optix.ApplyEvery property)
gradient() (in module jax.numpy)
GradientTransformation (class in jax.experimental.optix)
greater() (in module jax.numpy)
greater_equal() (in module jax.numpy)
gt() (in module jax.lax)
gumbel() (in module jax.random)
H
hamming() (in module jax.numpy)
hanning() (in module jax.numpy)
hard_sigmoid() (in module jax.nn)
hard_silu() (in module jax.nn)
hard_swish() (in module jax.nn)
hard_tanh() (in module jax.nn)
he_normal() (in module jax.nn.initializers)
he_uniform() (in module jax.nn.initializers)
heaviside() (in module jax.numpy)
hessian() (in module jax)
hfft() (in module jax.numpy.fft)
histogram() (in module jax.numpy)
histogram_bin_edges() (in module jax.numpy)
histogramdd() (in module jax.numpy)
host_count() (in module jax)
host_id() (in module jax)
host_ids() (in module jax)
hsplit() (in module jax.numpy)
hstack() (in module jax.numpy)
hypot() (in module jax.numpy)
I
i0() (in module jax.numpy)
(in module jax.scipy.special)
i0e() (in module jax.scipy.special)
i1() (in module jax.scipy.special)
i1e() (in module jax.scipy.special)
id_print() (in module jax.experimental.host_callback)
id_tap() (in module jax.experimental.host_callback)
identity() (in module jax.numpy)
ifft() (in module jax.numpy.fft)
ifft2() (in module jax.numpy.fft)
ifftn() (in module jax.numpy.fft)
ifftshift() (in module jax.numpy.fft)
igamma() (in module jax.lax)
igammac() (in module jax.lax)
ihfft() (in module jax.numpy.fft)
iinfo (class in jax.numpy)
imag() (in module jax.lax)
(in module jax.numpy)
in1d() (in module jax.numpy)
index (in module jax.ops)
index_add() (in module jax.ops)
index_in_dim() (in module jax.lax)
index_max() (in module jax.ops)
index_min() (in module jax.ops)
index_mul() (in module jax.ops)
index_take() (in module jax.lax)
index_update() (in module jax.ops)
indices() (in module jax.numpy)
inexact (class in jax.numpy)
init() (jax.experimental.optix.GradientTransformation property)
init_fn() (jax.experimental.optimizers.Optimizer property)
InitUpdate (in module jax.experimental.optix)
inner() (in module jax.numpy)
int16 (class in jax.numpy)
int32 (class in jax.numpy)
int64 (class in jax.numpy)
int8 (class in jax.numpy)
int_ (in module jax.numpy)
integer (class in jax.numpy)
interp() (in module jax.numpy)
intersect1d() (in module jax.numpy)
inv() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
inverse_time_decay() (in module jax.experimental.optimizers)
invert() (in module jax.numpy)
iota() (in module jax.lax)
irfft() (in module jax.numpy.fft)
irfft2() (in module jax.numpy.fft)
irfftn() (in module jax.numpy.fft)
is_finite() (in module jax.lax)
isclose() (in module jax.numpy)
iscomplex() (in module jax.numpy)
iscomplexobj() (in module jax.numpy)
isf() (in module jax.scipy.stats.logistic)
isfinite() (in module jax.numpy)
isin() (in module jax.numpy)
isinf() (in module jax.numpy)
isnan() (in module jax.numpy)
isneginf() (in module jax.numpy)
isposinf() (in module jax.numpy)
isreal() (in module jax.numpy)
isrealobj() (in module jax.numpy)
isscalar() (in module jax.numpy)
issubdtype() (in module jax.numpy)
issubsctype() (in module jax.numpy)
iterable() (in module jax.numpy)
ix_() (in module jax.numpy)
J
jacfwd() (in module jax)
jacrev() (in module jax)
jax.core (module)
jax.dlpack (module)
jax.experimental (module)
jax.experimental.host_callback (module)
jax.experimental.loops (module)
jax.experimental.optimizers (module)
jax.experimental.optix (module)
jax.experimental.stax (module)
jax.image (module)
jax.lax (module)
jax.nn (module)
jax.nn.initializers (module)
jax.numpy (module)
jax.numpy.fft (module)
jax.numpy.linalg (module)
jax.ops (module)
jax.profiler (module)
jax.random (module)
jax.scipy.linalg (module)
jax.scipy.ndimage (module)
jax.scipy.signal (module)
jax.scipy.sparse.linalg (module)
jax.scipy.special (module)
jax.scipy.stats.beta (module)
jax.scipy.stats.expon (module)
jax.scipy.stats.gamma (module)
jax.scipy.stats.laplace (module)
jax.scipy.stats.logistic (module)
jax.scipy.stats.norm (module)
jax.scipy.stats.uniform (module)
jax.tree_util (module)
Jaxpr (class in jax.core)
jit() (in module jax)
JoinPoint (class in jax.experimental.optimizers)
jvp() (in module jax)
K
kaiser() (in module jax.numpy)
kron() (in module jax.numpy)
L
l2_norm() (in module jax.experimental.optimizers)
laplace() (in module jax.random)
lcm() (in module jax.numpy)
ldexp() (in module jax.numpy)
le() (in module jax.lax)
leaky_relu() (in module jax.nn)
lecun_normal() (in module jax.nn.initializers)
lecun_uniform() (in module jax.nn.initializers)
left_shift() (in module jax.numpy)
less() (in module jax.numpy)
less_equal() (in module jax.numpy)
lexsort() (in module jax.numpy)
lgamma() (in module jax.lax)
linearize() (in module jax)
linspace() (in module jax.numpy)
load() (in module jax.numpy)
local_device_count() (in module jax)
local_devices() (in module jax)
log() (in module jax.lax)
(in module jax.numpy)
log10() (in module jax.numpy)
log1p() (in module jax.lax)
(in module jax.numpy)
log2() (in module jax.numpy)
log_ndtr (in module jax.scipy.special)
log_sigmoid() (in module jax.nn)
log_softmax() (in module jax.nn)
logaddexp (in module jax.numpy)
logaddexp2 (in module jax.numpy)
logcdf() (in module jax.scipy.stats.norm)
logical_and() (in module jax.numpy)
logical_not() (in module jax.numpy)
logical_or() (in module jax.numpy)
logical_xor() (in module jax.numpy)
logistic() (in module jax.random)
logit (in module jax.scipy.special)
logpdf() (in module jax.scipy.stats.beta)
(in module jax.scipy.stats.expon)
(in module jax.scipy.stats.gamma)
(in module jax.scipy.stats.laplace)
(in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.norm)
(in module jax.scipy.stats.uniform)
logspace() (in module jax.numpy)
logsumexp() (in module jax.scipy.special)
lstsq() (in module jax.numpy.linalg)
lt() (in module jax.lax)
lu() (in module jax.scipy.linalg)
lu_factor() (in module jax.scipy.linalg)
lu_solve() (in module jax.scipy.linalg)
M
make_jaxpr() (in module jax)
make_schedule() (in module jax.experimental.optimizers)
map() (in module jax.lax)
map_coordinates() (in module jax.scipy.ndimage)
mask_indices() (in module jax.numpy)
matmul() (in module jax.numpy)
matrix_power() (in module jax.numpy.linalg)
matrix_rank() (in module jax.numpy.linalg)
max (jax.numpy.iinfo attribute)
max() (in module jax.lax)
(in module jax.numpy)
maximum() (in module jax.numpy)
MaxPool() (in module jax.experimental.stax)
maxwell() (in module jax.random)
mean() (in module jax.numpy)
median() (in module jax.numpy)
meshgrid() (in module jax.numpy)
min (jax.numpy.iinfo attribute)
min() (in module jax.lax)
(in module jax.numpy)
minimum() (in module jax.numpy)
mod() (in module jax.numpy)
modf() (in module jax.numpy)
momentum() (in module jax.experimental.optimizers)
moveaxis() (in module jax.numpy)
msort() (in module jax.numpy)
mu() (jax.experimental.optix.ScaleByAdamState property)
(jax.experimental.optix.ScaleByRStdDevState property)
mul() (in module jax.lax)
multi_dot() (in module jax.numpy.linalg)
multigammaln() (in module jax.scipy.special)
multiply() (in module jax.numpy)
multivariate_normal() (in module jax.random)
N
nan_to_num() (in module jax.numpy)
nanargmax() (in module jax.numpy)
nanargmin() (in module jax.numpy)
nancumprod() (in module jax.numpy)
nancumsum() (in module jax.numpy)
nanmax() (in module jax.numpy)
nanmean() (in module jax.numpy)
nanmedian() (in module jax.numpy)
nanmin() (in module jax.numpy)
nanpercentile() (in module jax.numpy)
nanprod() (in module jax.numpy)
nanquantile() (in module jax.numpy)
nanstd() (in module jax.numpy)
nansum() (in module jax.numpy)
nanvar() (in module jax.numpy)
ndarray (class in jax.numpy)
ndim() (in module jax.numpy)
ndtr() (in module jax.scipy.special)
ndtri() (in module jax.scipy.special)
ne() (in module jax.lax)
neg() (in module jax.lax)
negative() (in module jax.numpy)
nesterov() (in module jax.experimental.optimizers)
nextafter() (in module jax.lax)
(in module jax.numpy)
noisy_sgd() (in module jax.experimental.optix)
nonzero() (in module jax.numpy)
norm() (in module jax.numpy.linalg)
normal() (in module jax.nn.initializers)
(in module jax.random)
normalize() (in module jax.nn)
not_equal() (in module jax.numpy)
nu() (jax.experimental.optix.ScaleByAdamState property)
(jax.experimental.optix.ScaleByRmsState property)
(jax.experimental.optix.ScaleByRStdDevState property)
number (class in jax.numpy)
O
object_ (class in jax.numpy)
one_hot() (in module jax.nn)
ones() (in module jax.nn.initializers)
(in module jax.numpy)
ones_like() (in module jax.numpy)
Optimizer (class in jax.experimental.optimizers)
optimizer() (in module jax.experimental.optimizers)
OptimizerState (class in jax.experimental.optimizers)
outer() (in module jax.numpy)
outfeed_receiver() (in module jax.experimental.host_callback)
P
pack_optimizer_state() (in module jax.experimental.optimizers)
packbits() (in module jax.numpy)
packed_state() (jax.experimental.optimizers.OptimizerState property)
pad() (in module jax.lax)
(in module jax.numpy)
parallel() (in module jax.experimental.stax)
params_fn() (jax.experimental.optimizers.Optimizer property)
pareto() (in module jax.random)
Partial (class in jax.tree_util)
pdf() (in module jax.scipy.stats.beta)
(in module jax.scipy.stats.expon)
(in module jax.scipy.stats.gamma)
(in module jax.scipy.stats.laplace)
(in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.norm)
(in module jax.scipy.stats.uniform)
percentile() (in module jax.numpy)
permutation() (in module jax.random)
piecewise() (in module jax.numpy)
piecewise_constant() (in module jax.experimental.optimizers)
pinv (in module jax.numpy.linalg)
pmap() (in module jax)
pmax() (in module jax.lax)
pmean() (in module jax.lax)
pmin() (in module jax.lax)
poisson() (in module jax.random)
polyadd() (in module jax.numpy)
polyder() (in module jax.numpy)
polygamma() (in module jax.scipy.special)
polymul() (in module jax.numpy)
polynomial_decay() (in module jax.experimental.optimizers)
polysub() (in module jax.numpy)
polyval() (in module jax.numpy)
population_count() (in module jax.lax)
positive() (in module jax.numpy)
pow() (in module jax.lax)
power() (in module jax.numpy)
ppermute() (in module jax.lax)
ppf() (in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.norm)
PRNGKey() (in module jax.random)
prod() (in module jax.numpy)
product() (in module jax.numpy)
promote_types() (in module jax.numpy)
pshuffle() (in module jax.lax)
psum() (in module jax.lax)
pswapaxes() (in module jax.lax)
ptp() (in module jax.numpy)
Q
qr() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
quantile() (in module jax.numpy)
R
rad2deg() (in module jax.numpy)
rademacher() (in module jax.random)
radians() (in module jax.numpy)
randint() (in module jax.random)
range() (jax.experimental.loops.Scope method)
ravel() (in module jax.numpy)
ravel_multi_index() (in module jax.numpy)
real() (in module jax.lax)
(in module jax.numpy)
reciprocal() (in module jax.lax)
(in module jax.numpy)
reduce() (in module jax.lax)
reduce_window() (in module jax.lax)
register_pytree_node() (in module jax.tree_util)
register_pytree_node_class() (in module jax.tree_util)
relu (in module jax.nn)
relu6() (in module jax.nn)
rem() (in module jax.lax)
remainder() (in module jax.numpy)
repeat() (in module jax.numpy)
reshape() (in module jax.lax)
(in module jax.numpy)
resize() (in module jax.image)
result_type() (in module jax.numpy)
rev() (in module jax.lax)
rfft() (in module jax.numpy.fft)
rfft2() (in module jax.numpy.fft)
rfftfreq() (in module jax.numpy.fft)
rfftn() (in module jax.numpy.fft)
right_shift() (in module jax.numpy)
rint() (in module jax.numpy)
rmsprop() (in module jax.experimental.optimizers)
(in module jax.experimental.optix)
rmsprop_momentum() (in module jax.experimental.optimizers)
rng_key() (jax.experimental.optix.AddNoiseState property)
roll() (in module jax.numpy)
rollaxis() (in module jax.numpy)
rolled_loop_step() (in module jax.random)
roots() (in module jax.numpy)
rot90() (in module jax.numpy)
rotate_left() (in module jax.random)
rotate_list() (in module jax.random)
round() (in module jax.lax)
(in module jax.numpy)
row_stack() (in module jax.numpy)
rsqrt() (in module jax.lax)
S
save() (in module jax.numpy)
save_device_memory_profile() (in module jax.profiler)
savez() (in module jax.numpy)
scale() (in module jax.experimental.optix)
scale_and_translate() (in module jax.image)
scale_by_adam() (in module jax.experimental.optix)
scale_by_rms() (in module jax.experimental.optix)
scale_by_schedule() (in module jax.experimental.optix)
scale_by_stddev() (in module jax.experimental.optix)
ScaleByAdamState (class in jax.experimental.optix)
ScaleByRmsState (class in jax.experimental.optix)
ScaleByRStdDevState (class in jax.experimental.optix)
ScaleByScheduleState (class in jax.experimental.optix)
ScaleState (class in jax.experimental.optix)
scan() (in module jax.lax)
scatter() (in module jax.lax)
scatter_add() (in module jax.lax)
Scope (class in jax.experimental.loops)
searchsorted() (in module jax.numpy)
segment_sum() (in module jax.ops)
select() (in module jax.lax)
(in module jax.numpy)
selu() (in module jax.nn)
serial() (in module jax.experimental.stax)
set_printoptions() (in module jax.numpy)
sf() (in module jax.scipy.stats.logistic)
sgd() (in module jax.experimental.optimizers)
(in module jax.experimental.optix)
shape() (in module jax.numpy)
shape_dependent() (in module jax.experimental.stax)
shift_left() (in module jax.lax)
shift_right_arithmetic() (in module jax.lax)
shift_right_logical() (in module jax.lax)
shuffle() (in module jax.random)
sigmoid() (in module jax.nn)
sign() (in module jax.lax)
(in module jax.numpy)
signbit() (in module jax.numpy)
signedinteger (class in jax.numpy)
silu() (in module jax.nn)
sin() (in module jax.lax)
(in module jax.numpy)
sinc() (in module jax.numpy)
single (in module jax.numpy)
sinh() (in module jax.lax)
(in module jax.numpy)
size() (in module jax.numpy)
slice() (in module jax.lax)
slice_in_dim() (in module jax.lax)
slogdet (in module jax.numpy.linalg)
sm3() (in module jax.experimental.optimizers)
soft_sign() (in module jax.nn)
softmax() (in module jax.nn)
softplus() (in module jax.nn)
solve() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
solve_triangular() (in module jax.scipy.linalg)
sometrue() (in module jax.numpy)
sort() (in module jax.lax)
(in module jax.numpy)
sort_complex() (in module jax.numpy)
sort_key_val() (in module jax.lax)
split() (in module jax.numpy)
(in module jax.random)
sqrt() (in module jax.lax)
(in module jax.numpy)
square() (in module jax.lax)
(in module jax.numpy)
squeeze() (in module jax.lax)
(in module jax.numpy)
stack() (in module jax.numpy)
start_server() (in module jax.profiler)
start_subtrace() (jax.experimental.loops.Scope method)
std() (in module jax.numpy)
stop_gradient() (in module jax.lax)
sub() (in module jax.lax)
subtract() (in module jax.numpy)
subtree_defs() (jax.experimental.optimizers.OptimizerState property)
sum() (in module jax.numpy)
SumPool() (in module jax.experimental.stax)
svd() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
swapaxes() (in module jax.numpy)
swish() (in module jax.nn)
switch() (in module jax.lax)
T
t() (in module jax.random)
take() (in module jax.numpy)
take_along_axis() (in module jax.numpy)
tan() (in module jax.lax)
(in module jax.numpy)
tanh() (in module jax.numpy)
TapFunctionException
tensordot() (in module jax.numpy)
tensorinv() (in module jax.numpy.linalg)
tensorsolve() (in module jax.numpy.linalg)
threefry_2x32() (in module jax.random)
tie_in() (in module jax.lax)
tile() (in module jax.numpy)
to_dlpack() (in module jax.dlpack)
top_k() (in module jax.lax)
trace() (in module jax.experimental.optix)
(in module jax.numpy)
(jax.experimental.optix.TraceState property)
trace_function() (in module jax.profiler)
TraceContext (class in jax.profiler)
TraceState (class in jax.experimental.optix)
transpose() (in module jax.lax)
(in module jax.numpy)
trapz() (in module jax.numpy)
tree_all() (in module jax.tree_util)
tree_def() (jax.experimental.optimizers.OptimizerState property)
tree_flatten() (in module jax.tree_util)
tree_leaves() (in module jax.tree_util)
tree_map() (in module jax.tree_util)
tree_multimap() (in module jax.tree_util)
tree_reduce() (in module jax.tree_util)
tree_structure() (in module jax.tree_util)
tree_transpose() (in module jax.tree_util)
tree_unflatten() (in module jax.tree_util)
treedef_children() (in module jax.tree_util)
treedef_is_leaf() (in module jax.tree_util)
treedef_tuple() (in module jax.tree_util)
tri() (in module jax.numpy)
tril() (in module jax.numpy)
(in module jax.scipy.linalg)
tril_indices() (in module jax.numpy)
tril_indices_from() (in module jax.numpy)
trim_zeros() (in module jax.numpy)
triu() (in module jax.numpy)
(in module jax.scipy.linalg)
triu_indices() (in module jax.numpy)
triu_indices_from() (in module jax.numpy)
true_divide() (in module jax.numpy)
trunc() (in module jax.numpy)
truncated_normal() (in module jax.random)
U
uint16 (class in jax.numpy)
uint32 (class in jax.numpy)
uint64 (class in jax.numpy)
uint8 (class in jax.numpy)
uniform() (in module jax.nn.initializers)
(in module jax.random)
unique() (in module jax.numpy)
unpack_optimizer_state() (in module jax.experimental.optimizers)
unpackbits() (in module jax.numpy)
unravel_index() (in module jax.numpy)
unsignedinteger (class in jax.numpy)
unwrap() (in module jax.numpy)
update() (jax.experimental.optix.GradientTransformation property)
update_fn() (jax.experimental.optimizers.Optimizer property)
V
value_and_grad() (in module jax)
vander() (in module jax.numpy)
var() (in module jax.numpy)
variance_scaling() (in module jax.nn.initializers)
vdot() (in module jax.numpy)
vectorize() (in module jax.numpy)
,
[1]
vjp() (in module jax)
vmap() (in module jax)
vsplit() (in module jax.numpy)
vstack() (in module jax.numpy)
W
weibull_min() (in module jax.random)
where() (in module jax.numpy)
while_loop() (in module jax.lax)
while_range() (jax.experimental.loops.Scope method)
X
xla_computation() (in module jax)
xlog1py() (in module jax.scipy.special)
xlogy() (in module jax.scipy.special)
Z
zeros() (in module jax.nn.initializers)
(in module jax.numpy)
zeros_like() (in module jax.numpy)
zeta() (in module jax.scipy.special)
Read the Docs
v: 4510-2
Versions
latest
stable
4510-2
test-docs
Downloads
html
On Read the Docs
Project Home
Builds
Free document hosting provided by
Read the Docs
.