package ca.pfv.spmf.algorithms.classifiers.logisticregression;

import java.util.ArrayList;
import java.util.Arrays;

/* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/logisticregression/MainTestLogisticRegression.class */
public class MainTestLogisticRegression {
    public static void main(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new InstanceContinuous(new double[]{5.0d, 100.0d, 300.0d}));
        arrayList.add(new InstanceContinuous(new double[]{5.0d, 130.0d, 400.0d}));
        arrayList.add(new InstanceContinuous(new double[]{10.0d, 200.0d, 600.0d}));
        arrayList.add(new InstanceContinuous(new double[]{-1.0d, 10.0d, 60.0d}));
        arrayList.add(new InstanceContinuous(new double[]{-6.0d, 3.0d, 60.0d}));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(false);
        arrayList2.add(false);
        arrayList2.add(false);
        arrayList2.add(true);
        arrayList2.add(true);
        AlgoBinaryLogisticRegression algoBinaryLogisticRegression = new AlgoBinaryLogisticRegression();
        algoBinaryLogisticRegression.setIterationCount(11000);
        algoBinaryLogisticRegression.setLearningRate(0.005d);
        algoBinaryLogisticRegression.train(arrayList, arrayList2);
        algoBinaryLogisticRegression.printStats();
        InstanceContinuous instanceContinuous = new InstanceContinuous(new double[]{663.0d, 700.0d, 900.0d});
        InstanceContinuous instanceContinuous2 = new InstanceContinuous(new double[]{-1.0d, 0.0d, 3.0d});
        System.out.println("Prediction instance 1: " + algoBinaryLogisticRegression.predictBoolean((InstanceContinuous) arrayList.get(0)) + " probability: " + algoBinaryLogisticRegression.predictDouble((InstanceContinuous) arrayList.get(0)));
        System.out.println("Prediction instance 2: " + algoBinaryLogisticRegression.predictBoolean((InstanceContinuous) arrayList.get(1)) + " probability: " + algoBinaryLogisticRegression.predictDouble((InstanceContinuous) arrayList.get(1)));
        System.out.println("Prediction instance 3: " + algoBinaryLogisticRegression.predictBoolean((InstanceContinuous) arrayList.get(2)) + " probability: " + algoBinaryLogisticRegression.predictDouble((InstanceContinuous) arrayList.get(2)));
        System.out.println("Prediction instance 4: " + algoBinaryLogisticRegression.predictBoolean((InstanceContinuous) arrayList.get(3)) + " probability: " + algoBinaryLogisticRegression.predictDouble((InstanceContinuous) arrayList.get(3)));
        System.out.println("Prediction instance 5: " + algoBinaryLogisticRegression.predictBoolean((InstanceContinuous) arrayList.get(4)) + " probability: " + algoBinaryLogisticRegression.predictDouble((InstanceContinuous) arrayList.get(4)));
        System.out.println("Prediction instance 6: " + algoBinaryLogisticRegression.predictBoolean(instanceContinuous) + " probability: " + algoBinaryLogisticRegression.predictDouble(instanceContinuous));
        System.out.println("Prediction instance 7: " + algoBinaryLogisticRegression.predictBoolean(instanceContinuous2) + " probability: " + algoBinaryLogisticRegression.predictDouble(instanceContinuous2));
        System.out.println("weights " + Arrays.toString(algoBinaryLogisticRegression.weights));
    }
}
