/*
 * Decompiled with CFR 0.152.
 */
package energy;

import java.io.FileReader;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import weka.classifiers.AbstractClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.UnsupportedClassTypeException;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;

public class ABMARS
extends AbstractClassifier
implements OptionHandler {
    public static final String userdir = System.getProperty("user.dir");
    private static final long serialVersionUID = 8956524807883202516L;
    private State s = new State();
    private Filter nomToBinFilter = null;

    public ABMARS() {
    }

    public ABMARS(int maxNumBases, int order) {
        this.s.maxNumBases = maxNumBases;
        this.s.maxOrder = order;
    }

    public String globalInfo() {
        return "";
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        double bestError;
        if (instances.classAttribute().isNominal() && instances.classAttribute().numValues() > 2) {
            throw new UnsupportedClassTypeException("This classifier only handles continuous output or two-class problems. For multiclass problems, please wrap MARSplines with MultiClassClassifier.");
        }
        Enumeration enu = instances.enumerateAttributes();
        while (enu.hasMoreElements()) {
            if (!((Attribute)enu.nextElement()).isNominal()) continue;
            System.out.println("Nominal attributes exist. We will convert nominal attributes to binary, but it is better to do that ahead of time.");
            this.nomToBinFilter = new NominalToBinary();
            this.nomToBinFilter.setInputFormat(instances);
            instances = Filter.useFilter(instances, this.nomToBinFilter);
            break;
        }
        if (this.s.maxNumBases == -1) {
            this.s.maxNumBases = instances.numInstances();
        }
        int numInstances = instances.numInstances();
        int numDimensions = instances.numAttributes() - 1;
        int maxBases = Math.min(this.s.maxNumBases * 2, numInstances);
        Matrix weights = null;
        BasisFunction[] bases = new BasisFunction[maxBases];
        int numBases = 1;
        Matrix basisValues = new Matrix(numInstances, 1);
        double[] tvArray = instances.attributeToDoubleArray(instances.classIndex());
        Matrix targetValues = new Matrix(tvArray, numInstances);
        bases[0] = BasisFunction.CONSTANT;
        double[][] dArray = basisValues.getArray();
        int n = dArray.length;
        int n2 = 0;
        while (n2 < n) {
            double[] b = dArray[n2];
            b[0] = 1.0;
            ++n2;
        }
        weights = basisValues.solve(targetValues);
        double squaredError = 0.0;
        double constantWeight = weights.get(0, 0);
        double[] dArray2 = tvArray;
        int n3 = tvArray.length;
        int n4 = 0;
        while (n4 < n3) {
            double tv = dArray2[n4];
            double error = tv - constantWeight;
            squaredError += error * error;
            ++n4;
        }
        ArrayList knots = new ArrayList(numDimensions);
        int d = 0;
        while (d < numDimensions) {
            double[] dimVals = instances.attributeToDoubleArray(d);
            HashSet<Double> dimSet = new HashSet<Double>(dimVals.length);
            double[] dArray3 = dimVals;
            int n5 = dimVals.length;
            int n6 = 0;
            while (n6 < n5) {
                double dv = dArray3[n6];
                dimSet.add(dv);
                ++n6;
            }
            knots.add(dimSet);
            ++d;
        }
        while (squaredError > 1.0E-6 * (double)numInstances && numBases < maxBases - 1) {
            RectifiedLinearBasisFunction bestBasis = new RectifiedLinearBasisFunction(null, null);
            bestBasis.sign = -1.0;
            Matrix workingBasisValues = new Matrix(numInstances, numBases + 2);
            workingBasisValues.setMatrix(0, numInstances - 1, 0, numBases - 1, basisValues);
            double[][] wbv = workingBasisValues.getArray();
            Matrix bestWeights = null;
            bestError = squaredError;
            boolean bestIncludesLeft = true;
            boolean bestIncludesRight = true;
            Matrix bestBasisValues = null;
            int b = 0;
            while (b < numBases) {
                if (bases[b].order() < this.s.maxOrder) {
                    int dim = 0;
                    while (dim < numDimensions) {
                        if (bases[b].dimension() != dim) {
                            Iterator iterator = ((Set)knots.get(dim)).iterator();
                            while (iterator.hasNext()) {
                                Matrix candidateWeights;
                                double knot = (Double)iterator.next();
                                Matrix wbvForKnot = workingBasisValues;
                                boolean hasLeft = false;
                                boolean hasRight = false;
                                int i = 0;
                                while (i < numInstances) {
                                    double bv = basisValues.get(i, b) * (instances.instance(i).value(dim) - knot);
                                    if (bv == 0.0) {
                                        wbv[i][numBases] = 0.0;
                                        wbv[i][numBases + 1] = 0.0;
                                    } else if (bv < 0.0) {
                                        wbv[i][numBases] = -bv;
                                        wbv[i][numBases + 1] = 0.0;
                                        hasLeft = true;
                                    } else {
                                        wbv[i][numBases] = 0.0;
                                        wbv[i][numBases + 1] = bv;
                                        hasRight = true;
                                    }
                                    ++i;
                                }
                                if (!hasLeft || !hasRight) continue;
                                if (!hasLeft) {
                                    wbvForKnot = workingBasisValues.getMatrix(0, numInstances - 1, 0, numBases);
                                    double[][] newWbv = wbvForKnot.getArray();
                                    int i2 = 0;
                                    while (i2 < numInstances) {
                                        newWbv[i2][numBases] = wbv[i2][numBases + 1];
                                        ++i2;
                                    }
                                }
                                if (!hasRight) {
                                    wbvForKnot = workingBasisValues.getMatrix(0, numInstances - 1, 0, numBases);
                                }
                                try {
                                    candidateWeights = wbvForKnot.solve(targetValues);
                                }
                                catch (RuntimeException e) {
                                    continue;
                                }
                                Matrix predictions = wbvForKnot.times(candidateWeights);
                                Matrix errors = targetValues.minus(predictions);
                                double thisError = 0.0;
                                int i3 = 0;
                                while (i3 < numInstances) {
                                    double error = errors.get(i3, 0);
                                    thisError += error * error;
                                    ++i3;
                                }
                                if (!(thisError < bestError)) continue;
                                bestBasis.dimension = dim;
                                bestBasis.knot = knot;
                                bestBasis.order = bases[b].order() + 1;
                                bestBasis.superOrdinateBasis = b;
                                bestWeights = candidateWeights;
                                bestError = thisError;
                                bestIncludesLeft = hasLeft;
                                bestIncludesRight = hasRight;
                                bestBasisValues = wbvForKnot;
                            }
                        }
                        ++dim;
                    }
                }
                ++b;
            }
            if (bestError == squaredError) break;
            int numNewBases = 2;
            if (!bestIncludesLeft || !bestIncludesRight) {
                numNewBases = 1;
                if (bestIncludesLeft) {
                    bases[numBases] = bestBasis;
                } else if (bestIncludesRight) {
                    bases[numBases] = bestBasis.createInverse();
                }
            } else {
                bases[numBases] = bestBasis;
                bases[numBases + 1] = bestBasis.createInverse();
            }
            bases[bestBasis.superOrdinateBasis].incrementNumDependants(numNewBases);
            weights = bestWeights;
            basisValues = bestBasisValues;
            int i = 0;
            while (i < numInstances) {
                Instance instance = instances.instance(i);
                wbv[i][numBases] = bases[numBases].response(instance, wbv[i]);
                if (numNewBases == 2) {
                    wbv[i][numBases + 1] = bases[numBases + 1].response(instance, wbv[i]);
                }
                ++i;
            }
            Matrix predictions = basisValues.times(weights);
            Matrix errors = targetValues.minus(predictions);
            squaredError = 0.0;
            int i4 = 0;
            while (i4 < numInstances) {
                double error = errors.get(i4, 0);
                squaredError += error * error;
                ++i4;
            }
            if (squaredError != bestError) {
                throw new RuntimeException("squaredError!=bestError");
            }
            numBases += numNewBases;
        }
        double[][] basisValueArray = basisValues.getArray();
        int numBasesAfterPruning = numBases;
        boolean[] isPrunedBase = new boolean[numBases];
        while (true) {
            int bestBasis = -1;
            bestError = Double.MAX_VALUE;
            Matrix bestWeights = null;
            int b = 1;
            while (b < numBases) {
                if (!isPrunedBase[b]) {
                    Matrix workingBasisValues = new Matrix(numInstances, numBasesAfterPruning - 1);
                    double[][] wbv = workingBasisValues.getArray();
                    int wbvi = 0;
                    int bb = 0;
                    while (bb < numBases) {
                        if (bb != b && !isPrunedBase[bb]) {
                            int i = 0;
                            while (i < numInstances) {
                                wbv[i][wbvi] = basisValueArray[i][bb];
                                ++i;
                            }
                            ++wbvi;
                        }
                        ++bb;
                    }
                    Matrix candidateWeights = workingBasisValues.solve(targetValues);
                    Matrix predictions = workingBasisValues.times(candidateWeights);
                    Matrix errors = targetValues.minus(predictions);
                    double thisError = 0.0;
                    int i = 0;
                    while (i < numInstances) {
                        double error = errors.get(i, 0);
                        thisError += error * error;
                        ++i;
                    }
                    if (thisError < bestError) {
                        bestBasis = b;
                        bestError = thisError;
                        bestWeights = candidateWeights;
                    }
                }
                ++b;
            }
            if (bestError > squaredError && numBasesAfterPruning <= this.s.maxNumBases) break;
            isPrunedBase[bestBasis] = true;
            int superOrdinateBasis = ((RectifiedLinearBasisFunction)bases[bestBasis]).superOrdinateBasis;
            bases[superOrdinateBasis].decrementNumDependants();
            --numBasesAfterPruning;
            squaredError = bestError;
            weights = bestWeights;
        }
        this.s.bases = new BasisFunction[numBases];
        System.arraycopy(bases, 0, this.s.bases, 0, numBases);
        this.s.weights = new double[numBases];
        int bb = 0;
        int b = 0;
        while (b < numBases) {
            if (!isPrunedBase[b]) {
                ((State)this.s).weights[b] = weights.get(bb, 0);
                ++bb;
            } else if (bases[b].numDependants() == 0) {
                ((State)this.s).bases[b] = null;
            }
            ++b;
        }
    }

    protected static double outputForInstance(BasisFunction[] bases, double[] weights, Instance instance) {
        int numBases = bases.length;
        double[] basisValues = new double[numBases];
        double output = 0.0;
        int b = 0;
        while (b < numBases) {
            if (bases[b] != null) {
                basisValues[b] = bases[b].response(instance, basisValues);
                output += weights[b] * basisValues[b];
            }
            ++b;
        }
        return output;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        boolean isNumeric = instance.classAttribute().isNumeric();
        if (this.nomToBinFilter != null) {
            this.nomToBinFilter.input(instance);
            instance = this.nomToBinFilter.output();
        }
        double output = ABMARS.outputForInstance(this.s.bases, this.s.weights, instance);
        if (isNumeric) {
            return new double[]{output};
        }
        double[] dist = new double[]{1.0 - output, output};
        return dist;
    }

    @Override
    public Enumeration listOptions() {
        return super.listOptions();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        return super.getOptions();
    }

    public int getNumBases() {
        int numBases = 0;
        double[] dArray = this.s.weights;
        int n = dArray.length;
        int n2 = 0;
        while (n2 < n) {
            double w = dArray[n2];
            if (w != 0.0) {
                ++numBases;
            }
            ++n2;
        }
        return numBases;
    }

    public int getMaxNumBases() {
        return this.s.maxNumBases;
    }

    public void setMaxNumBases(int m) {
        this.s.maxNumBases = m;
    }

    public int getMaxOrder() {
        return this.s.maxOrder;
    }

    public void setMaxOrder(int o) {
        this.s.maxOrder = o;
    }

    public String toString() {
        StringBuffer sb = new StringBuffer(this.s.bases.length * 100 + this.s.weights.length * 10 + 250);
        sb.append("MARSplines{ numBases=");
        sb.append(this.getNumBases());
        sb.append(", bases=[\n");
        int i = 0;
        while (i < this.s.weights.length) {
            double w = this.s.weights[i];
            if (w != 0.0) {
                sb.append("index=");
                sb.append(i);
                sb.append(", weight=");
                sb.append(w);
                sb.append(", basis=");
                sb.append(this.s.bases[i]);
                sb.append('\n');
            }
            ++i;
        }
        sb.append("]}");
        return sb.toString();
    }

    public static void main(String[] argv) throws Exception {
        ABMARS gab = new ABMARS(4, 2);
        String filename = String.valueOf(userdir) + "/resources/instances_numeric";
        try {
            FileReader reader = new FileReader(filename);
            Instances instances = new Instances(reader);
            int inst_i = 0;
            while (inst_i < instances.size()) {
                reader = new FileReader(filename);
                Instances temp_instances = new Instances(reader);
                Instance to_classify = temp_instances.remove(inst_i);
                temp_instances.setClassIndex(instances.numAttributes() - 1);
                gab.buildClassifier(temp_instances);
                System.out.println("==================");
                System.out.println("Prediction " + gab.classifyInstance(to_classify));
                System.out.println("Actual " + to_classify.value(temp_instances.numAttributes() - 1));
                System.out.println("==================");
                ++inst_i;
            }
        }
        catch (Exception e) {
            System.err.println(e.getLocalizedMessage());
            System.exit(0);
        }
    }

    protected static abstract class BasisFunction {
        private int numDeps = 0;
        static final BasisFunction CONSTANT = new ConstantBasisFunction();

        protected BasisFunction() {
        }

        public int numDependants() {
            return this.numDeps;
        }

        public int decrementNumDependants() {
            --this.numDeps;
            return this.numDeps;
        }

        public int incrementNumDependants(int i) {
            this.numDeps += i;
            return this.numDeps;
        }

        public abstract int dimension();

        public abstract int order();

        public abstract double response(Instance var1, double[] var2);
    }

    protected static class ConstantBasisFunction
    extends BasisFunction
    implements Serializable {
        private static final long serialVersionUID = -1053838349543019617L;

        protected ConstantBasisFunction() {
        }

        @Override
        public double response(Instance instance, double[] basisValues) {
            return 1.0;
        }

        @Override
        public int dimension() {
            return -1;
        }

        @Override
        public int order() {
            return 0;
        }

        public String toString() {
            return "CONSTANT";
        }
    }

    protected static class RectifiedLinearBasisFunction
    extends BasisFunction
    implements Serializable {
        private static final long serialVersionUID = 7701735949852762393L;
        private int order;
        private double sign = 1.0;
        private int dimension;
        private double knot;
        private int superOrdinateBasis;

        private RectifiedLinearBasisFunction() {
        }

        private RectifiedLinearBasisFunction(RectifiedLinearBasisFunction other) {
            this();
            this.order = other.order;
            this.sign = other.sign;
            this.dimension = other.dimension;
            this.knot = other.knot;
            this.superOrdinateBasis = other.superOrdinateBasis;
        }

        @Override
        public int dimension() {
            return this.dimension;
        }

        @Override
        public int order() {
            return this.order;
        }

        @Override
        public double response(Instance instance, double[] basisValues) {
            if (basisValues[this.superOrdinateBasis] == 0.0) {
                return 0.0;
            }
            double b = this.sign * (instance.value(this.dimension) - this.knot) * basisValues[this.superOrdinateBasis];
            return b > 0.0 ? b : 0.0;
        }

        public RectifiedLinearBasisFunction createInverse() {
            RectifiedLinearBasisFunction inverse = new RectifiedLinearBasisFunction(this);
            inverse.sign = -this.sign;
            return inverse;
        }

        public String toString() {
            return "RectifiedLinear{order=" + this.order + ",sign=" + this.sign + ",dim=" + this.dimension + ",knot=" + this.knot + ",super=" + this.superOrdinateBasis + "}";
        }

        /* synthetic */ RectifiedLinearBasisFunction(RectifiedLinearBasisFunction rectifiedLinearBasisFunction, RectifiedLinearBasisFunction rectifiedLinearBasisFunction2) {
            this();
        }
    }

    private static class State
    implements Serializable {
        private static final long serialVersionUID = -8976623443712084405L;
        private int maxNumBases = -1;
        private int maxOrder = 2;
        private double[] weights;
        private BasisFunction[] bases;

        private State() {
        }
    }
}

