/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.AutoCorrelatedBranchRatesDistribution;
import dr.evomodel.branchratemodel.AutoCorrelatedGradientWrtIncrements;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;

public class BranchRateGradientWrtIncrements
implements GradientWrtParameterProvider,
Reportable {
    private final GradientWrtParameterProvider rateGradientProvider;
    private List<AutoCorrelatedGradientWrtIncrements> priorGradientProvider;
    private List<DifferentiableBranchRates> branchRates = new ArrayList<DifferentiableBranchRates>();
    private final Tree tree;
    private List<AutoCorrelatedBranchRatesDistribution.BranchVarianceScaling> scaling = new ArrayList<AutoCorrelatedBranchRatesDistribution.BranchVarianceScaling>();
    private List<AutoCorrelatedBranchRatesDistribution.BranchRateUnits> units = new ArrayList<AutoCorrelatedBranchRatesDistribution.BranchRateUnits>();
    private final int listDim;
    private final int paramDim;
    private final Parameter parameter;

    public BranchRateGradientWrtIncrements(GradientWrtParameterProvider gradientWrtParameterProvider, List<AutoCorrelatedGradientWrtIncrements> list) {
        this.rateGradientProvider = gradientWrtParameterProvider;
        this.priorGradientProvider = list;
        this.listDim = list.size();
        ArrayList<AutoCorrelatedBranchRatesDistribution> arrayList = new ArrayList<AutoCorrelatedBranchRatesDistribution>();
        CompoundParameter compoundParameter = new CompoundParameter(null);
        for (int i = 0; i < this.listDim; ++i) {
            AutoCorrelatedBranchRatesDistribution autoCorrelatedBranchRatesDistribution = list.get(i).getDistribution();
            arrayList.add(autoCorrelatedBranchRatesDistribution);
            this.branchRates.add(autoCorrelatedBranchRatesDistribution.getBranchRateModel());
            this.scaling.add(autoCorrelatedBranchRatesDistribution.getScaling());
            this.units.add(autoCorrelatedBranchRatesDistribution.getUnits());
            compoundParameter.addParameter(list.get(i).getParameter());
        }
        this.parameter = compoundParameter;
        this.paramDim = ((AutoCorrelatedBranchRatesDistribution)arrayList.get(0)).getParameter().getDimension();
        this.tree = ((AutoCorrelatedBranchRatesDistribution)arrayList.get(0)).getTree();
    }

    @Override
    public Likelihood getLikelihood() {
        return this.rateGradientProvider.getLikelihood();
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = this.rateGradientProvider.getGradientLogDensity();
        double[] dArray2 = new double[dArray.length];
        double[] dArray3 = new double[this.paramDim];
        double[] dArray4 = new double[dArray3.length];
        for (int i = 0; i < this.listDim; ++i) {
            System.arraycopy(dArray, i * this.paramDim, dArray3, 0, this.paramDim);
            this.recursePostOrderToAccumulateGradient(this.tree.getRoot(), dArray3, dArray4, i);
            System.arraycopy(dArray4, 0, dArray2, i * this.paramDim, this.paramDim);
        }
        return dArray2;
    }

    private double recursePostOrderToAccumulateGradient(NodeRef nodeRef, double[] dArray, double[] dArray2, int n) {
        double d = 0.0;
        if (!this.tree.isExternal(nodeRef)) {
            d += this.recursePostOrderToAccumulateGradient(this.tree.getChild(nodeRef, 0), dArray, dArray2, n);
            d += this.recursePostOrderToAccumulateGradient(this.tree.getChild(nodeRef, 1), dArray, dArray2, n);
        }
        if (!this.tree.isRoot(nodeRef)) {
            int n2 = this.branchRates.get(n).getParameterIndexFromNode(nodeRef);
            dArray2[n2] = this.scaling.get(n).inverseRescaleIncrement(d += this.units.get(n).inverseTransformGradient(dArray[n2], this.branchRates.get(n).getUntransformedBranchRate(this.tree, nodeRef)), this.tree.getBranchLength(nodeRef));
        }
        return d;
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, null);
    }
}

