/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator$class;
import scala.Function0;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.TraitSetter;

@ScalaSignature(bytes="\u0006\u0001\u00054Q!\u0001\u0002\u0001\r9\u0011q\u0002S;cKJ\fum\u001a:fO\u0006$xN\u001d\u0006\u0003\u0007\u0011\t!\"Y4he\u0016<\u0017\r^8s\u0015\t)a!A\u0003paRLWN\u0003\u0002\b\u0011\u0005\u0011Q\u000e\u001c\u0006\u0003\u0013)\tQa\u001d9be.T!a\u0003\u0007\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005i\u0011aA8sON\u0019\u0001aD\u000b\u0011\u0005A\u0019R\"A\t\u000b\u0003I\tQa]2bY\u0006L!\u0001F\t\u0003\r\u0005s\u0017PU3g!\u00111r#G\u0010\u000e\u0003\tI!\u0001\u0007\u0002\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;peB\u0011!$H\u0007\u00027)\u0011ADB\u0001\bM\u0016\fG/\u001e:f\u0013\tq2D\u0001\u0005J]N$\u0018M\\2f!\t1\u0002\u0001\u0003\u0005\"\u0001\t\u0005\t\u0015!\u0003$\u000311\u0017\u000e^%oi\u0016\u00148-\u001a9u\u0007\u0001\u0001\"\u0001\u0005\u0013\n\u0005\u0015\n\"a\u0002\"p_2,\u0017M\u001c\u0005\tO\u0001\u0011\t\u0011)A\u0005Q\u00059Q\r]:jY>t\u0007C\u0001\t*\u0013\tQ\u0013C\u0001\u0004E_V\u0014G.\u001a\u0005\tY\u0001\u0011\t\u0011)A\u0005[\u0005i!m\u0019$fCR,(/Z:Ti\u0012\u00042AL\u00194\u001b\u0005y#B\u0001\u0019\t\u0003%\u0011'o\\1eG\u0006\u001cH/\u0003\u00023_\tI!I]8bI\u000e\f7\u000f\u001e\t\u0004!QB\u0013BA\u001b\u0012\u0005\u0015\t%O]1z\u0011!9\u0004A!A!\u0002\u0013A\u0014\u0001\u00042d!\u0006\u0014\u0018-\\3uKJ\u001c\bc\u0001\u00182sA\u0011!(P\u0007\u0002w)\u0011AHB\u0001\u0007Y&t\u0017\r\\4\n\u0005yZ$A\u0002,fGR|'\u000fC\u0003A\u0001\u0011\u0005\u0011)\u0001\u0004=S:LGO\u0010\u000b\u0005\u0005\u0012+e\t\u0006\u0002 \u0007\")qg\u0010a\u0001q!)\u0011e\u0010a\u0001G!)qe\u0010a\u0001Q!)Af\u0010a\u0001[!9\u0001\n\u0001b\u0001\n#J\u0015a\u00013j[V\t!\n\u0005\u0002\u0011\u0017&\u0011A*\u0005\u0002\u0004\u0013:$\bB\u0002(\u0001A\u0003%!*\u0001\u0003eS6\u0004\u0003b\u0002)\u0001\u0005\u0004%I!S\u0001\f]Vlg)Z1ukJ,7\u000f\u0003\u0004S\u0001\u0001\u0006IAS\u0001\r]Vlg)Z1ukJ,7\u000f\t\u0005\b)\u0002\u0011\r\u0011\"\u0003V\u0003\u0015\u0019\u0018nZ7b+\u0005A\u0003BB,\u0001A\u0003%\u0001&\u0001\u0004tS\u001el\u0017\r\t\u0005\b3\u0002\u0011\r\u0011\"\u0003V\u0003%Ig\u000e^3sG\u0016\u0004H\u000f\u0003\u0004\\\u0001\u0001\u0006I\u0001K\u0001\u000bS:$XM]2faR\u0004\u0003\"B/\u0001\t\u0003q\u0016aA1eIR\u0011qd\u0018\u0005\u0006Ar\u0003\r!G\u0001\tS:\u001cH/\u00198dK\u0002")
public class HuberAggregator
implements DifferentiableLossAggregator<Instance, HuberAggregator> {
    private final boolean fitIntercept;
    public final double org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon;
    private final Broadcast<double[]> bcFeaturesStd;
    private final Broadcast<Vector> bcParameters;
    private final int dim;
    private final int org$apache$spark$ml$optim$aggregator$HuberAggregator$$numFeatures;
    private final double sigma;
    private final double intercept;
    private double weightSum;
    private double lossSum;
    private final double[] gradientSumArray;
    private volatile boolean bitmap$0;

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    @TraitSetter
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    @TraitSetter
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        HuberAggregator huberAggregator = this;
        synchronized (huberAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator$class.gradientSumArray(this);
                this.bitmap$0 = true;
            }
            return this.gradientSumArray;
        }
    }

    @Override
    public double[] gradientSumArray() {
        return this.bitmap$0 ? this.gradientSumArray : this.gradientSumArray$lzycompute();
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator$class.merge(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator$class.gradient(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator$class.weight(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator$class.loss(this);
    }

    @Override
    public int dim() {
        return this.dim;
    }

    public int org$apache$spark$ml$optim$aggregator$HuberAggregator$$numFeatures() {
        return this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$numFeatures;
    }

    private double sigma() {
        return this.sigma;
    }

    private double intercept() {
        return this.intercept;
    }

    @Override
    public HuberAggregator add(Instance instance) {
        Instance instance2 = instance;
        if (instance2 != null) {
            double margin;
            double linearLoss;
            double label = instance2.label();
            double weight = instance2.weight();
            Vector features = instance2.features();
            Predef$.MODULE$.require(this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$numFeatures() == features.size(), (Function0)new Serializable(this, features){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ HuberAggregator $outer;
                private final Vector features$1;

                public final String apply() {
                    return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Dimensions mismatch when adding new sample."})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" Expecting ", " but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.org$apache$spark$ml$optim$aggregator$HuberAggregator$$numFeatures()), BoxesRunTime.boxToInteger((int)this.features$1.size())}))).toString();
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                    this.features$1 = features$1;
                }
            });
            Predef$.MODULE$.require(weight >= 0.0, (Function0)new Serializable(this, weight){
                public static final long serialVersionUID = 0L;
                private final double weight$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"instance weight, ", " has to be >= 0.0"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.weight$1)}));
                }
                {
                    this.weight$1 = weight$1;
                }
            });
            if (weight == 0.0) {
                return this;
            }
            double[] localFeaturesStd = (double[])this.bcFeaturesStd.value();
            double[] localCoefficients = (double[])Predef$.MODULE$.doubleArrayOps(((Vector)this.bcParameters.value()).toArray()).slice(0, this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$numFeatures());
            double[] localGradientSumArray = this.gradientSumArray();
            DoubleRef sum = DoubleRef.create((double)0.0);
            features.foreachActive((Function2)new Serializable(this, localFeaturesStd, localCoefficients, sum){
                public static final long serialVersionUID = 0L;
                private final double[] localFeaturesStd$1;
                private final double[] localCoefficients$1;
                private final DoubleRef sum$1;

                public final void apply(int index2, double value) {
                    this.apply$mcVID$sp(index2, value);
                }

                public void apply$mcVID$sp(int index2, double value) {
                    if (this.localFeaturesStd$1[index2] != 0.0 && value != 0.0) {
                        this.sum$1.elem += this.localCoefficients$1[index2] * (value / this.localFeaturesStd$1[index2]);
                    }
                }
                {
                    this.localFeaturesStd$1 = localFeaturesStd$1;
                    this.localCoefficients$1 = localCoefficients$1;
                    this.sum$1 = sum$1;
                }
            });
            if (this.fitIntercept) {
                sum.elem += this.intercept();
            }
            if (package$.MODULE$.abs(linearLoss = label - (margin = sum.elem)) <= this.sigma() * this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon) {
                this.lossSum_$eq(this.lossSum() + 0.5 * weight * (this.sigma() + package$.MODULE$.pow(linearLoss, 2.0) / this.sigma()));
                double linearLossDivSigma = linearLoss / this.sigma();
                features.foreachActive((Function2)new Serializable(this, weight, localFeaturesStd, localGradientSumArray, linearLossDivSigma){
                    public static final long serialVersionUID = 0L;
                    private final double weight$1;
                    private final double[] localFeaturesStd$1;
                    private final double[] localGradientSumArray$1;
                    private final double linearLossDivSigma$1;

                    public final void apply(int index2, double value) {
                        this.apply$mcVID$sp(index2, value);
                    }

                    public void apply$mcVID$sp(int index2, double value) {
                        if (this.localFeaturesStd$1[index2] != 0.0 && value != 0.0) {
                            this.localGradientSumArray$1[index2] = this.localGradientSumArray$1[index2] + -1.0 * this.weight$1 * this.linearLossDivSigma$1 * (value / this.localFeaturesStd$1[index2]);
                        }
                    }
                    {
                        this.weight$1 = weight$1;
                        this.localFeaturesStd$1 = localFeaturesStd$1;
                        this.localGradientSumArray$1 = localGradientSumArray$1;
                        this.linearLossDivSigma$1 = linearLossDivSigma$1;
                    }
                });
                if (this.fitIntercept) {
                    int n = this.dim() - 2;
                    localGradientSumArray[n] = localGradientSumArray[n] + -1.0 * weight * linearLossDivSigma;
                }
                int n = this.dim() - 1;
                localGradientSumArray[n] = localGradientSumArray[n] + 0.5 * weight * (1.0 - package$.MODULE$.pow(linearLossDivSigma, 2.0));
            } else {
                double sign = linearLoss >= 0.0 ? -1.0 : 1.0;
                this.lossSum_$eq(this.lossSum() + 0.5 * weight * (this.sigma() + 2.0 * this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon * package$.MODULE$.abs(linearLoss) - this.sigma() * this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon * this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon));
                features.foreachActive((Function2)new Serializable(this, weight, localFeaturesStd, localGradientSumArray, sign){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ HuberAggregator $outer;
                    private final double weight$1;
                    private final double[] localFeaturesStd$1;
                    private final double[] localGradientSumArray$1;
                    private final double sign$1;

                    public final void apply(int index2, double value) {
                        this.apply$mcVID$sp(index2, value);
                    }

                    public void apply$mcVID$sp(int index2, double value) {
                        if (this.localFeaturesStd$1[index2] != 0.0 && value != 0.0) {
                            this.localGradientSumArray$1[index2] = this.localGradientSumArray$1[index2] + this.weight$1 * this.sign$1 * this.$outer.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon * (value / this.localFeaturesStd$1[index2]);
                        }
                    }
                    {
                        if ($outer == null) {
                            throw null;
                        }
                        this.$outer = $outer;
                        this.weight$1 = weight$1;
                        this.localFeaturesStd$1 = localFeaturesStd$1;
                        this.localGradientSumArray$1 = localGradientSumArray$1;
                        this.sign$1 = sign$1;
                    }
                });
                if (this.fitIntercept) {
                    int n = this.dim() - 2;
                    localGradientSumArray[n] = localGradientSumArray[n] + weight * sign * this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon;
                }
                int n = this.dim() - 1;
                localGradientSumArray[n] = localGradientSumArray[n] + 0.5 * weight * (1.0 - this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon * this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon);
            }
            this.weightSum_$eq(this.weightSum() + weight);
            HuberAggregator huberAggregator = this;
            return huberAggregator;
        }
        throw new MatchError((Object)instance2);
    }

    public HuberAggregator(boolean fitIntercept, double epsilon, Broadcast<double[]> bcFeaturesStd, Broadcast<Vector> bcParameters) {
        this.fitIntercept = fitIntercept;
        this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$epsilon = epsilon;
        this.bcFeaturesStd = bcFeaturesStd;
        this.bcParameters = bcParameters;
        DifferentiableLossAggregator$class.$init$(this);
        this.dim = ((Vector)bcParameters.value()).size();
        this.org$apache$spark$ml$optim$aggregator$HuberAggregator$$numFeatures = fitIntercept ? this.dim() - 2 : this.dim() - 1;
        this.sigma = ((Vector)bcParameters.value()).apply(this.dim() - 1);
        this.intercept = fitIntercept ? ((Vector)bcParameters.value()).apply(this.dim() - 2) : 0.0;
    }
}

