package ca.pfv.spmf.algorithms.classifiers.general;

import ca.pfv.spmf.algorithms.classifiers.data.Dataset;
import ca.pfv.spmf.algorithms.classifiers.data.Instance;
import ca.pfv.spmf.algorithms.classifiers.data.VirtualDataset;
import ca.pfv.spmf.tools.MemoryLogger;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/general/Evaluator.class */
public class Evaluator {
    private boolean DEBUGMODE = false;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/general/Evaluator$ClassificationResults.class */
    public class ClassificationResults {
        ConfusionMatrix matrix = new ConfusionMatrix();
        List<Short> predictedClasses = new ArrayList();
        long runtime = 0;
        Double memory = Double.valueOf(0.0d);

        ClassificationResults() {
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/general/Evaluator$TrainingResults.class */
    public class TrainingResults {
        long runtime = 0;
        Double memory = Double.valueOf(0.0d);
        double avgRuleCount = 0.0d;

        TrainingResults() {
        }
    }

    private void runOnInstancesAnUpdateResults(Dataset dataset, Classifier classifier, ClassificationResults classificationResults) {
        MemoryLogger.getInstance().reset();
        long currentTimeMillis = System.currentTimeMillis();
        for (Instance instance : dataset.getInstances()) {
            short predict = classifier.predict(instance);
            short shortValue = instance.getKlass().shortValue();
            classificationResults.predictedClasses.add(Short.valueOf(predict));
            classificationResults.matrix.add(Short.valueOf(shortValue), Short.valueOf(predict));
        }
        classificationResults.runtime += System.currentTimeMillis() - currentTimeMillis;
        MemoryLogger.getInstance().checkMemory();
        classificationResults.memory = Double.valueOf(classificationResults.memory.doubleValue() + MemoryLogger.getInstance().getMaxMemory());
    }

    public OverallResults trainAndRunClassifiersHoldout(ClassificationAlgorithm[] classificationAlgorithmArr, Dataset dataset, double d) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (ClassificationAlgorithm classificationAlgorithm : classificationAlgorithmArr) {
            arrayList.add(classificationAlgorithm.getName());
        }
        OverallResults overallResults = new OverallResults(arrayList);
        if (d <= 0.0d || d >= 1.0d) {
            throw new RuntimeException("Sampling percentage must be in the range [0,1]");
        }
        Dataset[] splitDatasetForHoldout = VirtualDataset.splitDatasetForHoldout(dataset, d);
        Dataset dataset2 = splitDatasetForHoldout[0];
        Dataset dataset3 = splitDatasetForHoldout[1];
        if (this.DEBUGMODE) {
            System.out.println("===== HOLDOUT SAMPLING =====");
            System.out.println("Holdout sampling with percentage = " + d);
            System.out.println("  - Original dataset: " + dataset.getInstances().size() + " records.");
            System.out.println("  - Training part: " + dataset2.getInstances().size() + " records.");
            System.out.println("  - Testing part: " + dataset3.getInstances().size() + " records.");
            System.out.println("===== RUNNING =====");
        }
        for (ClassificationAlgorithm classificationAlgorithm2 : classificationAlgorithmArr) {
            if (this.DEBUGMODE) {
                System.out.println("Running algorithm ... " + classificationAlgorithm2.getName());
            }
            Classifier trainAndCalculateStats = classificationAlgorithm2.trainAndCalculateStats(dataset2);
            TrainingResults trainingResults = new TrainingResults();
            trainingResults.memory = Double.valueOf(trainingResults.memory.doubleValue() + classificationAlgorithm2.getTrainingMaxMemory());
            trainingResults.runtime += classificationAlgorithm2.getTrainingTime();
            if (trainAndCalculateStats instanceof RuleClassifier) {
                trainingResults.avgRuleCount += ((RuleClassifier) trainAndCalculateStats).getNumberRules();
            }
            ClassificationResults classificationResults = new ClassificationResults();
            runOnInstancesAnUpdateResults(dataset2, trainAndCalculateStats, classificationResults);
            ClassificationResults classificationResults2 = new ClassificationResults();
            runOnInstancesAnUpdateResults(dataset3, trainAndCalculateStats, classificationResults2);
            overallResults.addResults(classificationResults, classificationResults2, trainingResults);
        }
        return overallResults;
    }

    public OverallResults trainAndRunClassifiersKFold(ClassificationAlgorithm[] classificationAlgorithmArr, Dataset dataset, int i) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (ClassificationAlgorithm classificationAlgorithm : classificationAlgorithmArr) {
            arrayList.add(classificationAlgorithm.getName());
        }
        OverallResults overallResults = new OverallResults(arrayList);
        if (i < 2) {
            throw new RuntimeException("k needs to be 2 or more");
        }
        int ceil = (int) Math.ceil(dataset.getInstances().size() * (1.0d / i));
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = i2 * ceil;
            int i4 = i3 + ceil;
            if (i2 == i - 1) {
                i4 = dataset.getInstances().size();
            }
            Dataset[] splitDatasetForKFold = VirtualDataset.splitDatasetForKFold(dataset, i3, i4);
            Dataset dataset2 = splitDatasetForKFold[0];
            Dataset dataset3 = splitDatasetForKFold[1];
            if (this.DEBUGMODE) {
                System.out.println("===== KFOLD " + i2 + " =====");
                System.out.println(" k = " + i);
                System.out.println("  - Original dataset: " + dataset.getInstances().size() + " records.");
                System.out.println("  - Training part: " + dataset2.getInstances().size() + " records.");
                System.out.println("  - Testing part: " + dataset3.getInstances().size() + " records.");
                System.out.println("===== RUNNING =====");
            }
            for (ClassificationAlgorithm classificationAlgorithm2 : classificationAlgorithmArr) {
                if (this.DEBUGMODE) {
                    System.out.println("Running algorithm ... " + classificationAlgorithm2.getName());
                }
                Classifier trainAndCalculateStats = classificationAlgorithm2.trainAndCalculateStats(dataset3);
                TrainingResults trainingResults = new TrainingResults();
                trainingResults.memory = Double.valueOf(trainingResults.memory.doubleValue() + classificationAlgorithm2.getTrainingMaxMemory());
                trainingResults.runtime += classificationAlgorithm2.getTrainingTime();
                if (trainAndCalculateStats instanceof RuleClassifier) {
                    trainingResults.avgRuleCount += ((RuleClassifier) trainAndCalculateStats).getNumberRules() / i;
                }
                ClassificationResults classificationResults = new ClassificationResults();
                runOnInstancesAnUpdateResults(dataset2, trainAndCalculateStats, classificationResults);
                ClassificationResults classificationResults2 = new ClassificationResults();
                runOnInstancesAnUpdateResults(dataset3, trainAndCalculateStats, classificationResults2);
                overallResults.addResults(classificationResults, classificationResults2, trainingResults);
            }
        }
        return overallResults;
    }
}
