Class Trainer
java.lang.Object
io.github.kirstenali.deepj.training.Trainer
A small, reusable training loop wrapper.
This library supports different model/data shapes (e.g. supervised Tensor->Tensor models,
and causal language models that operate on token ids). Rather than duplicating full trainers,
Trainer delegates a single training step to a pluggable Trainer.StepFunction.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic interfacestatic interface -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionTrain until maxSteps or until EMA loss goes below targetEmaLoss (if provided).train(int maxSteps, int batchSize, int logEvery, double emaBeta, Double targetEmaLoss, int releaseEverySteps) train(int maxSteps, int batchSize, int logEvery, double emaBeta, Double targetEmaLoss, int releaseEverySteps, Trainer.StepHook stepHook) Train until maxSteps or until EMA loss goes below targetEmaLoss (if provided).train(int maxSteps, int batchSize, int logEvery, double emaBeta, Double targetEmaLoss, Trainer.StepHook stepHook) Train until maxSteps or until EMA loss goes below targetEmaLoss (if provided).doubletrainStep(int batchSize)
-
Constructor Details
-
Trainer
-
-
Method Details
-
trainStep
public double trainStep(int batchSize) -
train
public TrainingResult train(int maxSteps, int batchSize, int logEvery, double emaBeta, Double targetEmaLoss) Train until maxSteps or until EMA loss goes below targetEmaLoss (if provided). Uses the default periodic backend release cadence. -
train
public TrainingResult train(int maxSteps, int batchSize, int logEvery, double emaBeta, Double targetEmaLoss, int releaseEverySteps) -
train
public TrainingResult train(int maxSteps, int batchSize, int logEvery, double emaBeta, Double targetEmaLoss, Trainer.StepHook stepHook) Train until maxSteps or until EMA loss goes below targetEmaLoss (if provided). Uses the default periodic backend release cadence. -
train
public TrainingResult train(int maxSteps, int batchSize, int logEvery, double emaBeta, Double targetEmaLoss, int releaseEverySteps, Trainer.StepHook stepHook) Train until maxSteps or until EMA loss goes below targetEmaLoss (if provided). releaseEverySteps invalid input: '<'= 0 disables periodic release, but final release still runs.
-