Skip to content

Model library

The model library lives under the School.Models prefix. Each model follows the same convention: a config record describing the architecture, a params record holding the weight tensors, an *_init_like(seed, cfg) function that builds seeded parameters, and an *_forward(...) function that runs the forward pass. Models that hold their parameters as a list of weight tensors also export *_params_to_list and *_params_from_list for the optimizer boundary.

This chapter documents the models defined under src/models/.

School.Models.Mlp (src/models/mlp.ch) is a two-layer perceptron: a linear layer, a ReLU, and a second linear layer.

  • MlpConfig { in_dim, hidden_dim, out_dim, seed }.
  • MlpParams { fc1_w, fc1_b, fc2_w, fc2_b }.
  • mlp_init_like(seed, cfg) builds seeded parameters.
  • mlp_forward(x, params) runs the forward pass over a rank-2 input.

The MLP also carries its own training surface, which is the reference for how School composes a forward pass, a loss, and grad into a step:

  • mlp_sq_loss(...) is a pure-tensor sum-of-squares loss over the flat weight and bias tensors. It is the function that gets differentiated.
  • mlp_sq_loss_value(params, x, y) is the ADT-level wrapper that reports the loss without differentiating.
  • mlp_sgd_step(params, x, y, lr) takes grad(mlp_sq_loss)(...) to get the per-parameter gradient, subtracts the scaled gradient from each of the four parameter tensors, and returns the next parameters with the pre-step loss.

The reason the loss is written over flat tensors and unpacked in a wrapper is that grad differentiates a scalar-returning function of tensors. The parameter ADT is destructured in the caller and the flat tensors are handed to the pure loss.

A forward pass (examples/p5/mlp_demo.ch):

import School.Models.Mlp (MlpConfig, mlp_init_like, mlp_forward)
...
cfg = MlpConfig { in_dim: cast(4, int64), hidden_dim: cast(8, int64), out_dim: cast(3, int64), seed: cast(42, int64) }
params = mlp_init_like(cast(0, int64), cfg)
logits = mlp_forward(x, params)

A full training run over the MLP is in Getting started.

School.Models.Gpt_Oss (src/models/gpt_oss.ch) is a decoder-only transformer. It is a pre-norm dense decoder built from the migrated transformer surface: a token embedding (gather), n_blocks decoder blocks, and an untied language-model head. The block uses RMS normalization, rotary position embedding (so there is no learned position table), grouped-query attention with learned per-head attention sinks, and a feed-forward block.

  • GptOssConfig { vocab, max_seq, d_model, n_heads, n_kv_heads, n_blocks, d_ff, seed }.
  • GptOssParams holds the token embedding, the language-model head, the list of BlockParams, and the final norm.
  • gpt_oss_init_like(seed, cfg) and gpt_oss_forward(ids, params, cfg).

School.Models.Bert_Tiny (src/models/bert_tiny.ch) is a bidirectional encoder transformer with a masked-language-model head. It sums token, token-type, and learned positional embeddings, applies an embedding-level layer normalization, runs n_blocks post-norm encoder blocks, and ends in an MLM head (a dense layer, GELU, layer normalization, and a decoder tied to the token-embedding table).

  • BertTinyConfig { vocab, max_seq, d_model, n_heads, n_blocks, d_ff, type_vocab, seed }.
  • BertTinyParams holds the three embedding tables, the embedding norm, the list of BertTinyBlockParams, and the MLM head parameters.
  • bert_tiny_init_like(seed, cfg) and bert_tiny_forward(ids, type_ids, params, cfg).

School.Models.Simple_Vision_Cnn (src/models/simple_vision_cnn.ch) is a LeNet-5 convolutional classifier over a fixed 32x32 input: two convolution-and-pool stages followed by three fully connected layers.

  • SimpleVisionCnnConfig { in_channels, image_h, image_w, num_classes, seed }.
  • SimpleVisionCnnParams holds the two convolution weight and bias pairs and the three fully connected layers.
  • simple_vision_cnn_init_like(seed, cfg) and simple_vision_cnn_forward(x, params).

A forward pass over a 1x1x32x32 input (examples/p5/simple_vision_cnn_demo.ch):

import School.Models.Simple_Vision_Cnn (SimpleVisionCnnConfig, simple_vision_cnn_init_like, simple_vision_cnn_forward)
...
cfg = SimpleVisionCnnConfig { in_channels: cast(1, int64), image_h: cast(32, int64), image_w: cast(32, int64), num_classes: cast(10, int64), seed: cast(42, int64) }
params = simple_vision_cnn_init_like(cast(0, int64), cfg)
logits = simple_vision_cnn_forward(x, params)

School.Models.Resnet_V2 (src/models/resnet_v2.ch) is a pre-activation (v2) residual network: a convolution stem, two pre-activation residual blocks at a fixed width of 16 channels with 3x3 kernels, a final batch-norm tail, and a linear head. It runs at a fixed input shape (batch 1, c_in channels, 8x8 spatial); only c_in and num_classes are polymorphic.

  • ResnetV2Config { in_channels, image_h, image_w, num_classes, seed }.
  • ResnetV2Params holds the stem, the per-block batch-norm and convolution parameters, the final batch norm, and the head.
  • resnet_v2_init_like(seed, cfg) and resnet_v2_forward(x, params).

School.Models.Conv_Next (src/models/conv_next.ch) is a small isotropic ConvNeXt-style classifier. A 2x2 stride-2 patchify stem projects the input channels to cdim and halves the resolution, then two ConvNeXt blocks run at fixed channels and resolution. Each block is a depthwise 7x7 convolution, a channels-last layer normalization, a 1x1 pointwise expansion, GELU, a 1x1 pointwise projection, a per-channel learnable scale, stochastic depth, and a residual add.

  • ConvNextConfig { in_channels, image_size, cdim, num_classes, seed }.
  • ConvNextParams holds the stem, two ConvNextBlockParams, and the head norm and linear head.
  • conv_next_init_like(seed, cfg) and conv_next_forward(x, params).

School.Models.Unet (src/models/unet.ch) is a 2-level U-Net over NCHW inputs: an encoder, a bottleneck, and a decoder. It composes the im2col convolution, group normalization, nearest-neighbor upsampling, strided maxpool downsampling, and channel-axis skip concatenation. Each block is the U-Net double 3x3 convolution (two consecutive convolution, group-norm, ReLU sub-blocks).

  • UnetConfig { in_channels, out_channels, base_channels, num_groups, eps, seed }.
  • UnetParams holds the encoder, bottleneck, and decoder convolution weights and the final 1x1 projection.
  • unet_init_like(seed, cfg) and unet_forward(x, params, cfg).