Transforms: grad and vmap
Transforms are compiler features, not library functions or macros. They are DAG-to-DAG
rewrites: the compiler takes a function, rewrites its RISC DAG, and produces a new function.
After expansion the program contains only RISC primitives. A transform must always be
applied; a bare grad with no function is a parse error. Transforms compose. The
authoritative source is spec/06-transformations.md.
grad(f) is reverse-mode differentiation. It produces a new function from f's arguments to
their gradients. The output of f must be a scalar floating result; reduce to a scalar
first if it is not. The gradient of a single tensor parameter has the same type as that
parameter, and a multi-parameter function yields a flat tuple of gradients.
grad(loss_fn)(grad {} (var {} loss_fn))By default grad(f) differentiates with respect to every differentiable parameter. The
wrt argument restricts it to named parameters, and the result holds one entry per listed
parameter in the order listed. Apply the gradient function to get the values:
(dw, db) = grad(loss_fn, wrt=(w, b))(w, b)grad returns gradients only, not the forward value alongside them. It composes with
itself for higher derivatives: grad(grad(f)) is the second derivative.
vmap vectorizes a per-example function over a batch dimension. It is a DAG rewrite, not a
loop: every operation is lifted to run over the added axis. The axis argument names the
integer position where the batch dimension is inserted, defaulting to 0.
def process(x: tensor[features, f32]) -> tensor[features, f32] = relu(x)
def batch_process(xs: tensor[batch, features, f32]) -> tensor[batch, features, f32] = xs |> vmap(process, axis=0)(vmap {} (var {} process) (lit {type: (t-prim {} int32)} 0))Each tensor argument of the wrapped function gains the batch dimension; non-tensor arguments are shared across the batch. Reductions inside the wrapped function still reduce over their original named axis, so the batch dimension passes through untouched.
Composing transforms
Section titled “Composing transforms”vmap(grad(f)) computes per-example gradients: each example in the batch gets its own
gradient vector, which is what per-example gradient clipping needs. This is not the same as
grad(vmap(f)), which would be the gradient of the sum over the batch. To differentiate a
vmapped function, reduce its result to a scalar first:
grad(fn (xs) -> sum(vmap(process, axis=0)(xs), 0))When you are unsure whether a particular composition is supported, write a small program and let the compiler tell you. The supported surface is defined by the compiler and its tests.