Dynamic vs Static: Computation Graphs in PyTorch and JAX
Updated on November 26, 2025 7 minutes read
How does your deep learning code “remember” the operations it just executed so it can automatically compute gradients?
In this deep dive, we unpack the anatomy of computation graphs and explain why PyTorch and JAX, two powerhouse libraries in today’s research and production stacks, take different approaches. By the end, you will understand the trade-offs, pick the right tool for your next project, and gather practical tips you can apply immediately.
Want structured, mentor-led practice? Apply to our next Data Science and AI Bootcamp and master both frameworks from first principles to deployment.*
1 What is a Computation Graph?
A computation graph is a directed acyclic graph (DAG) where:
Nodes represent values (tensors, scalars, RNG states). Edges represent operations that transform inputs into outputs
During training, we effectively build two intertwined graphs:
- Forward graph charts data flow from inputs to loss
- Backward graph created by the autodiff engine so gradients can propagate from the loss back to parameters
Because the graph is just data, libraries can:
Reorder or fuse operations for speed
Slice away unused branches (for example, in multi-task models)
Serialize the graph to run on different hardware or runtimes
Why a DAG, not a general graph? Cycles imply a valthat ue that depends (directly or indirectly) on itself. Frameworks avoid that by unrolling loops or linking recurrent edges to previous time step nodes, so each iteration still forms a DAG.
Throughout the article, we will toggle between two mental models: Tape record everything as we go and replay it backward. Trace peek at the function’s structure before running it, then hand a static graph to a compiler.
2 PyTorch: Dynamic Define-By-Run Graphs
PyTorch’s lineage values flexibility. Every time your Python lines execute, PyTorch creates a fresh graph in C++. The grad_fn chain you see when you print a tensor is the visible tip of that iceberg.
import torch
a = torch.randn(3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
c = (a * b).sum()
print(c.grad_fn) # <SumBackward0 object at ...>
c.backward()
How the engine works
- Forward pass – each differentiable op subclasses
torch.autograd.Functionand pushes a node onto the autograd tape. - Backward pass – on
tensor.backward(), the C++ engine topologically sorts nodes in reverse and accumulates gradients on leaf tensors.
Perks of dynamism
| Benefit | Why it matters |
|---|---|
| Native Python control flow | Write logic like if random.random() < p: without special graph APIs. |
| Interactive debugging | Drop in a breakpoint or print(tensor.grad_fn) anywhere. |
| Hot-reload research loops | Tweak layer sizes or hyper-parameters and rerun immediately. |
Common pitfalls
- In-place operations (
tensor += 1) may discard needed history and break backprop. - Detached tensors when converting to NumPy and back without
requires_grad. - Memory blow-ups if long-running loops build graphs without
torch.no_grad().
3 JAX: Staged, Static Graphs via Tracing
JAX marries Autograd-style differentiation with XLA and a functional mindset. You write pure functions; JAX executes them in three conceptual stages:
Tracing intercept Python and build a JAXPR (SSA-style IR).
Compilation lower JAXPR to HLO, fuse kernels, and emit a binary for the device.
Execution cache the binary; subsequent calls bypass Python almost entirely.
import jax
import jax.numpy as jnp
@jax.jit # stage out the entire function
def network(x, w):
return jnp.tanh(x @ w)
print(jax.make_jaxpr(network)(jnp.ones((1, 4)), jnp.ones((4, 4))))
Trademark superpowers
| Feature | What you get |
|---|---|
| Ahead-of-time compilation | One graph that can target CPU, GPU, or TPU |
| First-class transformations | jax.grad, vmap, and pmap rewrite graphs algebraically |
| Fusion and layout control | XLA fuses elementwise ops and chooses efficient tiling and layout |
Pain points
- Compile latency - seconds for larger models, especially on the first call
- Purity constraints - no side effects in a jitted region
- Debugging mindset - you think in graphs and traces rather than imperative line-by-line execution
4 Hands-On Comparison
Control flow
# PyTorch
import torch
import random
def coin_net(x, w1, w2):
if random.random() < 0.5:
return torch.relu(x @ w1)
Else:
Return the torch.sigmoid(x @ w2)
# JAX
import jax
import jax.numpy as jnp
def coin_net(x, key, w1, w2):
branch = jax.random.bernoulli(key)
return jax. lax. cond(
branch,
lambda _: jnp.relu(x @ w1),
lambda _: jax.nn.sigmoid(x @ w2),
operand=None,
)
Micro-benchmark
| Batch size | PyTorch eager | PyTorch torch.compile | JAX @jit |
|---|---|---|---|
| 32 | 1.8 ms | 1.3 ms | 0.9 ms |
| 2 048 | 14 ms | 6.1 ms | 5.3 ms |
RTX 4090, CUDA 12, PyTorch 2.3, JAX 0.4.27. Compile time removed; treat these numbers as illustrative rather than definitive 2026 benchmarks.
5 Graph Optimisation and Memory Tricks
Gradient checkpointing
# PyTorch
from torch.utils.checkpoint import checkpoint
out = checkpoint(block, x)
# JAX
import jax
out = jax.checkpoint(block)(x)
SPMD and model parallelism
- PyTorch:
torchrun, FSDP, and tensor-parallelism plugins - JAX:
pmap,pjit, and GSPMD partition specs
Mixed precision
- PyTorch: use
torch.cuda.amp.autocast()to run selected regions in lower precision. - JAX: set
dtype=jnp.float16orbfloat16; XLA manages kernel selection and loss scaling.
6 Debugging Workflows You Will Actually Use
| Task | PyTorch | JAX |
|---|---|---|
| Graph visualisation | torchviz, TensorBoard | jax.debug.visualize_array_shapes, TensorBoard via jax2tf |
| Kernel profiling | Nsight Systems, torch.profiler | XLA HLO dumps, Perfetto traces |
| Inline shape prints | print(t.shape) | jax.debug.print('{x}', x=x) |
| Breakpoints | Native pdb.set_trace() | jax.debug.breakpoint() during tracing |
Tip: Combine tools such as
torch.cuda.memory_summary()or theXLA_PYTHON_CLIENT_MEM_FRACTIONenvironment variable to watch device memory usage live.
7 Advanced Use Cases
- Meta-learning: PyTorch higher-order autograd versus JAX
jax.hessian. - Probabilistic programming: Pyro versus NumPyro.
- Differentiable physics: PyTorch3D versus BraX.
- Large-scale reinforcement learning: Meta ReAgent (PyTorch) versus DeepMind Acme (JAX).
- Edge deployment: PyTorch Mobile versus
jax2tfexported models for TFLite and related runtimes.
8 Selecting the Right Tool (or Both)
| Project constraint | Choose | Rationale |
|---|---|---|
| Rapid iteration | PyTorch | Zero compile overhead and excellent debugging experience |
| Large transformer on TPU | JAX | XLA fusion and SPMD friendly APIs |
| ONNX centric pipeline | PyTorch | Mature ONNX exporter and ecosystem |
| Functional programming codebase | JAX | Fits naturally with pure functions and transformations |
| Mixed hardware fleet | Hybrid | Prototype quickly in PyTorch; deploy a JAX or TF stack where XLA shines |
9 Conclusion
Computation graphs are the invisible scaffolding of modern deep learning. PyTorch dynamic tape feels like Python itself: immediate and malleable. JAX static graphs feel like a compiler: stricter up front and often very fast once compiled.
Mastering both gives you leverage across research prototypes and production inference in 2026 and beyond.
Ready to bend graphs to your will? Apply to our next Data Science and AI Bootcamp and turn these concepts into portfolio-ready projects.
Appendix A Autograd Math Refresher
For a composite function ( f(g(h(x))) ), the derivative is:
[ \frac{\partial f}{\partial x} = \frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial h} \cdot \frac{\partial h}{\partial x}. ]
Frameworks attach local Jacobian vector products (JVPs) or vector Jacobian products (VJPs) to each operation. The engine feeds upstream gradients through these closures, chaining them according to the graph structure.
Appendix B – DIY Graph Inspectors
PyTorch
def dump_graph(t, visited=None, depth=0):
if visited is None:
visited = set()
If t.grad_fn and t.grad_fn not in visited:
visited.add(t.grad_fn)
print(' ' * depth, t.grad_fn)
for nxt, _ in t.grad_fn.next_functions:
if nxt is None:
dump_graph(nxt, visited, depth + 1)
import jax
import jax.numpy as jnp
@jax.jit
def network(x, w):
return jnp.tanh(x @ w)
def count_ops(f, *args):
jaxpr = jax.make_jaxpr(f)(*args)
return len(jaxpr.jaxpr.eqns)
print('Ops:', count_ops(network, jnp.ones((1, 4)), jnp.ones((4, 4))))
Looking Ahead: Graphs Beyond 2026
- Kernel fusion DSLs - projects such as
torch-mlirandStableHLOaim to make cross-framework graphs increasingly interchangeable. - Composable compiler passes - plug-in passes to quantise, prune, or otherwise transform graphs with a few lines of Python.
- Neural compute fabric - edge and data-centre chips ingest ONNX or StableHLO directly, making graph exports a deployment contract.
- Graph-level privacy - differential-privacy passes inject calibrated noise per edge (for example, via Opacus and emerging JAX prototypes).
In 2026 and beyond, graph literacy is not optional; it is the lens through which the next decade of AI tooling will be designed.
Ready to get hands-on? Clone the companion GitHub repo, run the notebooks on a free Colab GPU, then challenge yourself to port the same model back and forth between PyTorch and JAX. The muscle memory pays off when mixed codebases land on your desk.