Skip to content

Layers

Layers live under the School.Nn prefix. Each layer is a plain function with an explicit tensor signature. Layers that carry learnable parameters pair a forward function with a parameter record (an ADT) and, where applicable, a config record; purely functional layers such as the activations are a single forward function.

This chapter is a reference for the layer surface defined under src/nn/. Tensor element types are written f32 for float and int64 for integer indices, and p denotes a precision variable (a layer polymorphic over its element precision).

School.Nn.Linear.forward is an affine layer: x @ w + b.

forward: tensor[a, b, p] -> tensor[b, c, p] -> tensor[c, p] -> tensor[a, c, p]

The bias is expanded over the batch axis inside the function, so the caller passes a rank-1 bias. Defined in src/nn/linear.ch.

The elementwise activations take a tensor of any rank and return the same shape.

  • School.Nn.Relu.relu_forward (src/nn/relu.ch): rectified linear unit.
  • School.Nn.Gelu.gelu_forward (src/nn/gelu.ch): Gaussian error linear unit, computed with the tanh approximation.
  • School.Nn.Silu.silu_forward (src/nn/silu.ch): sigmoid linear unit (x * sigmoid(x)).
relu_forward(x: tensor[..r, f32]) -> tensor[..r, f32]

Each module also exports rank-fixed aliases (relu_forward_2d, relu_forward_3d, relu_forward_4d, and the matching gelu/silu forms) for call sites that pass a concrete rank.

School.Nn.ScalarOps (src/nn/scalarops.ch) provides the small tensor builders the rest of the framework leans on: tensor_full_1d, tensor_full_like, tensor_zeros_like, tensor_ones_like, tensor_scalar_mul, tensor_scalar_add, and tensor_square.

  • School.Nn.Im2col.conv2d_im2col_forward (src/nn/im2col.ch) is an im2col-based 2D convolution. It takes an NCHW input, a [c_out, c_in, kh, kw] weight, a [c_out] bias, and the height and width strides:

    conv2d_im2col_forward(
    x: tensor[a, c_in, h, w, f32],
    weight: tensor[c_out, c_in, kh, kw, f32],
    bias: tensor[c_out, f32],
    sh: int64, sw: int64)
  • School.Nn.Depthwise.depthwise_conv2d_forward (src/nn/depthwise.ch) is a depthwise 2D convolution: one [kh, kw] kernel per input channel, with weight shape [c, kh, kw] and bias [c].

  • School.Nn.Embedding.forward (src/nn/embedding.ch) is a lookup: it gathers rows of an embedding table by integer ids.

    forward(ids: tensor[batch, seq, int64], table: tensor[vocab, hidden, p])
    -> tensor[batch, seq, hidden, p]
  • School.Nn.PosEmbed (src/nn/posembed.ch) is a learned positional embedding. pos_embed_init_like builds a PosEmbedParams table from a PosEmbedConfig (max_seq, hidden, seed), and pos_embed_forward adds the position rows for the first seq_len positions to the token embeddings.

  • School.Nn.Rope (src/nn/rope.ch) is rotary position embedding. rope_cos_table and rope_sin_table build the cosine and sine tables for a given sequence length, head dimension, and base, and rope_apply rotates a [seq, d, f32] tensor.

The positional embedding example combines an embedding lookup with a learned position table (examples/p5/pos_embed_demo.ch):

import School.Nn.Embedding (forward)
import School.Nn.PosEmbed (PosEmbedConfig, pos_embed_init_like, pos_embed_forward)
...
tok_embeds = forward(ids, tok_table)
cfg = PosEmbedConfig { max_seq: max_seq, hidden: hidden, seed: cast(42, int64) }
pos_params = pos_embed_init_like(cast(0, int64), cfg)
out = pos_embed_forward(tok_embeds, pos_params, cast(3, int64))
  • School.Nn.LayerNorm (src/nn/layernorm.ch): layer normalization. layernorm_init_like(d) builds a LayerNormParams (gain and bias of length d), and layernorm_forward normalizes over the feature axis given an eps. Rank-fixed forms layernorm_forward_2d, _3d, _4d, and a channels-last form layernorm_channels_last_forward for NCHW inputs are also provided.
  • School.Nn.RmsNorm (src/nn/rmsnorm.ch): RMS normalization. forward and forward_2d scale each row by its root-mean-square and a per-feature gain.
  • School.Nn.Group_Norm (src/nn/group_norm.ch): group normalization over NCHW inputs. groupnorm_init_like builds the params from a GroupNormConfig, and groupnorm_forward normalizes within channel groups.
  • School.Nn.BatchNorm (src/nn/batchnorm.ch): batch normalization for rank-2 (batchnorm1d_forward) and NCHW (batchnorm2d_forward) inputs. The plain forward functions use the running statistics carried in the params record. Training-mode forms (batchnorm1d_forward_train, batchnorm2d_forward_train) compute batch statistics and return the normalized output together with updated params (the running-statistics exponential moving average).

The eval-mode batch-norm forward normalizes with the stored running mean and variance (examples/p1/batchnorm_eval_demo.ch):

import School.Nn.BatchNorm (BatchNorm1dConfig, BatchNorm1dParams, batchnorm1d_forward)
...
params = BatchNorm1dParams { gamma: gamma, beta: beta, running_mean: running_mean, running_var: running_var }
cfg = BatchNorm1dConfig { num_features: cast(2, int64), eps: cast(0.0000001, f32), momentum: cast(0.1, f32) }
y = batchnorm1d_forward(x, params, cfg)

School.Nn.Attention (src/nn/attention.ch) provides scaled dot-product attention and its variants:

  • scaled_dot_product_attention(q, k, v, scale): the core SDPA over [s, d] query, key, and value tensors with an [s, s] scale.
  • causal_scaled_dot_product_attention(q, k, v, scale, mask): SDPA with an additive causal mask.
  • causal_sdpa_with_sink(q, k, v, scale, mask, sink): causal SDPA with a learned per-head attention sink term.
  • multi_head_attention and grouped_query_attention: head-wise SDPA entry points.
  • gqa_broadcast_kv(kv_pool, group_map): broadcasts a pool of key/value heads to query heads by gather, for grouped-query attention.

School.Nn.CausalMask.causal_mask(seq_len) (src/nn/causalmask.ch) builds the additive [s, s] lower-triangular mask: zero on and below the diagonal, a large negative value above it.

A minimal attention forward pass (examples/p0/tiny_attention.ch):

import School.Nn.Attention (scaled_dot_product_attention)
...
out = scaled_dot_product_attention(q, k, v, scale)

The causal variant uses the mask (examples/p5/causal_mask_demo.ch):

import School.Nn.CausalMask (causal_mask)
import School.Nn.Attention (causal_scaled_dot_product_attention)
...
mask = causal_mask(cast(2, int64))
out = causal_scaled_dot_product_attention(q, k, v, scale, mask)
  • School.Nn.Pool (src/nn/pool.ch): 1D pooling, avgpool1d_forward and maxpool1d_forward, each taking a kernel and a stride.
  • School.Nn.Pool2d (src/nn/pool2d.ch): 2D pooling over NCHW inputs. The config records AvgPool2dConfig and MaxPool2dConfig carry the window, the stride, and a PaddingMode of Valid or Same. avgpool2d_forward and maxpool2d_forward are the windowed forms; adaptive_avgpool2d_forward and adaptive_maxpool2d_forward pool to a requested output height and width.
  • School.Nn.Pad (src/nn/pad.ch): padding over rank-1 through rank-4 tensors (pad_forward, pad_forward_2d, pad_forward_3d, pad_forward_4d) with a PadMode selecting constant, replicate, or reflect padding.
  • School.Nn.Upsample.nearest_upsample2d (src/nn/upsample.ch): nearest-neighbor 2D upsampling by an integer factor over NCHW inputs.

A 2x2 average pool with stride 2 and Valid padding (examples/p1/avgpool2d_demo.ch):

import School.Nn.Pool2d (PaddingMode, Valid, AvgPool2dConfig, avgpool2d_forward)
...
cfg = AvgPool2dConfig { window_h: cast(2, int64), window_w: cast(2, int64), stride_h: cast(2, int64), stride_w: cast(2, int64), padding: Valid }
y = avgpool2d_forward(x, cfg)
  • School.Nn.Dropout (src/nn/dropout.ch): inverted dropout. dropout_init(seed) builds the DropoutState, and dropout_apply (with rank-fixed _2d, _3d, _4d forms) takes the input, a rate, a Mode, and the state. In eval mode it is the identity; in train mode it drops and rescales.
  • School.Nn.Stochastic_Depth (src/nn/stochastic_depth.ch): drop-path for residual blocks. stochastic_depth_init(seed) builds the state and stochastic_depth_apply_4d drops a whole NCHW sample with the given rate in train mode, passing it through in eval mode.

School.Nn.Compose (src/nn/compose.ch) provides residual combinators: residual(x, inner_p, f) adds the input to an inner function applied to it, and residual_4d(x, fx) is the plain residual add for NCHW tensors.

School.Module.Seq (src/module/seq.ch) provides shape-preserving block combinators: sequential2 and sequential3 chain two or three functions, branch2 sums two parallel branches, and residual_with_norm applies a normalization and an inner function around a residual connection.

School.Nn.Generate (src/nn/generate.ch) is autoregressive decoding for transformer models. generate runs greedy decoding for a fixed number of tokens given a step model and a context, and generate_with runs configurable decoding under a GenerateConfig (max_tokens, temperature, top_k, top_p). greedy_next_tokens and sample_next_tokens expose the per-step token selection. The model is supplied as a callback that maps the current ids and an optional KVCache to logits and an updated cache.