package aima.core.learning.neural;

import aima.core.util.Util;
import aima.core.util.math.Matrix;
import aima.core.util.math.Vector;

/* loaded from: input_file:lib/aima-core-3.0.0.jar:aima/core/learning/neural/Layer.class */
public class Layer {
    private final Matrix weightMatrix;
    Vector biasVector;
    Vector lastBiasUpdateVector;
    private final ActivationFunction activationFunction;
    private Vector lastActivationValues;
    private Vector lastInducedField;
    private Matrix lastWeightUpdateMatrix;
    private Matrix penultimateWeightUpdateMatrix;
    private Vector penultimateBiasUpdateVector;
    private Vector lastInput;

    public Layer(Matrix matrix, Vector vector, ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
        this.weightMatrix = matrix;
        this.lastWeightUpdateMatrix = new Matrix(matrix.getRowDimension(), matrix.getColumnDimension());
        this.penultimateWeightUpdateMatrix = new Matrix(matrix.getRowDimension(), matrix.getColumnDimension());
        this.biasVector = vector;
        this.lastBiasUpdateVector = new Vector(vector.getRowDimension());
        this.penultimateBiasUpdateVector = new Vector(vector.getRowDimension());
    }

    public Layer(int i, int i2, double d, double d2, ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
        this.weightMatrix = new Matrix(i, i2);
        this.lastWeightUpdateMatrix = new Matrix(this.weightMatrix.getRowDimension(), this.weightMatrix.getColumnDimension());
        this.penultimateWeightUpdateMatrix = new Matrix(this.weightMatrix.getRowDimension(), this.weightMatrix.getColumnDimension());
        this.biasVector = new Vector(i);
        this.lastBiasUpdateVector = new Vector(this.biasVector.getRowDimension());
        this.penultimateBiasUpdateVector = new Vector(this.biasVector.getRowDimension());
        initializeMatrix(this.weightMatrix, d, d2);
        initializeVector(this.biasVector, d, d2);
    }

    public Vector feedForward(Vector vector) {
        this.lastInput = vector;
        Matrix plus = this.weightMatrix.times(vector).plus(this.biasVector);
        Vector vector2 = new Vector(numberOfNeurons());
        for (int i = 0; i < numberOfNeurons(); i++) {
            vector2.setValue(i, plus.get(i, 0));
        }
        this.lastInducedField = vector2.copyVector();
        Vector vector3 = new Vector(numberOfNeurons());
        for (int i2 = 0; i2 < numberOfNeurons(); i2++) {
            vector3.setValue(i2, this.activationFunction.activation(vector2.getValue(i2)));
        }
        this.lastActivationValues = vector3.copyVector();
        return vector3;
    }

    public Matrix getWeightMatrix() {
        return this.weightMatrix;
    }

    public Vector getBiasVector() {
        return this.biasVector;
    }

    public int numberOfNeurons() {
        return this.weightMatrix.getRowDimension();
    }

    public int numberOfInputs() {
        return this.weightMatrix.getColumnDimension();
    }

    public Vector getLastActivationValues() {
        return this.lastActivationValues;
    }

    public Vector getLastInducedField() {
        return this.lastInducedField;
    }

    public Matrix getLastWeightUpdateMatrix() {
        return this.lastWeightUpdateMatrix;
    }

    public void setLastWeightUpdateMatrix(Matrix matrix) {
        this.lastWeightUpdateMatrix = matrix;
    }

    public Matrix getPenultimateWeightUpdateMatrix() {
        return this.penultimateWeightUpdateMatrix;
    }

    public void setPenultimateWeightUpdateMatrix(Matrix matrix) {
        this.penultimateWeightUpdateMatrix = matrix;
    }

    public Vector getLastBiasUpdateVector() {
        return this.lastBiasUpdateVector;
    }

    public void setLastBiasUpdateVector(Vector vector) {
        this.lastBiasUpdateVector = vector;
    }

    public Vector getPenultimateBiasUpdateVector() {
        return this.penultimateBiasUpdateVector;
    }

    public void setPenultimateBiasUpdateVector(Vector vector) {
        this.penultimateBiasUpdateVector = vector;
    }

    public void updateWeights() {
        this.weightMatrix.plusEquals(this.lastWeightUpdateMatrix);
    }

    public void updateBiases() {
        Matrix plusEquals = this.biasVector.plusEquals(this.lastBiasUpdateVector);
        Vector vector = new Vector(plusEquals.getRowDimension());
        for (int i = 0; i < plusEquals.getRowDimension(); i++) {
            vector.setValue(i, plusEquals.get(i, 0));
        }
        this.biasVector = vector;
    }

    public Vector getLastInputValues() {
        return this.lastInput;
    }

    public ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public void acceptNewWeightUpdate(Matrix matrix) {
        setPenultimateWeightUpdateMatrix(getLastWeightUpdateMatrix());
        setLastWeightUpdateMatrix(matrix);
    }

    public void acceptNewBiasUpdate(Vector vector) {
        setPenultimateBiasUpdateVector(getLastBiasUpdateVector());
        setLastBiasUpdateVector(vector);
    }

    public Vector errorVectorFrom(Vector vector) {
        return vector.minus(getLastActivationValues());
    }

    private static void initializeMatrix(Matrix matrix, double d, double d2) {
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            for (int i2 = 0; i2 < matrix.getColumnDimension(); i2++) {
                matrix.set(i, i2, Util.generateRandomDoubleBetween(d, d2));
            }
        }
    }

    private static void initializeVector(Vector vector, double d, double d2) {
        for (int i = 0; i < vector.size(); i++) {
            vector.setValue(i, Util.generateRandomDoubleBetween(d, d2));
        }
    }
}
