Launching Soon: On-Demand, Self-Paced Courses. Learn more!

Dynamic vs Static: Computation Graphs in PyTorch and JAX

Updated on November 26, 2025 7 minutes read

Diagram comparing a dynamic PyTorch computation graph with a JAX JIT workflow, showing forward and backward passes and staged compilation steps

How does your deep learning code “remember” the operations it just executed so it can automatically compute gradients?

In this deep dive, we unpack the anatomy of computation graphs and explain why PyTorch and JAX, two powerhouse libraries in today’s research and production stacks, take different approaches. By the end, you will understand the trade-offs, pick the right tool for your next project, and gather practical tips you can apply immediately.

Want structured, mentor-led practice? Apply to our next Data Science and AI Bootcamp and master both frameworks from first principles to deployment.*

1 What is a Computation Graph?

A computation graph is a directed acyclic graph (DAG) where:

Nodes represent values (tensors, scalars, RNG states). Edges represent operations that transform inputs into outputs

During training, we effectively build two intertwined graphs:

  1. Forward graph charts data flow from inputs to loss
  2. Backward graph created by the autodiff engine so gradients can propagate from the loss back to parameters

Because the graph is just data, libraries can:

Reorder or fuse operations for speed Slice away unused branches (for example, in multi-task models)
Serialize the graph to run on different hardware or runtimes

Why a DAG, not a general graph? Cycles imply a valthat ue that depends (directly or indirectly) on itself. Frameworks avoid that by unrolling loops or linking recurrent edges to previous time step nodes, so each iteration still forms a DAG.

Throughout the article, we will toggle between two mental models: Tape record everything as we go and replay it backward. Trace peek at the function’s structure before running it, then hand a static graph to a compiler.

2 PyTorch: Dynamic Define-By-Run Graphs

PyTorch’s lineage values flexibility. Every time your Python lines execute, PyTorch creates a fresh graph in C++. The grad_fn chain you see when you print a tensor is the visible tip of that iceberg.

import torch

a = torch.randn(3, requires_grad=True)
b = torch.randn(3, requires_grad=True)

c = (a * b).sum()
print(c.grad_fn)  # <SumBackward0 object at ...>

c.backward()  

How the engine works

  • Forward pass – each differentiable op subclasses torch.autograd.Function and pushes a node onto the autograd tape.
  • Backward pass – on tensor.backward(), the C++ engine topologically sorts nodes in reverse and accumulates gradients on leaf tensors.

Perks of dynamism

BenefitWhy it matters
Native Python control flowWrite logic like if random.random() < p: without special graph APIs.
Interactive debuggingDrop in a breakpoint or print(tensor.grad_fn) anywhere.
Hot-reload research loopsTweak layer sizes or hyper-parameters and rerun immediately.

Common pitfalls

  • In-place operations (tensor += 1) may discard needed history and break backprop.
  • Detached tensors when converting to NumPy and back without requires_grad.
  • Memory blow-ups if long-running loops build graphs without torch.no_grad().

3 JAX: Staged, Static Graphs via Tracing

JAX marries Autograd-style differentiation with XLA and a functional mindset. You write pure functions; JAX executes them in three conceptual stages:

Tracing intercept Python and build a JAXPR (SSA-style IR).
Compilation lower JAXPR to HLO, fuse kernels, and emit a binary for the device.
Execution cache the binary; subsequent calls bypass Python almost entirely.

import jax
import jax.numpy as jnp

@jax.jit  # stage out the entire function
def network(x, w):
   return jnp.tanh(x @ w)

print(jax.make_jaxpr(network)(jnp.ones((1, 4)), jnp.ones((4, 4)))) 

Trademark superpowers

FeatureWhat you get
Ahead-of-time compilationOne graph that can target CPU, GPU, or TPU
First-class transformationsjax.grad, vmap, and pmap rewrite graphs algebraically
Fusion and layout controlXLA fuses elementwise ops and chooses efficient tiling and layout

Pain points

  • Compile latency - seconds for larger models, especially on the first call
  • Purity constraints - no side effects in a jitted region
  • Debugging mindset - you think in graphs and traces rather than imperative line-by-line execution

4 Hands-On Comparison

Control flow

# PyTorch
import torch
import random

def coin_net(x, w1, w2):
    if random.random() < 0.5:
        return torch.relu(x @ w1)
    Else:
        Return the torch.sigmoid(x @ w2)
# JAX
import jax
import jax.numpy as jnp

def coin_net(x, key, w1, w2):
    branch = jax.random.bernoulli(key)
    return jax. lax. cond(
        branch,
        lambda _: jnp.relu(x @ w1),
        lambda _: jax.nn.sigmoid(x @ w2),
        operand=None,
    )

Micro-benchmark

Batch sizePyTorch eagerPyTorch torch.compileJAX @jit
321.8 ms1.3 ms0.9 ms
2 04814 ms6.1 ms5.3 ms

RTX 4090, CUDA 12, PyTorch 2.3, JAX 0.4.27. Compile time removed; treat these numbers as illustrative rather than definitive 2026 benchmarks.

5 Graph Optimisation and Memory Tricks

Gradient checkpointing

# PyTorch
from torch.utils.checkpoint import checkpoint

out = checkpoint(block, x)
# JAX
import jax

out = jax.checkpoint(block)(x)

SPMD and model parallelism

  • PyTorch: torchrun, FSDP, and tensor-parallelism plugins
  • JAX: pmap, pjit, and GSPMD partition specs

Mixed precision

  • PyTorch: use torch.cuda.amp.autocast() to run selected regions in lower precision.
  • JAX: set dtype=jnp.float16 or bfloat16; XLA manages kernel selection and loss scaling.

6 Debugging Workflows You Will Actually Use

TaskPyTorchJAX
Graph visualisationtorchviz, TensorBoardjax.debug.visualize_array_shapes, TensorBoard via jax2tf
Kernel profilingNsight Systems, torch.profilerXLA HLO dumps, Perfetto traces
Inline shape printsprint(t.shape)jax.debug.print('{x}', x=x)
BreakpointsNative pdb.set_trace()jax.debug.breakpoint() during tracing

Tip: Combine tools such as torch.cuda.memory_summary() or the XLA_PYTHON_CLIENT_MEM_FRACTION environment variable to watch device memory usage live.

7 Advanced Use Cases

  1. Meta-learning: PyTorch higher-order autograd versus JAX jax.hessian.
  2. Probabilistic programming: Pyro versus NumPyro.
  3. Differentiable physics: PyTorch3D versus BraX.
  4. Large-scale reinforcement learning: Meta ReAgent (PyTorch) versus DeepMind Acme (JAX).
  5. Edge deployment: PyTorch Mobile versus jax2tf exported models for TFLite and related runtimes.

8 Selecting the Right Tool (or Both)

Project constraintChooseRationale
Rapid iterationPyTorchZero compile overhead and excellent debugging experience
Large transformer on TPUJAXXLA fusion and SPMD friendly APIs
ONNX centric pipelinePyTorchMature ONNX exporter and ecosystem
Functional programming codebaseJAXFits naturally with pure functions and transformations
Mixed hardware fleetHybridPrototype quickly in PyTorch; deploy a JAX or TF stack where XLA shines

9 Conclusion

Computation graphs are the invisible scaffolding of modern deep learning. PyTorch dynamic tape feels like Python itself: immediate and malleable. JAX static graphs feel like a compiler: stricter up front and often very fast once compiled.

Mastering both gives you leverage across research prototypes and production inference in 2026 and beyond.

Ready to bend graphs to your will? Apply to our next Data Science and AI Bootcamp and turn these concepts into portfolio-ready projects.

Appendix A Autograd Math Refresher

For a composite function ( f(g(h(x))) ), the derivative is:

[ \frac{\partial f}{\partial x} = \frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial h} \cdot \frac{\partial h}{\partial x}. ]

Frameworks attach local Jacobian vector products (JVPs) or vector Jacobian products (VJPs) to each operation. The engine feeds upstream gradients through these closures, chaining them according to the graph structure.

Appendix B – DIY Graph Inspectors

PyTorch

def dump_graph(t, visited=None, depth=0):
    if visited is None:
        visited = set()
   If t.grad_fn and t.grad_fn not in visited:
        visited.add(t.grad_fn)
        print('  ' * depth, t.grad_fn)
        for nxt, _ in t.grad_fn.next_functions:
            if nxt is None:
                dump_graph(nxt, visited, depth + 1)
import jax
import jax.numpy as jnp

@jax.jit
def network(x, w):
    return jnp.tanh(x @ w)

def count_ops(f, *args):
    jaxpr = jax.make_jaxpr(f)(*args)
    return len(jaxpr.jaxpr.eqns)

print('Ops:', count_ops(network, jnp.ones((1, 4)), jnp.ones((4, 4))))

Looking Ahead: Graphs Beyond 2026

  • Kernel fusion DSLs - projects such as torch-mlir and StableHLO aim to make cross-framework graphs increasingly interchangeable.
  • Composable compiler passes - plug-in passes to quantise, prune, or otherwise transform graphs with a few lines of Python.
  • Neural compute fabric - edge and data-centre chips ingest ONNX or StableHLO directly, making graph exports a deployment contract.
  • Graph-level privacy - differential-privacy passes inject calibrated noise per edge (for example, via Opacus and emerging JAX prototypes).

In 2026 and beyond, graph literacy is not optional; it is the lens through which the next decade of AI tooling will be designed.

Ready to get hands-on? Clone the companion GitHub repo, run the notebooks on a free Colab GPU, then challenge yourself to port the same model back and forth between PyTorch and JAX. The muscle memory pays off when mixed codebases land on your desk.

Frequently Asked Questions

Career Services

Personalised career support to launch your tech career. Benefit from résumé reviews, mock interviews and insider industry insights so you can showcase your new skills with confidence.