Package org.DeepJ.ann

Class Tensor

java.lang.Object
org.DeepJ.ann.Tensor

public class Tensor extends Object
  • Field Details

    • data

      public double[][] data
    • rows

      public int rows
    • cols

      public int cols
  • Constructor Details

    • Tensor

      public Tensor(int rows, int cols)
    • Tensor

      public Tensor(double[][] data)
  • Method Details

    • iterate

      public void iterate(BiConsumer<Integer,Integer> operation)
    • iterate

      public static void iterate(BiConsumer<Integer,Integer> operation, Tensor t)
    • iterate

      public static void iterate(BiConsumer<Integer,Integer> operation, Tensor t, Tensor u)
    • random

      public static Tensor random(int rows, int cols, Random rand)
    • matmul

      public Tensor matmul(Tensor other)
    • multiply

      public Tensor multiply(Tensor other)
    • multiplyBroadcastCols

      public Tensor multiplyBroadcastCols(Tensor colVector)
    • sum

      public double sum()
    • transpose

      public Tensor transpose()
    • multiplyScalar

      public Tensor multiplyScalar(double scalar)
    • addScalar

      public Tensor addScalar(double scalar)
    • divideScalar

      public Tensor divideScalar(double scalar)
    • softmaxRows

      public static Tensor softmaxRows(Tensor logits)
    • softmaxBackward

      public static Tensor softmaxBackward(Tensor upstreamGrad, Tensor softmaxOutput)
    • add

      public Tensor add(Tensor other)
    • addBroadcastCols

      public Tensor addBroadcastCols(Tensor colVector)
    • subtract

      public Tensor subtract(Tensor other)
    • mseLoss

      public double mseLoss(Tensor target)
    • unflattenToTensor

      public static Tensor unflattenToTensor(double[] flat, int rows, int cols)
    • flattenTensor

      public static double[] flattenTensor(Tensor t)
    • meanAlongRows

      public Tensor meanAlongRows()
    • varianceAlongRows

      public Tensor varianceAlongRows()
    • sumAlongRows

      public Tensor sumAlongRows()
    • sumAlongCols

      public Tensor sumAlongCols()
    • divideBroadcastCols

      public Tensor divideBroadcastCols(Tensor colVector)
    • subtractBroadcastCols

      public Tensor subtractBroadcastCols(Tensor colVector)
    • addBroadcastRows

      public Tensor addBroadcastRows(Tensor rowVector)
    • multiplyBroadcastRows

      public Tensor multiplyBroadcastRows(Tensor rowVector)
    • sqrt

      public Tensor sqrt()
    • pow

      public Tensor pow(double exponent)
    • ones

      public static Tensor ones(int rows, int cols)
    • zeros

      public static Tensor zeros(int rows, int cols)
    • causalMask

      public static Tensor causalMask(int size)
    • print

      public void print(String label)