Package io.github.kirstenali.deepj.loss
Class CrossEntropyLoss
java.lang.Object
io.github.kirstenali.deepj.loss.CrossEntropyLoss
- All Implemented Interfaces:
LossFunction
Cross-entropy loss with integer class targets.
Expected shapes:
- predicted (logits): [nTokens x vocab]
- actual (class indices): [nTokens x 1] where each entry is an integer in [0, vocab)
This class also provides helpers for common language-modeling usage where targets are provided as int[].
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionstatic TensorfromIntTargets(int[] targets) Builds a [n x 1] Tensor from int[] targets.static TensorConvenience helper: gradient w.r.t.static doubleConvenience helper: compute loss from logits and int targets.doublestatic int[]toIntTargets(Tensor actual) Converts a [n x 1] Tensor of class indices into an int[].
-
Constructor Details
-
CrossEntropyLoss
public CrossEntropyLoss()
-
-
Method Details
-
loss
- Specified by:
lossin interfaceLossFunction
-
gradient
- Specified by:
gradientin interfaceLossFunction
-
loss
Convenience helper: compute loss from logits and int targets. -
gradient
-
toIntTargets
Converts a [n x 1] Tensor of class indices into an int[]. -
fromIntTargets
Builds a [n x 1] Tensor from int[] targets.
-