/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.sql;

import java.util.ArrayList;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class JavaUserDefinedUntypedAggregation {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("Java Spark SQL user-defined DataFrames aggregation example").getOrCreate();
        spark.udf().register("myAverage", (UserDefinedAggregateFunction)new MyAverage());
        Dataset df = spark.read().json("examples/src/main/resources/employees.json");
        df.createOrReplaceTempView("employees");
        df.show();
        Dataset result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
        result.show();
        spark.stop();
    }

    public static class MyAverage
    extends UserDefinedAggregateFunction {
        private StructType inputSchema;
        private StructType bufferSchema;

        public MyAverage() {
            ArrayList<StructField> inputFields = new ArrayList<StructField>();
            inputFields.add(DataTypes.createStructField((String)"inputColumn", (DataType)DataTypes.LongType, (boolean)true));
            this.inputSchema = DataTypes.createStructType(inputFields);
            ArrayList<StructField> bufferFields = new ArrayList<StructField>();
            bufferFields.add(DataTypes.createStructField((String)"sum", (DataType)DataTypes.LongType, (boolean)true));
            bufferFields.add(DataTypes.createStructField((String)"count", (DataType)DataTypes.LongType, (boolean)true));
            this.bufferSchema = DataTypes.createStructType(bufferFields);
        }

        public StructType inputSchema() {
            return this.inputSchema;
        }

        public StructType bufferSchema() {
            return this.bufferSchema;
        }

        public DataType dataType() {
            return DataTypes.DoubleType;
        }

        public boolean deterministic() {
            return true;
        }

        public void initialize(MutableAggregationBuffer buffer) {
            buffer.update(0, (Object)0L);
            buffer.update(1, (Object)0L);
        }

        public void update(MutableAggregationBuffer buffer, Row input) {
            if (!input.isNullAt(0)) {
                long updatedSum = buffer.getLong(0) + input.getLong(0);
                long updatedCount = buffer.getLong(1) + 1L;
                buffer.update(0, (Object)updatedSum);
                buffer.update(1, (Object)updatedCount);
            }
        }

        public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
            long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
            long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
            buffer1.update(0, (Object)mergedSum);
            buffer1.update(1, (Object)mergedCount);
        }

        public Double evaluate(Row buffer) {
            return (double)buffer.getLong(0) / (double)buffer.getLong(1);
        }
    }
}

