/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood;

import dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.hmc.JointGradient;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.TransformedMultivariateParameter;
import dr.inference.model.Variable;
import dr.inference.operators.hmc.NumericalHessianFromGradient;
import dr.util.Transform;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class ApproximateTreeDataLikelihood
extends AbstractModelLikelihood {
    private double marginalLikelihood;
    private MaximizerWrtParameter maximizer;
    private Parameter parameter;
    private Likelihood likelihood;
    private boolean likelihoodKnown = false;
    private final HessianWrtParameterProvider hessianWrtParameterProvider;
    public static final String APPROXIMATE_LIKELIHOOD = "approximateTreeDataLikelihood";
    private final double marginalLikelihoodConst;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(MaximizerWrtParameter.class)};

        @Override
        public String getParserName() {
            return ApproximateTreeDataLikelihood.APPROXIMATE_LIKELIHOOD;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            MaximizerWrtParameter maximizerWrtParameter = (MaximizerWrtParameter)xMLObject.getChild(MaximizerWrtParameter.class);
            return new ApproximateTreeDataLikelihood(maximizerWrtParameter);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "Approximates the marginal likelihood of the data given the tree using Laplace approximation";
        }

        @Override
        public Class getReturnType() {
            return ApproximateTreeDataLikelihood.class;
        }
    };

    public ApproximateTreeDataLikelihood(MaximizerWrtParameter maximizerWrtParameter) {
        super(APPROXIMATE_LIKELIHOOD);
        this.maximizer = maximizerWrtParameter;
        this.likelihood = maximizerWrtParameter.getLikelihood();
        GradientWrtParameterProvider gradientWrtParameterProvider = maximizerWrtParameter.getGradient();
        this.parameter = gradientWrtParameterProvider.getParameter();
        this.marginalLikelihoodConst = (double)this.parameter.getDimension() / 2.0 * Math.log(Math.PI * 2);
        this.hessianWrtParameterProvider = maximizerWrtParameter.getTransform() != null ? this.constructHessian() : (this.isGradientProvidingHessian(gradientWrtParameterProvider) ? (HessianWrtParameterProvider)gradientWrtParameterProvider : new NumericalHessianFromGradient(gradientWrtParameterProvider));
        this.updateParameterMAP();
        this.updateMarginalLikelihood();
        this.addVariable(this.parameter);
        this.addModel(maximizerWrtParameter.getLikelihood().getModel());
    }

    private boolean isGradientProvidingHessian(GradientWrtParameterProvider gradientWrtParameterProvider) {
        boolean bl = false;
        if (gradientWrtParameterProvider instanceof HessianWrtParameterProvider) {
            if (gradientWrtParameterProvider instanceof JointGradient) {
                JointGradient jointGradient = (JointGradient)gradientWrtParameterProvider;
                boolean bl2 = false;
                for (GradientWrtParameterProvider gradientWrtParameterProvider2 : jointGradient.getDerivativeList()) {
                    if (gradientWrtParameterProvider2 instanceof HessianWrtParameterProvider) continue;
                    bl2 = true;
                }
                bl = !bl2;
            } else {
                bl = true;
            }
        }
        return bl;
    }

    private HessianWrtParameterProvider constructHessian() {
        GradientWrtParameterProvider gradientWrtParameterProvider = new GradientWrtParameterProvider(){
            private TransformedMultivariateParameter transformedParameter;
            {
                this.transformedParameter = new TransformedMultivariateParameter(ApproximateTreeDataLikelihood.this.parameter, (Transform.MultivariableTransform)ApproximateTreeDataLikelihood.this.maximizer.getTransform());
            }

            @Override
            public Likelihood getLikelihood() {
                throw new RuntimeException("should not be called");
            }

            @Override
            public Parameter getParameter() {
                return this.transformedParameter;
            }

            @Override
            public int getDimension() {
                return this.transformedParameter.getDimension();
            }

            @Override
            public double[] getGradientLogDensity() {
                double[] dArray = ApproximateTreeDataLikelihood.this.maximizer.getGradient().getGradientLogDensity();
                return ApproximateTreeDataLikelihood.this.maximizer.getTransform().updateGradientLogDensity(dArray, ApproximateTreeDataLikelihood.this.parameter.getParameterValues(), 0, ApproximateTreeDataLikelihood.this.parameter.getDimension());
            }
        };
        return new NumericalHessianFromGradient(gradientWrtParameterProvider);
    }

    private void updateMarginalLikelihood() {
        double[] dArray = this.hessianWrtParameterProvider.getDiagonalHessianLogDensity();
        double d = 0.0;
        for (int i = 0; i < this.parameter.getDimension(); ++i) {
            d += Math.log(Math.abs(dArray[i]));
        }
        this.marginalLikelihood = this.marginalLikelihoodConst + 0.5 * d + this.likelihood.getLogLikelihood() + (this.maximizer.getTransform() == null ? 0.0 : this.maximizer.getTransform().logJacobian(this.parameter.getParameterValues(), 0, this.parameter.getDimension()));
        this.likelihoodKnown = true;
    }

    private void updateParameterMAP() {
        this.maximizer.maximize();
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.likelihoodKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
        this.likelihoodKnown = false;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public Model getModel() {
        return null;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.updateParameterMAP();
            this.updateMarginalLikelihood();
        }
        return this.marginalLikelihood;
    }

    @Override
    public void makeDirty() {
        this.likelihoodKnown = false;
    }
}

