Class CrossEntropyLoss

java.lang.Object
io.github.kirstenali.deepj.loss.CrossEntropyLoss
All Implemented Interfaces:
LossFunction

public final class CrossEntropyLoss extends Object implements LossFunction
Cross-entropy loss with integer class targets.

Expected shapes:

  • predicted (logits): [nTokens x vocab]
  • actual (class indices): [nTokens x 1] where each entry is an integer in [0, vocab)

This class also provides helpers for common language-modeling usage where targets are provided as int[].