jax.debug.visualize_sharding#
- jax.debug.visualize_sharding(shape, sharding, *, use_color=True, scale=1.0, min_width=9, max_width=80, color_map=None)[source]#
Visualizes a
Sharding
usingrich
.- Parameters:
shape (Sequence[int]) –
sharding (Sharding) –
use_color (bool) –
scale (float) –
min_width (int) –
max_width (int) –
color_map (ColorMap | None) –