Skip to content

Training loop and data

This chapter covers the data layer (School.Data), the training loop and training metrics (School.Train), and hyperparameter search (School.Hpo).

School.Data.Dataset (src/data/dataset.ch) holds tabular data and the column transforms.

  • Dataset { x, y, rows } is a flat dataset of a feature tensor, an integer label tensor, and a row count. dataset_from_lists(xs, ys, rows) builds one from flat lists.
  • FloatCol and IntCol wrap a float or integer column.
  • normalize(col, lo, hi) rescales a float column into [lo, hi].
  • standardize(col) rescales a float column to zero mean and unit variance.
  • one_hot(labels, num_classes) expands an IntCol of labels into a one-hot FloatCol.

One-hot encoding three labels into three classes (examples/p3/onehot_demo.ch):

import School.Data.Dataset (IntCol, one_hot, FloatCol)
...
labels = IntCol { values: to_tensor([cast(0, int64), cast(2, int64), cast(1, int64)]) }
encoded = one_hot(labels, cast(3, int64))

School.Data.Loader (src/data/loader.ch) turns a dataset into shuffled mini-batches.

  • dataloader_init(features, labels, rows, feat_dim, batch_size, shuffle, seed, drop_last) builds a DataLoader over a rank-2 feature tensor and an integer label tensor.
  • dataloader_next(loader) returns the next loader, the next Batch { x, y }, and a flag that is true while batches remain.
  • dataloader_reset(loader, seed) reshuffles for a new epoch from a fresh seed.
  • split(n, ratios, seed) partitions n row indices into the given ratios (for train, validation, and test splits).

The loader is deterministic: the same inputs and seed produce the same batch order.

Parsing a small in-source table, splitting it, and reading the first batch (examples/p3/csv_to_dataloader.ch):

import School.Data.Dataset (Dataset, dataset_from_lists)
import School.Data.Loader (Batch, DataLoader, dataloader_init, dataloader_next, split)
...
ds = dataset_from_lists(fixture_features(), fixture_labels(), rows)
ratios = to_tensor([cast(0.5, f32), cast(0.25, f32), cast(0.25, f32)])
parts = split(rows, &ratios, cast(0, int64))
loader = dataloader_init(features, labels, rows, feat_dim, cast(2, int64), true, cast(0, int64), false)
first = dataloader_next(loader)

School.Train.Loop (src/train/loop.ch) provides the loop state and the epoch combinator.

  • TrainState { step, accumulated_loss } accumulates loss across steps. train_state_init() builds the zero state, train_loop_step(state, loss) and train_history_log(state, loss) fold a step's loss in, and history_avg_loss(state) returns the running mean loss.
  • loop(initial_params, initial_state, loader, step, num_epochs, seed) is the pure epoch-by-batch combinator. It walks each epoch's batches, threading the parameters and state through a user-supplied step, and reshuffles between epochs. Given the same inputs it produces the same result.
  • TrainConfig { num_epochs, log_every_n_steps, accumulate_grad_batches } carries the loop configuration, and the gradient-accumulation helpers (GradAccumState, accum_state_init, grad_accumulate) accumulate a parameter-shaped gradient across micro-batches before a step.

The step the loop calls is where differentiation happens. The loop never differentiates; the step does, by taking the grad of a loss. The MLP step mlp_train_step is the reference step, covered in the Model library.

School.Train.MetricsExt (src/train/metricsext.ch) provides binary and regression metrics:

  • precision_binary, recall_binary, f1_binary over a prediction and a target with a decision threshold.
  • r_squared, mse_metric, mae_metric for regression.

School.Train.MetricsExt_Multiclass (src/train/metricsext_multiclass.ch) provides the multiclass metrics over logits and integer labels, with an Average of Macro, Micro, or Weighted:

  • precision, recall, f1.
  • confusion_matrix, which returns a class-by-class integer count matrix.
  • roc_auc.

Macro-averaged precision, recall, and F1 (examples/p3/multiclass_metrics_demo.ch):

import School.Train.MetricsExt_Multiclass (precision, recall, f1, Average, Macro)
...
p = precision(&logits, &labels, Macro)
r = recall(&logits, &labels, Macro)
f = f1(&logits, &labels, Macro)

School.Hpo (src/hpo.ch) provides integer-axis hyperparameter search: grid_search over a GridConfig and random_search over a RandomConfig.