A Principled ML Compiler Stack in 5,000 Lines of Python

Understanding the modern ML compiler stack is a thankless task. TVM is 500K lines of C++. The PyTorch stack piles Dynamo, Inductor, and Triton on top of each other. XLA, MLIR, Halide, Mojo. Where do you even start?
I pulled my hair out trying to come up with an overview, until I finally gave up. Screw it — I'd just build an LLM compiler from scratch and document the process. To keep it realistically reproducible by an external reader, I added constraints: 5,000 LOC, a 2-week time limit, no library use, just pure Python and raw CUDA.
The result is a complete reference stack end-to-end. A PyTorch graph flows through six intermediate representations, each one closer to the hardware than the last: decomposition, fusion, tiling, scheduling, and finally CUDA codegen.
This first part takes a common RMSNorm layer from a transformer and walks it through the upper half of the pipeline: capture, decomposition into primitives, fusion into loop nests, and tile-level scheduling for the GPU. At the end I'll print the emitted CUDA kernel and benchmark a full transformer block against the production stack. The codegen mechanics and matmul optimization (the Tile IR section) are covered in depth in Part 2 (forthcoming).
To run the examples in this article and interact with the compiler, clone the repository on your local machine and create a virtual environment (GPU not required):
git clone https://github.com/cloudrift-ai/deplodock.git
cd deplodock && make setup
source venv/bin/activate
Vendor kernels are still hard to beat at full prefill on the FFN-width matmuls, which is why every production stack falls back to cuBLAS/cuDNN/CUTLASS on the heavy hitters and code-generates everything around them. I'm deliberately not doing that here because the goal is to show precisely what the compiler stack buys you on its own and what it doesn't. For what closing the remaining matmul gap actually takes — CTA tiling, shared-memory double-buffering, TMA-driven producer/consumer pipelines, SASS-level scheduling — see my earlier writeup, Beating cuBLAS on RTX 5090.
Pipeline
ML compilers are layered. Each layer has its own IR, an intermediate representation, i.e. a small language that describes the computation at that stage. The top-level IR starts as a sequence of PyTorch calls like torch.linear or torch.exp; the bottom-level IR is already CUDA source, ready to hand to nvcc.
This layering structure allows decomposing the problem into smaller pieces. Each layer transformation is a simple algorithm that adds just a few details to the language, making it closer to the final target.
Deplodock uses six such languages: Torch IR → Tensor IR → Loop IR → Tile IR → Kernel IR → CUDA. What each one looks like for a single matmul:
| IR | What Happens | Naive Matrix Multiplication Representation |
|---|---|---|
| Torch IR | Capture the FX graph | C = matmul(A, B) |
| Tensor IR | Decompose Torch ops | A_bc = A[i, k, na] |
| Loop IR | Convert to loops and fuse | for i in 0..M: # free |
| Tile IR | Schedule kernels | Tile(axes=(i:M=THREAD, j:N=THREAD)): |
| Kernel IR | Framework-agnostic kernel | long long tid = blockIdx.x * blockDim.x + threadIdx.x |
| CUDA | Render CUDA kernels | __global__ void matmul(float* A, float* B, float* C) { |
Every ML compiler is some version of this pipeline. What differs is how each IR is designed and which passes run between them; that's where the engineering judgment lives. The rest of this article walks through one such design end-to-end, and every stage is printable on demand: deplodock compile … --ir tensor (or loop, tile, kernel, cuda) renders the same computation at any layer of the stack, on real code.
For the rest of the article I'll use the RMSNorm layer as a running example. RMSNorm sits before every attention and FFN block in Llama, Qwen, and most modern transformers. Decomposed into primitives, its DAG looks like this:
Torch IR — Capturing PyTorch
Step one is to turn the PyTorch module into a graph. Tracing is the standard technique: it walks the module's forward code and records every operation, with concrete input shapes, into an FX graph.
exported = torch.export.export(module, example_inputs)
graph = convert_to_graph_ir(exported.graph_module)
The output is Torch IR, a 1:1 mirror of the FX graph using PyTorch's exact op set. For a whole transformer layer this is ~50–80 nodes. Tracing the RMSNorm module gives a boring graph with a single rms_norm op:
deplodock compile -c "nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir torch
# Graph: 3 nodes, 1 inputs, 1 outputs
inputs:
x: (1, 32, 2048) float32
constants:
p_weight: (2048,) float32
rms_norm = rmsnorm(x, p_weight, eps=1e-06) -> (1, 32, 2048) float32
outputs:
rms_norm: (1, 32, 2048) float32
If you try to trace an older model from Hugging Face like Qwen2.5-7B, you'll see that RMSNorm isn't a single op. The
RMSNormmodule and the corresponding Torch FX operation only landed in PyTorch 2.4.0; before that, RMSNorm was expressed as a sequence of primitivetorch.functionalcalls.
Tensor IR — A Minimal Primitive Set
Torch IR is just one possible frontend. ONNX, TensorFlow, and JAX each ship their own catalog of hundreds of ops with their own semantics, broadcasting rules, and shape conventions. If the rest of the compiler consumed Torch IR directly, every downstream pass (fusion, scheduling, codegen) would need to understand every Torch op. Adding ONNX support later would mean teaching every pass a second opset. Even making the code work with different PyTorch versions would be a headache.
Thus, a new IR layer is introduced: Tensor IR. Tensor IR is the single canonical representation everything downstream works on. Frontend-specific rewrites decompose each high-level op into one of three primitive kinds:
- Elementwise applies a scalar function per output position, which covers
add,mul,exp,rsqrt,silu,relu, and roughly fifteen others. - Reduction collapses an axis via associative binary op:
sum,maxorprod. - IndexMap covers every layout-only op: reshape, transpose, slice, unsqueeze, concat. Instead of five separate ops with five shape calculations, one
IndexMapOpparameterized by a coordinate function: for each output coord, compute the corresponding input coord via an affine expression. Transpose is(i, j) → (j, i). Slicex[5:8]is(i,) → (i + 5,). Concat-of-two adds a predicate that selects which source. No compute.
It's truly minimal. Even matrix-vector multiplication via torch.nn.Linear is decomposed into a broadcasted mul feeding a sum:
deplodock compile -c "nn.Linear(3, 2, bias=False)(torch.randn(4, 3))" --ir tensor
# Graph: 7 nodes, 1 inputs, 1 outputs
inputs:
input: (4, 3) float32
constants:
p_weight: (2, 3) float32
linear_b_unsq_bc = p_weight[k, j] -> (4, 3, 2) float32
linear_a_unsq_bc = input[i, j] -> (4, 3, 2) float32
linear_ew = multiply(linear_a_unsq_bc, linear_b_unsq_bc) -> (4, 3, 2) float32
linear_reduce = sum(linear_ew, axis=-2) -> (4, 1, 2) float32
linear = linear_reduce[i, na, j] -> (4, 2) float32
outputs:
linear: (4, 2) float32
The M×K×N intermediate looks ruinous, but it's never materialized. Later stages merge the mul and sum into one kernel that streams partial products straight into the accumulator, which is exactly what a hand-written matmul does.
Decomposition is a per-op function: each supported Torch op maps to a sequence of Tensor IR ops. Running it on RMSNorm's Torch IR yields:
deplodock compile -c "nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir tensor
# Graph: 15 nodes, 1 inputs, 1 outputs
inputs:
x: (1, 32, 2048) float32
constants:
p_weight: (2048,) float32
rms_norm_eps: (1,) float32 = 1e-06
rms_norm_mean_count: (1,) float32 = 2048.0
p_weight_bc = p_weight[k] -> (1, 32, 2048) float32
rms_norm_sq = multiply(x, x) -> (1, 32, 2048) float32
rms_norm_eps_bc = rms_norm_eps[k] -> (1, 32, 1) float32
rms_norm_mean_count_bc = rms_norm_mean_count[k] -> (1, 32, 1) float32
rms_norm_mean_sum = sum(rms_norm_sq, axis=-1) -> (1, 32, 1) float32
rms_norm_mean = divide(rms_norm_mean_sum, rms_norm_mean_count_bc) -> (1, 32, 1) float32
rms_norm_add_eps = add(rms_norm_mean, rms_norm_eps_bc) -> (1, 32, 1) float32
rms_norm_rsq = rsqrt(rms_norm_add_eps) -> (1, 32, 1) float32
rms_norm_rsq_bc = rms_norm_rsq[i, j, na] -> (1, 32, 2048) float32
rms_norm_norm = multiply(x, rms_norm_rsq_bc) -> (1, 32, 2048) float32
rms_norm = multiply(rms_norm_norm, p_weight_bc) -> (1, 32, 2048) float32
outputs:
rms_norm: (1, 32, 2048) float32
For simplicity, we assume that all dimensions are statically known and only work with float32 tensors. For real LLM inference, the compiler will need to handle dynamic shapes and support efficient data types like FP8, FP16, BF16, etc.
IndexMap Composition
Decomposition often produces chains of pure-layout ops. Matmul's decomposition above has unsqueeze → broadcast back-to-back; attention typically produces transpose → reshape → slice.
A Tensor-IR-level pass that runs before lifting collapses any IndexMapOp → IndexMapOp chain into a single composed coord map:
compose(producer, consumer):
# consumer reads the producer at indices consumer.coord_map[k]
# producer reads its input at indices producer.coord_map[j]
# producer.coord_map is expressed over placeholders for its output coords;
# substitute the consumer's expressions into those placeholders:
for each output coord k:
composed[k] = producer.coord_map[k]{placeholder_j → consumer.coord_map[j]}
return IndexMapOp(coord_map=composed, source=producer.input)
For example, a transpose feeding a slice:
t = x.T # IndexMap: (i, j) → (j, i) input: x
s = t[5:8] # IndexMap: (i, j) → (i + 5, j) input: t
Fuse into IndexMap: (i, j) → (j, i + 5). Repeatedly applying this rule fuses all IndexMap chains into a single composed coordinate mapping.
Loop IR — Fused Kernels as Loop Nests
Loop IR is one step closer to the final CUDA kernel. It has a single op type, LoopOp, which bundles the nested loops, the index arithmetic, and the per-element body that implement one kernel's computation. Think of it as the C function a programmer would write by hand. It leverages scalar instructions like most of the hardware, but it doesn't know anything about GPUs yet.
deplodock compile -c "nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir loop
=== 0: merged_n0 -> rms_norm ===
in0 = load rms_norm_mean_count[0]
in1 = load rms_norm_eps[0]
for a0 in 0..32: # free
for a1 in 0..2048: # reduce
in2 = load x[0, a0, a1]
v0 = multiply(in2, in2)
acc0 <- add(acc0, v0)
v1 = divide(acc0, in0)
v2 = add(v1, in1)
v3 = rsqrt(v2)
for a2 in 0..2048: # free
in3 = load x[0, a0, a2]
in4 = load p_weight[a2]
v4 = multiply(in3, v3)
v5 = multiply(v4, in4)
merged_n0[0, a0, a2] = v5
With each Tensor IR node lifted to its own trivial LoopOp, Loop IR's job is to collapse the pile of one-op kernels into as few fused kernels as possible.
Tensor IR → Loop IR is a two-step process:
- Lift each Tensor IR primitive into a one-op
LoopOp. - Merge adjacent
LoopOps wherever it's legal (this is the fusion pass).
Lifting
After decomposition and IndexMap composition, every Tensor IR node is one of the three primitives (elementwise, reduction, indexing). We convert each of these primitives into a trivial single-op LoopOp.
For example, for an elementwise operation like torch.nn.ReLU() the lift rule is:
lift_elementwise(node):
axes = [Axis(f"a{i}", extent=d) for i, d in enumerate(node.output.shape)]
loads = [Load(f"in{i}", source=i, index=identity(input_shape[i], axes))
for i in range(len(node.inputs))]
inner = [*loads,
Assign("v", op=node.op, args=[ld.name for ld in loads]),
Write(output=0, index=[Var(a.name) for a in axes], value="v")]
# Wrap the body in one free-axis Loop per output dim, outermost first.
body = inner
for a in reversed(axes):
body = [Loop(axis=a, body=body)]
return LoopOp(body=body)
Generated Loop IR:
deplodock compile -c "F.relu(torch.randn(8))" --ir loop
=== 0: relu -> relu ===
for a0 in 0..8: # free
in0 = load x[a0]
v0 = relu(in0)
relu[a0] = v0
After all nodes lift, the graph contains only LoopOp nodes.
Merging
Before codegen, we want to fuse as many loops as possible. Every boundary between two unfused kernels is a round-trip through global memory: write the intermediate, read it back. On CPUs that's a cache-miss tax. On GPUs it's catastrophic: DRAM bandwidth is orders of magnitude below what the ALUs can consume. A CUDA engineer who sees unfused loops will slam your head on the keyboard.
Loop fusion is well-studied; it predates ML compilers by decades. Classical compilers perform it to improve cache locality on CPUs and vector units, and it's covered in every standard compiler textbook (Aho, Lam, Sethi, Ullman, Compilers: Principles, Techniques, and Tools, Allen & Kennedy) alongside loop interchange, tiling, and distribution. The ML-compiler twist is mainly that getting fusion wrong is much more expensive on a GPU than on a CPU. What follows is a simplified, closed-form version of the algorithm from the polyhedral compilation literature; see the References for the full polyhedral schedulers (Pluto, Tensor Comprehensions).
The core idea: to fuse two loops, we need a mapping between the consumer's index variables and the producer's. Call it σ: a0, a1, ... → b0, b1, .... Once σ is known, we can rewrite the producer's body with the consumer's indices and splice it into the consumer loop, and the intermediate tensor disappears.
The examples below build intuition from the simplest case up to softmax.
Two Pointwise Kernels
The expression we're fusing:
x = torch.randn(8) # example input
mid = torch.neg(x)
out = torch.exp(mid) # (8,) -> (8,)
After lift (--passes dol is a shortcut for decomposition,optimization,lifting, i.e. all passes before fusion), each op has its own loop nest:
deplodock compile -c "torch.exp(torch.neg(torch.randn(8)))" --passes dol --ir loop
=== 0: neg -> neg ===
for a0 in 0..8: # free
in0 = load x[a0]
v0 = negative(in0)
neg[a0] = v0
=== 1: exp -> exp ===
for a0 in 0..8: # free
in0 = load neg[a0]
v0 = exp(in0)
exp[a0] = v0
The producer writes neg[a0]. The consumer reads neg[a0]. σ is an identity map. (Each LoopOp owns its own axis namespace; the fact that both kernels use a0 is incidental.)
Once σ is solved, the rest is mechanical: rewrite the producer's input indices through σ, concatenate its body into the consumer's, and forward the produced SSA value to wherever the consumer's read was. The intermediate buffer disappears entirely:
deplodock compile -c "torch.exp(torch.neg(torch.randn(8)))" --ir loop
=== 0: merged_exp -> exp ===
for a0 in 0..8: # free
in0 = load x[a0]
v0 = negative(in0)
v1 = exp(v0)
merged_exp[a0] = v1
Offset Slice
The expression we're fusing, the kind of window read rotary embeddings do:
x = torch.randn(16)
mid = torch.neg(x)
out = torch.exp(mid[5:8]) # (16,) -> (3,)
After lift, we get three kernels: the neg producer runs over all 16 elements, a trivial indexmap kernel implements the slice, and the exp consumer runs over 3:
deplodock compile -c "torch.exp(torch.neg(torch.randn(16))[5:8])" --passes dol --ir loop
=== 0: neg -> neg ===
for a0 in 0..16: # free
in0 = load x[a0]
v0 = negative(in0)
neg[a0] = v0
=== 1: n0 -> slice_1 ===
for a0 in 0..3: # free
in0 = load neg[(a0 + 5)]
n0[a0] = in0
=== 2: exp -> exp ===
for a0 in 0..3: # free
in0 = load n0[a0]
v0 = exp(in0)
exp[a0] = v0
The slice's σ is [a0] → [a0 + 5]. After fusion the producer's axis is gone entirely; the merged kernel iterates only over a0 ∈ [0, 3) and reads the source buffer through the composed offset:
deplodock compile -c "torch.exp(torch.neg(torch.randn(16))[5:8])" --ir loop
=== 0: merged_exp -> exp ===
for a0 in 0..3: # free
in0 = load x[(a0 + 5)]
v0 = negative(in0)
v1 = exp(v0)
merged_exp[a0] = v1
A slice feeding a pointwise op becomes a single loop nest that reads the sliced window directly from the source buffer, no intermediate.
The producer originally ran over a0 ∈ [0, 16), but the merged kernel iterates a0 ∈ [0, 3) only: because σ replaced the producer's axis with the consumer's, the producer body now executes just three times, at x[5], x[6], x[7]. The other 13 elements of what the intermediate buffer used to hold are never computed. Merge is doing dead-element elimination for free, via the consumer's iteration space.
This only works because the merge rule requires fan-out of 1: the grammar matches a LoopOp whose sole consumer is another LoopOp. If mid also fed a second kernel that needed the full 16 elements, chain-mode wouldn't fire, and the producer would stay intact: we'd need the full producer loop to materialize mid, plus a conditional write for the consumer's slice.
Reduction Feeding Elementwise
The expression we're fusing:
x = torch.randn(4, 8)
s = x.sum(dim=-1, keepdim=True) # (4, 8) -> (4, 1), one sum per row
out = torch.exp(s)
The reduce kernel carries an axis that never appears in its write; the axis is collapsed into an accumulator. The producer sums x[a0, a1] over a1; the consumer takes exp of each element of the reduced result:
deplodock compile -c "torch.exp(torch.randn(4,8).sum(-1,True))" --passes dol --ir loop
=== 0: sum_1 -> sum_1 ===
for a0 in 0..4: # free
for a1 in 0..8: # reduce
in0 = load x[a0, a1]
acc0 <- add(acc0, in0)
sum_1[a0, 0] = acc0
=== 1: exp -> exp ===
for a0 in 0..4: # free
in0 = load sum_1[a0, 0]
v0 = exp(in0)
exp[a0, 0] = v0
σ is the identity: σ[a0] = a0. But the producer has two axes, and a1 never appeared in its write, so σ has no binding for it. a1 can't just disappear; the producer's work depends on it. Since a1 has kind reduce, it's legal to carry it into the merged kernel as an inner sweep:
deplodock compile -c "torch.exp(torch.randn(4,8).sum(-1,True))" --ir loop
=== 0: merged_exp -> exp ===
for a0 in 0..4: # free
for a1 in 0..8: # reduce
in0 = load x[a0, a1]
acc0 <- add(acc0, in0)
v0 = exp(acc0)
merged_exp[a0, 0] = v0
The free vs reduce tag on each axis is how the algorithm tracks which producer-side axes can safely leak into the consumer's iteration space. A reduce axis just becomes an inner loop. A free axis would require replicating the consumer's work across a new dimension, which we refuse.
Softmax: Multi-Port Consumer + Reduce-Axis Aliasing
Softmax is the case that puts every piece of the algorithm under load:
mx = x.max(dim=-1, keepdim=True) # reduce #1 over cols
e = torch.exp(x - mx)
s = e.sum(dim=-1, keepdim=True) # reduce #2 over cols
out = e / s
In loop-form it looks as follows. Producer's and consumer's iteration and reduce axes match, so we'll name them a0, a1, a1_1 in both loops:
# producer: max + sub + exp
for a0 in 0..4:
for a1 in 0..8:
acc0 = -1e+30
for a1_1 in 0..8:
acc0 = max(acc0, x[a0, a1_1]) # reduce sweep #1
v_mx = copy(acc0)
v_s = sub(x[a0, a1], v_mx)
v_e = exp(v_s)
mid[a0, a1] = v_e
# consumer: sum + div
for a0 in 0..4:
for a1 in 0..8:
acc0 = 0.0
for a1_1 in 0..8:
acc0 = add(acc0, mid[a0, a1_1]) # reduce sweep #2
v_sm = copy(acc0)
v = div(mid[a0, a1], v_sm)
out[a0, a1] = v
All axes match, so σ is again the identity. But we can't collapse the two reduce sweeps into one: the consumer's sum needs the producer's max fully computed first (acc0 has to be finalized before sub(x, acc0) is meaningful). The algorithm keeps both sweeps back-to-back and splices the producer's sub/exp chain into the second. Feeding the whole thing through the real pipeline (F.softmax(x, dim=-1)) produces one kernel with three inner sweeps:
deplodock compile -c "F.softmax(torch.randn(4, 8), dim=-1)" --ir loop
=== 0: merged_n0 -> softmax ===
for a0 in 0..4: # free
for a1 in 0..8: # reduce
in0 = load x[a0, a1]
acc0 <- maximum(acc0, in0) # max sweep
for a1 in 0..8: # reduce
in1 = load x[a0, a1]
v0 = subtract(in1, acc0)
v1 = exp(v0)
acc1 <- add(acc1, v1) # sum sweep (uses finalized max)
for a2 in 0..8: # free
in2 = load x[a0, a2]
v2 = subtract(in2, acc0)
v3 = exp(v2)
v4 = divide(v3, acc1) # divide sweep
merged_n0[a0, a2] = v4
The x - max and exp subgraphs get inlined twice (once into the sum sweep, once into the divide sweep), but that's the whole point: recomputation is cheap, memory round-trips aren't.
The Full Algorithm
Three concepts the algorithm leans on, illustrated together:
for a0 in 0..4: # opens scope (a0,)
for a1 in 0..8: # reduce # opens scope (a0, a1)
in0 = load x[a0, a1] # scope (a0, a1) live {a0, a1} ssa_deps {}
acc0 <- add(acc0, in0) # scope (a0, a1) live {a0, a1} ssa_deps {in0, acc0}
v0 = mul(acc0, acc0) # scope (a0,) live {} ssa_deps {acc0}
out0[a0, 0] = v0 # scope (a0,) live {a0} ssa_deps {v0}
- Scope — the tuple of enclosing
Loopaxes at a stmt's position. - Live axes — axis
Vars appearing in the stmt's Expr subtree (after σ). They pin the shallowest scope the stmt can legally sit in. - SSA deps — stmts whose SSA outputs the stmt reads. Drives the worklist: emitting a stmt schedules its deps.
The algorithm is a classic worklist: a pending queue seeded with the consumer's Writes, processed in reverse topological order so defs land before uses:
def merge(producer, consumer):
pending = [(w, {}) for w in consumer.body if isinstance(w, Write)]
while pending:
stmt, sigma = pending.pop()
match stmt:
case Load() if reads(stmt, producer):
# read_index[k] == σ(write_index[k]); solve for σ.
# Emit a trivial copy bridge; a later pass collapses these.
sigma = solve_sigma(stmt.index, producer.write.index)
emitted = emit(Copy(stmt.name, producer.write.value), current_scope)
case Accum(axis, op, value):
# Reduce axis survives: wrap the accum in an inner Loop,
# placed at the shallowest scope the value still depends on.
fresh = fresh_axis()
scope = scope_for_value(current_scope, value)
emitted = emit(Loop(fresh, Accum(op, value)), scope)
sigma = sigma | {axis: fresh}
case _: # plain Assign / Select / Load
# Rewrite through the accumulated σ + SSA renaming.
emitted = emit(rewrite(stmt, ssa_rename, sigma), current_scope)
for dep in ssa_deps(emitted):
pending.append((dep, sigma))
Four helpers deserve a closer look. They're where the actual work happens:
def emit(stmt, scope):
"""Descend the consumer body along ``scope``, creating empty ``Loop``
nodes as needed, and *prepend* stmt at the leaf. Prepend (not append)
because the worklist runs in reverse topological order, so prepending
gives defined-before-use ordering for free.
body = [Loop(a0, [Loop(a1, [in0=load ..., acc0 <- ...])])]
emit(new_stmt, scope=(a0,))
==> body = [Loop(a0, [new_stmt, Loop(a1, [...])])]
"""
def solve_sigma(read_idx, write_idx):
"""Pair componentwise: read[k] == σ(write[k]). Both sides are affine
in axis Vars; extract the assignment, fail the merge if any pair isn't
linear in the writer's axes.
read = [a0 + 5, a1] write = [b0, b1]
==> sigma = {b0: a0 + 5, b1: a1}
"""
def rewrite(stmt, ssa_rename, sigma):
"""One tree-walk, two substitutions: σ swaps producer axis Vars for
reader-side exprs; ssa_rename freshens producer SSA names to avoid
colliding with the consumer's vN namespace.
"""
def scope_for_value(current_scope, value):
"""Shallowest prefix of ``current_scope`` that binds every live axis
of ``value``. Walks outermost→innermost, stops once all live axes are
bound. Hoists scalar stmts out of reduce loops they no longer need.
current_scope = (a0, a1, a2)
value = x[a0, a2] live = {a0, a2} ==> (a0, a1, a2)
value = acc[a0] live = {a0} ==> (a0,)
value = 0.0 live = {} ==> ()
"""
The Blowup Guard
The algorithm guarantees correctness: if merge returns a body, running it gives the same values as running producer and consumer separately. It doesn't guarantee less compute. For that we need an additional post-rewrite check that rejects rewrites that increase compute.
The cleanest example is an MLP-style two-matmul chain: up(64→256) feeding down(256→64). With the blowup guard temporarily disabled in 001_merge_loop_ops.py, the merged kernel that deplodock compile emits is:
deplodock compile -c "
class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.up = torch.nn.Linear(64, 256, bias=False)
self.down = torch.nn.Linear(256, 64, bias=False)
def forward(self, x):
return self.down(self.up(x))
MLP()(torch.randn(8, 64))" --ir loop
=== 0: merged_n1 -> linear_1 ===
for a0 in 0..8: # free
for a1 in 0..64: # free
for a2 in 0..256: # reduce (down_proj's H)
for a3 in 0..64: # reduce (up_proj's K, replayed!)
in0 = load p_up_weight[a2, a3]
in1 = load x[a0, a3]
v0 = multiply(in1, in0)
acc0 <- add(acc0, v0)
in2 = load p_down_weight[a1, a2]
v1 = multiply(acc0, in2)
acc1 <- add(acc1, v1)
merged_n1[a0, a1] = acc1
The up-proj reduce (a3:64) sits inside the down-proj reduce (a2:256), which sits inside the down-proj output axes (a0:8, a1:64). The producer's reduce reads x[a0, a3] * W1[a2, a3] (live axes {a0, a2}), so scope_for_value can't hoist past a2. The K=64 reduce runs M*D*H = 8*64*256 = 131K times instead of M*H = 2K. Total work jumps to ~8.4M ops vs ~262K combined unfused, a 32× blowup.
A principled solution would be a global optimization instead of the greedy algorithm with a blow-up guard; that is exactly what the polyhedral schedulers in Pluto and Tensor Comprehensions do, formulating the entire fusion-and-tiling problem as one ILP over the iteration domain (see the References for details).
Tile IR — Parallel Programming Model
Tile IR is the GPU-aware IR. The full lowering machinery is the subject of Part 2 (forthcoming). Yet, I promised the full pipeline, and I need to at least reduce the violent intent of a CUDA ninjas reading this article to an unhappy growl.
Tile IR is the first IR that knows it's targeting a GPU, but only in the abstract. It assumes three facts about the hardware:
- A grid of parallel threads.
- Threads are divided into blocks.
- Each block has a fast scratchpad (shared memory) one order of magnitude faster than global memory.
A Tile is a Loop nest with axes annotated THREAD or BLOCK, plus three pseudo-ops: StridedLoop (a thread walks a range with a stride), Combine (cross-thread merge of a partial accumulator), and Stage (slab of data hoisted into shared memory).
deplodock compile -c "torch.nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir tile
TileOp k_rms_norm_reduce
in0 = load rms_norm_mean_count[0]
in1 = load rms_norm_eps[0]
# 32 blocks × 256 threads = 8192 threads
Tile(axes=(a0:256=THREAD, a1:32=BLOCK)):
# hoist the row of x and the weight row into shared memory
x_smem = Stage(x, origin=(0, a1, 0), slab=(a2:2048@2)) async
p_weight_smem = Stage(p_weight, origin=(0), slab=(a3:2048@0)) async
# reduce — each thread walks 8 of 2048 indices, accumulates a partial
StridedLoop(a2 = a0; < 2048; += 256):
in2 = load x_smem[a2]
v0 = multiply(in2, in2)
acc0 <- add(acc0, v0)
# merge partials across the 256 threads
Combine(acc0, op=add)
# finalize the rsqrt scalar per row
v1 = divide(acc0, in0)
v2 = add(v1, in1)
v3 = rsqrt(v2)
# free — write the normalized row back
StridedLoop(a3 = a0; < 2048; += 256):
in3 = load x_smem[a3]
in4 = load p_weight_smem[a3]
v4 = multiply(in3, v3)
v5 = multiply(v4, in4)
rms_norm[0, a1, a3] = v5
8192 threads versus the 32 a naive lowering would have launched, and the 2048-wide reduce takes 8 strided iterations per thread instead of 2048 serial ones.
Lowering from Loop IR to this final shape runs as a stack of small rewrite rules. Each pass rewrites a Tile IR operation it into one with better hardware mapping, e.g. introduces input data staging or splits loop into blocks.
Kernel IR — Materializing the Schedule
Kernel IR turns each Tile IR scheduling decision into a concrete hardware primitive. An async Stage becomes a Smem declaration plus a cp.async fill loop with commit_group/wait_group fences. Combine becomes another Smem, a per-thread write of the partial, a Sync, a TreeHalve (the canonical halving loop), another Sync, and a broadcast load of slot 0. THREAD/BLOCK axes become threadIdx/blockIdx.
deplodock compile -c "torch.nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir kernel
kernel k_rms_norm_reduce
in0 = load rms_norm_mean_count[0]
in1 = load rms_norm_eps[0]
Tile(axes=(a0:256=THREAD, a1:32=BLOCK)):
Init(acc0, op=add)
# Stage(x) → shared-memory buffer, async cooperative fill
Sync
Smem x_smem[2048] (float)
StridedLoop(x_smem_flat = a0; < 2048; += 256):
cp.async x_smem[x_smem_flat] <- x[0, a1, (0 + x_smem_flat)]
cp.async.commit_group
cp.async.wait_group(0)
Sync
# Stage(p_weight) → shared-memory buffer, async cooperative fill
Sync
Smem p_weight_smem[2048] (float)
StridedLoop(p_weight_smem_flat = a0; < 2048; += 256):
cp.async p_weight_smem[p_weight_smem_flat] <- p_weight[(0 + p_weight_smem_flat)]
cp.async.commit_group
cp.async.wait_group(0)
Sync
# reduce — accumulate the per-thread partial sum of x²
StridedLoop(a2 = a0; < 2048; += 256):
in2 = load x_smem[a2]
v0 = multiply(in2, in2)
acc0 <- add(acc0, v0)
# Combine → tree-reduce partials in shared memory, broadcast slot 0
Smem acc0_smem[256] (float)
acc0_smem[a0] = acc0
Sync
TreeHalve(acc0_smem, op=add, length=256, tid=a0)
Sync
acc0_b = load acc0_smem[0]
# finalize the rsqrt scalar per row
v1 = divide(acc0_b, in0)
v2 = add(v1, in1)
v3 = rsqrt(v2)
# free — write the normalized row back
StridedLoop(a3 = a0; < 2048; += 256):
in3 = load x_smem[a3]
in4 = load p_weight_smem[a3]
v4 = multiply(in3, v3)
v5 = multiply(v4, in4)
rms_norm[0, a1, a3] = v5
The split between Tile IR (decisions) and Kernel IR (hardware) mirrors the algorithm/schedule split from Halide: the same compute can be scheduled multiple ways without rewriting the body, and a future autotuner could search over Tile IR rewrites without ever touching the codegen path.
CUDA — Emitting Source
Codegen is a one-to-one tree walk over Kernel IR. Smem → __shared__ float, Sync → __syncthreads(), TreeHalve → the halving loop, StridedLoop → a strided for, cp.async → inline PTX. Every load and store flattens its multi-index against the tensor's strides.
deplodock compile --code "torch.nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir cuda
extern "C" __global__
__launch_bounds__(256) void k_rms_norm_reduce(const float* x, const float* p_weight, float* rms_norm) {
float in0 = 2048.0f;
float in1 = 1e-06f;
{
int a1 = blockIdx.x;
int a0 = threadIdx.x;
float acc0 = 0.0f;
__syncthreads();
__shared__ float x_smem[2048];
for (int x_smem_flat = a0; x_smem_flat < 2048; x_smem_flat += 256) {
unsigned int _smem_addr = __cvta_generic_to_shared(&x_smem[x_smem_flat]);
asm volatile("cp.async.ca.shared.global [%0], [%1], 4;\n" :: "r"(_smem_addr), "l"(&x[a1 * 2048 + x_smem_flat]) : "memory");
}
asm volatile("cp.async.commit_group;\n" ::: "memory");
asm volatile("cp.async.wait_group 0;\n" ::: "memory");
__syncthreads();
__syncthreads();
__shared__ float p_weight_smem[2048];
for (int p_weight_smem_flat = a0; p_weight_smem_flat < 2048; p_weight_smem_flat += 256) {
unsigned int _smem_addr = __cvta_generic_to_shared(&p_weight_smem[p_weight_smem_flat]);
asm volatile("cp.async.ca.shared.global [%0], [%1], 4;\n" :: "r"(_smem_addr), "l"(&p_weight[p_weight_smem_flat]) : "memory");
}
asm volatile("cp.async.commit_group;\n" ::: "memory");
asm volatile("cp.async.wait_group 0;\n" ::: "memory");
__syncthreads();
for (int a2 = a0; a2 < 2048; a2 += 256) {
float in2 = x_smem[a2];
float v0 = in2 * in2;
acc0 += v0;
}
__shared__ float acc0_smem[256];
acc0_smem[a0] = acc0;
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (a0 < s) {
acc0_smem[a0] = acc0_smem[a0] + acc0_smem[a0 + s];
}
__syncthreads();
}
__syncthreads();
float acc0_b = acc0_smem[0];
float v1 = acc0_b / in0;
float v2 = v1 + in1;
float v3 = rsqrtf(v2);
for (int a3 = a0; a3 < 2048; a3 += 256) {
float in3 = x_smem[a3];
float in4 = p_weight_smem[a3];
float v4 = in3 * v3;
float v5 = v4 * in4;
rms_norm[a1 * 2048 + a3] = v5;
}
}
}
That's the full pipeline output: a single CUDA kernel for the whole fused RMSNorm. With inline PTX for smem staging, the kernel is properly unreadable, just the way CUDA ninja likes it.
Validation
Every benchmark in this section is fp32 end-to-end on all three backends: eager PyTorch, torch.compile, and Deplodock. Numerical correctness checked against eager PyTorch as max-abs-diff: exact at small shapes, ≤ 1e-5 at seq=512 (FP32 reduction-order drift).
Setup: NVIDIA RTX 5090 (sm_120, Blackwell), driver 580.126.09, CUDA 13.0, PyTorch 2.11.0 (+cu130).
Wins — Pointwise Fusion
Code-generation approach shines on long pointwise chains like GELU. GELU tanh-approximation is used as the activation in Gemma (via GeGLU), every modern vision transformer (ViT, DINOv2, SigLIP), Whisper, and CLIP. It is seven elementwise ops on a [1, 32, 18944] tensor (the FFN intermediate width of Qwen2.5-7B).
def gelu(x):
return 0.5 * x * (1.0 + torch.tanh(0.7978845608 * (x + 0.044715 * x * x * x)))
At a single prefill step (seq=32) where dispatch dominates:
deplodock run -bc "x=torch.randn(32,18944);0.5*x*(1+torch.tanh(0.797*(x+0.044*x*x*x)))"
Backend Latency (us) vs Eager
------------------------------------------------
Eager PyTorch 31 1.00x
torch.compile 24 1.25x
Deplodock 6 4.87x
At full prefill (seq=512) where compute is more noticeable:
deplodock run -bc "x=torch.randn(512,18944);0.5*x*(1+torch.tanh(0.797*(x+0.044*x*x*x)))"
Backend Latency (us) vs Eager
------------------------------------------------
Eager PyTorch 343 1.00x
torch.compile 41 8.38x
Deplodock 53 6.50x
Eager dispatches each op as its own kernel: eight launches, eight HBM round-trips for the same x. Deplodock fuses the whole expression into one LoopOp (seven multiplies, an add, a tanh, an add, two more multiplies) and lowers to one pointwise kernel:
deplodock compile -c "x=torch.randn(32,18944);0.5*x*(1+torch.tanh(0.797*(x+0.044*x*x*x)))"
extern "C" __global__
__launch_bounds__(256) void k_mul_5_pointwise(const float* x, float* mul_5) {
float in0 = 0.044f;
float in1 = 0.797f;
float in2 = 1.0f;
float in3 = 0.5f;
{
int a0 = blockIdx.x / 1184;
int a2 = blockIdx.x % 1184;
int a1 = threadIdx.x / 16;
int a3 = threadIdx.x % 16;
float in4 = x[(a0 * 16 + a1) * 18944 + (a2 * 16 + a3)];
float v0 = in4 * in0;
float v1 = v0 * in4;
float v2 = v1 * in4;
float v3 = in4 + v2;
float v4 = v3 * in1;
float v5 = tanhf(v4);
float v6 = v5 + in2;
float v7 = in4 * in3;
float v8 = v7 * v6;
mul_5[(a0 * 16 + a1) * 18944 + (a2 * 16 + a3)] = v8;
}
}
One HBM read of x, one HBM write of out, no intermediates. Eager pays for the round-trips between every op; torch.compile fuses the chain too (Inductor produces one Triton kernel that runs at roughly the same speed as the one above), but its dispatch path is thicker (Dynamo guards, Inductor runtime, Triton launcher) and at this kernel size the per-call overhead shows.
Ties — Softmax
The attention scores tensor is [batch, heads, seq, seq]; softmax over the last axis is a long-reduce kernel similar in shape to RMSNorm:
deplodock run --bench -c "torch.nn.Softmax(dim=-1)(torch.randn(1, 28, 2048, 2048))"
Backend Latency (us) vs Eager
------------------------------------------------
Eager PyTorch 688 1.00x
torch.compile 661 1.04x
Deplodock 703 0.98x
The same cooperative-reduce + staging recipe that handles RMSNorm covers softmax: one block per (batch, head, row), 256 threads cooperate on the 2048-wide reduce, and the row is staged into shared memory and reused across the max, exp-sum, and normalize sweeps.
Loss — Matmul
This is the part of the post where I admit things.
Full-block latency. One transformer block (attention + MLP + the two surrounding norms) compiled end-to-end:
| model | seq | eager | torch.compile | deplodock | vs eager |
|---|---|---|---|---|---|
| TinyLlama-1.1B | 32 | 323 µs | 234 µs | 574 µs | 0.56× |
| TinyLlama-1.1B | 128 | 587 µs | 473 µs | 917 µs | 0.64× |
| TinyLlama-1.1B | 512 | 1288 µs | 1187 µs | 3807 µs | 0.34× |
| Qwen2.5-7B | 32 | 1145 µs | 997 µs | 1649 µs | 0.69× |
| Qwen2.5-7B | 128 | 1598 µs | 1684 µs | 3075 µs | 0.52× |
| Qwen2.5-7B | 512 | 5259 µs | 5055 µs | 10575 µs | 0.50× |
(TinyLlama-1.1B: hidden=2048, ffn=5632, 32 heads, 4 KV heads, head_dim=64. Qwen2.5-7B: hidden=3584, ffn=18944, 28 heads, 4 KV heads, head_dim=128.)
The gap is the matmul. Every transformer block is dominated by the Q/K/V projections, the FFN's gate/up/down, and the attention's QKᵀ and attn·V: six matmuls per block, accounting for the overwhelming majority of the FLOPs. cuBLAS has been highly optimized over years, and that's the bar deplodock is measured against here.
Matrix Multiplication Gap
The block table above is dominated by the six matmuls and the O(seq²) attention path; the matmuls themselves at Qwen-realistic shapes (Linear(K → N) on [1, seq, K]):
| shape | seq | eager | torch.compile | deplodock | vs eager |
|---|---|---|---|---|---|
Linear(3584, 3584) (q/o_proj) | 32 | 69 µs | 61 µs | 62 µs | 1.10× |
Linear(3584, 3584) | 128 | 123 µs | 104 µs | 236 µs | 0.52× |
Linear(3584, 3584) | 512 | 265 µs | 242 µs | 483 µs | 0.55× |
Linear(3584, 512) (k/v_proj, GQA) | 512 | 83 µs | 78 µs | 135 µs | 0.62× |
Linear(3584, 18944) (gate/up_proj) | 512 | 1388 µs | 1383 µs | 2030 µs | 0.68× |
Linear(18944, 3584) (down_proj) | 512 | 1277 µs | 1234 µs | 2522 µs | 0.51× |
The lowering strategy that I was able to land within the time limit lands at 50% to slightly above cuBLAS depending on shape — it actually edges out cuBLAS on the q/o_proj at seq=32 and holds 50–70% on the prefill-width matmuls. Closing the remaining gap requires tensor cores and the kind of SASS-level scheduling cuBLAS pulls from CUTLASS templates.
What's Next
Can you run an entire LLM with this compiler? Two things are missing:
- Prefill/decode + KV-cache. Not really a deplodock concern; vLLM already owns that machinery, and a deplodock backend slot in vLLM is the natural integration point.
- Dynamic shapes. Required for arbitrary-length inference; the pipeline currently bakes concrete extents into every op for readability. A bigger lift than it sounds — almost every graph rewrite rule relies on shape inference in some form, so wiring it through the compiler is a hefty task. A future post.
Closing Notes
Final stats:
- 5,033 LOC of tracing, IR definitions, graph transformation, introspection, CUDA code generation, and dispatch logic (at the time of writing; excluding comments, imports, and re-exports).
- Core dependencies: numpy, cupy (raw CUDA dispatch), cppyy (JIT-compiles the C++ reference kernels used for correctness checks).
- End-to-end numerically correct against PyTorch eager on TinyLlama-1.1B and Qwen2.5-7B transformer blocks.
The core algorithms are small. Fusion is a page of pseudocode. Codegen is a tree walk. IR lowering is a dispatch table per op. You don't need to read 500K lines of code to understand how ML compilers work — let alone to build one.
If you made it this far, clone the compiler and run deplodock compile <hf_model> --ir <target_ir> on a model you'd like to dissect. Once you see the same algorithms working on real models, the production stacks (TVM, Inductor, XLA, MLIR) stop looking like black boxes.
References
Classical compilers
- Alfred V. Aho, Monica S. Lam, Ravi Sethi, Jeffrey D. Ullman. Compilers: Principles, Techniques, and Tools, 2nd edition ("the Dragon Book"). Addison-Wesley, 2006. Chapter 11 ("Optimizing for Parallelism and Locality") covers loop transformations (fusion, interchange, tiling) at textbook level; a gentle entry point to the classical compiler-era framing.
- Randy Allen and Ken Kennedy. Optimizing Compilers for Modern Architectures: A Dependence-Based Approach. Morgan Kaufmann, 2001. Chapters on loop transformations and dependence analysis give the classical dependence-preservation formulation that underlies every fusion legality rule.
- Gary A. Kildall. A Unified Approach to Global Program Optimization. POPL 1973. The paper that introduced iterative worklist-based dataflow analysis; the pattern our splicer reuses (with SSA emission in place of lattice-value propagation).
Polyhedral scheduling
- Uday Bondhugula, Albert Hartono, J. Ramanujam, P. Sadayappan. A Practical Automatic Polyhedral Parallelizer and Locality Optimizer. PLDI 2008. Pluto, the canonical ILP-based polyhedral scheduler. Our σ solver is a closed-form affine-only subset of what Pluto solves via integer linear programming.
- Nicolas Vasilache, Oleksandr Zinenko, Theodoros Theodoridis, Priya Goyal, Zachary DeVito, William S. Moses, Sven Verdoolaege, Andrew Adams, Albert Cohen. Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions. arXiv:1802.04730, 2018. The first major ML compiler built directly on the polyhedral model; the reference for applying scheduling-with-replication to transformer-style multi-accumulator patterns.
ML compiler stacks
- Jonathan Ragan-Kelley, Connelly Barnes, Andrew Adams, Sylvain Paris, Frédo Durand, Saman Amarasinghe. Halide: A Language and Compiler for Optimizing Parallelism, Locality, and Recomputation in Image Processing Pipelines. PLDI 2013. Not polyhedral, but the origin of the algorithm / schedule split that shapes the Loop IR / Kernel IR split in this compiler.
- Tianqi Chen et al. TVM: An Automated End-to-End Optimizing Compiler for Deep Learning. OSDI 2018. Auto-tuned extension of Halide-style scheduling; a reference for what a search-based scheduler over our Loop IR would look like.
- Philippe Tillet, H. T. Kung, David Cox. Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations. MAPL 2019. The block-programming language Inductor lowers to. Sits at roughly the abstraction level of our Tile IR, but with a Pythonic surface and an LLVM backend instead of source emission.
- PyTorch contributors. TorchInductor: a PyTorch-native Compiler with Define-by-Run IR and Symbolic Shapes. PyTorch dev-discuss, 2022. The compiler
torch.compileinvokes; the primary baseline in the validation section. Lowers FX → Inductor IR → Triton, with shape symbolics throughout.
PyTorch and serving
- PyTorch contributors. torch.export. PyTorch documentation. The ahead-of-time tracing entry point we use to turn a
nn.Moduleplus example inputs into a shape-resolved graph; the frontend of our compiler.
GPU codegen and attention
- Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. The streaming softmax + online-rescaling pattern that fuses softmax with the surrounding
QKᵀandattn·Vmatmuls; orthogonal to the per-op softmax measured here, but the natural next step for closing the attention-block gap. - Dmitry Trifonov. Beating cuBLAS on RTX 5090. CloudRift blog. A walk-through of the hand-tuned SGEMM kernel this compiler would ideally generate: CTA tiling, shared-memory double-buffering, and TMA-driven producer/consumer pipelines. Demonstrates the shared-memory scheduling strategy deplodock's Loop IR doesn't yet emit.


