Schedules
Learning-rate schedules map a step count to a learning rate. They live under two
prefixes: School.Schedule for the warmup and step-decay schedules, and
School.SchedExt for the decay, cyclic, and plateau schedules. Each schedule
pairs a config record with a function that takes the step (and, for the plateau
schedule, the running state) and returns the learning rate.
This chapter is a reference for the schedules defined in src/schedule.ch and
src/schedext.ch.
School.Schedule
Section titled “School.Schedule”Defined in src/schedule.ch.
| Schedule | Function | Config fields |
|---|---|---|
| Cosine with warmup | cosine_with_warmup(step, config) | warmup_steps, total_steps, min_lr, max_lr |
| Linear warmup | linear_warmup(step, config) | warmup_steps, target_lr |
| Step decay | step_decay(step, config) | initial_lr, decay_factor, decay_steps |
cosine_with_warmup ramps linearly from zero to max_lr over warmup_steps, then
follows a cosine curve down to min_lr by total_steps. linear_warmup ramps to
target_lr over warmup_steps and holds. step_decay multiplies initial_lr by
decay_factor each time the step passes one of the milestones in decay_steps.
Reading the cosine schedule at several steps (examples/p0/cosine_warmup_curve.ch):
import School.Schedule (CosineWarmupConfig, cosine_with_warmup)...def cwc_cfg() -> CosineWarmupConfig = CosineWarmupConfig { warmup_steps: cast(10, int64), total_steps: cast(100, int64), min_lr: cast(0.001, f32), max_lr: cast(0.1, f32) }def main() -> f32 = { _ = cosine_with_warmup(cast(0, int64), cwc_cfg()) _ = cosine_with_warmup(cast(5, int64), cwc_cfg()) peak = cosine_with_warmup(cast(10, int64), cwc_cfg()) _ = cosine_with_warmup(cast(55, int64), cwc_cfg()) _ = cosine_with_warmup(cast(100, int64), cwc_cfg()) peak}At step 10, the end of the warmup, the rate is at its peak max_lr. The config is
consumed by value and is rebuilt per call.
School.SchedExt
Section titled “School.SchedExt”Defined in src/schedext.ch.
| Schedule | Function | Config fields |
|---|---|---|
| Exponential decay | exponential_decay(config, step) | initial_lr, decay_rate |
| Polynomial decay | polynomial_decay(config, step) | initial_lr, final_lr, total_steps, power |
| One-cycle | one_cycle(config, step) | max_lr, total_steps, pct_start, min_lr |
| Cyclic | cyclic(config, step) | base_lr, max_lr, period_steps |
| Reduce on plateau | reduce_on_plateau_init, reduce_on_plateau_step | factor, patience, min_lr, threshold |
exponential_decay scales initial_lr by decay_rate raised to the step.
polynomial_decay interpolates from initial_lr to final_lr over total_steps
with the given power. one_cycle ramps up over the first pct_start fraction of
total_steps to max_lr and back down toward min_lr. cyclic oscillates between
base_lr and max_lr with the given period.
reduce_on_plateau is stateful: reduce_on_plateau_init(initial_lr) builds a
ReduceOnPlateauState, and reduce_on_plateau_step(metric, state, config) returns
the next learning rate and next state, cutting the rate by factor after patience
epochs without improvement beyond threshold, down to min_lr.
Reading exponential decay at step 10 (examples/p4/cosine_warmup_demo.ch):
import School.SchedExt (ExponentialDecayConfig, exponential_decay)...cfg = ExponentialDecayConfig { initial_lr: cast(0.1, f32), decay_rate: cast(0.9, f32) }lr_step10 = exponential_decay(cfg, cast(10, int64))