Training loop and data
This chapter covers the data layer (School.Data), the training loop and
training metrics (School.Train), and hyperparameter search (School.Hpo).
Datasets
Section titled “Datasets”School.Data.Dataset (src/data/dataset.ch) holds tabular data and the column
transforms.
Dataset { x, y, rows }is a flat dataset of a feature tensor, an integer label tensor, and a row count.dataset_from_lists(xs, ys, rows)builds one from flat lists.FloatColandIntColwrap a float or integer column.normalize(col, lo, hi)rescales a float column into[lo, hi].standardize(col)rescales a float column to zero mean and unit variance.one_hot(labels, num_classes)expands anIntColof labels into a one-hotFloatCol.
One-hot encoding three labels into three classes (examples/p3/onehot_demo.ch):
import School.Data.Dataset (IntCol, one_hot, FloatCol)...labels = IntCol { values: to_tensor([cast(0, int64), cast(2, int64), cast(1, int64)]) }encoded = one_hot(labels, cast(3, int64))The data loader
Section titled “The data loader”School.Data.Loader (src/data/loader.ch) turns a dataset into shuffled
mini-batches.
dataloader_init(features, labels, rows, feat_dim, batch_size, shuffle, seed, drop_last)builds aDataLoaderover a rank-2 feature tensor and an integer label tensor.dataloader_next(loader)returns the next loader, the nextBatch { x, y }, and a flag that is true while batches remain.dataloader_reset(loader, seed)reshuffles for a new epoch from a fresh seed.split(n, ratios, seed)partitionsnrow indices into the given ratios (for train, validation, and test splits).
The loader is deterministic: the same inputs and seed produce the same batch order.
Parsing a small in-source table, splitting it, and reading the first batch
(examples/p3/csv_to_dataloader.ch):
import School.Data.Dataset (Dataset, dataset_from_lists)import School.Data.Loader (Batch, DataLoader, dataloader_init, dataloader_next, split)...ds = dataset_from_lists(fixture_features(), fixture_labels(), rows)ratios = to_tensor([cast(0.5, f32), cast(0.25, f32), cast(0.25, f32)])parts = split(rows, &ratios, cast(0, int64))loader = dataloader_init(features, labels, rows, feat_dim, cast(2, int64), true, cast(0, int64), false)first = dataloader_next(loader)The training loop
Section titled “The training loop”School.Train.Loop (src/train/loop.ch) provides the loop state and the epoch
combinator.
TrainState { step, accumulated_loss }accumulates loss across steps.train_state_init()builds the zero state,train_loop_step(state, loss)andtrain_history_log(state, loss)fold a step's loss in, andhistory_avg_loss(state)returns the running mean loss.loop(initial_params, initial_state, loader, step, num_epochs, seed)is the pure epoch-by-batch combinator. It walks each epoch's batches, threading the parameters and state through a user-supplied step, and reshuffles between epochs. Given the same inputs it produces the same result.TrainConfig { num_epochs, log_every_n_steps, accumulate_grad_batches }carries the loop configuration, and the gradient-accumulation helpers (GradAccumState,accum_state_init,grad_accumulate) accumulate a parameter-shaped gradient across micro-batches before a step.
The step the loop calls is where differentiation happens. The loop never
differentiates; the step does, by taking the grad of a loss. The MLP step
mlp_train_step is the reference step, covered in the
Model library.
Training metrics
Section titled “Training metrics”School.Train.MetricsExt (src/train/metricsext.ch) provides binary and
regression metrics:
precision_binary,recall_binary,f1_binaryover a prediction and a target with a decisionthreshold.r_squared,mse_metric,mae_metricfor regression.
School.Train.MetricsExt_Multiclass (src/train/metricsext_multiclass.ch) provides
the multiclass metrics over logits and integer labels, with an Average of
Macro, Micro, or Weighted:
precision,recall,f1.confusion_matrix, which returns a class-by-class integer count matrix.roc_auc.
Macro-averaged precision, recall, and F1 (examples/p3/multiclass_metrics_demo.ch):
import School.Train.MetricsExt_Multiclass (precision, recall, f1, Average, Macro)...p = precision(&logits, &labels, Macro)r = recall(&logits, &labels, Macro)f = f1(&logits, &labels, Macro)Hyperparameter search
Section titled “Hyperparameter search”School.Hpo (src/hpo.ch) provides integer-axis hyperparameter search:
grid_search over a GridConfig and random_search over a RandomConfig.