package ca.pfv.spmf.algorithms.classifiers.decisiontree.id3;

import ca.pfv.spmf.algorithms.classifiers.data.Dataset;
import ca.pfv.spmf.algorithms.classifiers.data.Instance;
import ca.pfv.spmf.algorithms.classifiers.general.ClassificationAlgorithm;
import ca.pfv.spmf.algorithms.classifiers.general.Classifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/decisiontree/id3/AlgoID3.class */
public class AlgoID3 extends ClassificationAlgorithm {
    private List<Short> targetAttributeValues;
    private long startTime;
    private long endTime;

    private Node id3(int[] iArr, List<Instance> list, Map<Short, Long> map) {
        Map<Short, Long> calculateFrequencyOfClassValues = map == null ? calculateFrequencyOfClassValues(list) : map;
        if (iArr.length == 0) {
            long j = 0;
            Short sh = null;
            for (Map.Entry<Short, Long> entry : calculateFrequencyOfClassValues.entrySet()) {
                if (entry.getValue().longValue() > j) {
                    j = entry.getValue().longValue();
                    sh = entry.getKey();
                }
            }
            ClassNode classNode = new ClassNode();
            classNode.className = sh;
            return classNode;
        }
        if (calculateFrequencyOfClassValues.entrySet().size() == 1) {
            ClassNode classNode2 = new ClassNode();
            classNode2.className = calculateFrequencyOfClassValues.entrySet().iterator().next().getKey();
            return classNode2;
        }
        double d = 0.0d;
        Iterator<Short> it = this.targetAttributeValues.iterator();
        while (it.hasNext()) {
            if (calculateFrequencyOfClassValues.get(it.next()) != null) {
                double longValue = r0.longValue() / list.size();
                d -= (longValue * Math.log(longValue)) / Math.log(2.0d);
            }
        }
        int i = 0;
        double d2 = -99999.0d;
        for (int i2 : iArr) {
            double calculateGain = calculateGain(i2, list, d);
            if (calculateGain >= d2) {
                d2 = calculateGain;
                i = i2;
            }
        }
        if (d2 == 0.0d) {
            ClassNode classNode3 = new ClassNode();
            long j2 = 0;
            Short sh2 = null;
            for (Map.Entry<Short, Long> entry2 : calculateFrequencyOfClassValues.entrySet()) {
                if (entry2.getValue().longValue() > j2) {
                    j2 = entry2.getValue().longValue();
                    sh2 = entry2.getKey();
                }
            }
            classNode3.className = sh2;
            return classNode3;
        }
        DecisionNode decisionNode = new DecisionNode();
        decisionNode.attribute = i;
        int[] iArr2 = new int[iArr.length - 1];
        int i3 = 0;
        for (int i4 = 0; i4 < iArr.length; i4++) {
            if (iArr[i4] != i) {
                int i5 = i3;
                i3++;
                iArr2[i5] = iArr[i4];
            }
        }
        HashMap hashMap = new HashMap();
        for (Instance instance : list) {
            Short sh3 = instance.getItems()[i];
            List list2 = (List) hashMap.get(sh3);
            if (list2 == null) {
                list2 = new ArrayList();
                hashMap.put(sh3, list2);
            }
            list2.add(instance);
        }
        decisionNode.nodes = new Node[hashMap.size()];
        decisionNode.attributeValues = new Short[hashMap.size()];
        int i6 = 0;
        for (Map.Entry entry3 : hashMap.entrySet()) {
            decisionNode.attributeValues[i6] = (Short) entry3.getKey();
            decisionNode.nodes[i6] = id3(iArr2, (List) entry3.getValue(), null);
            i6++;
        }
        return decisionNode;
    }

    private double calculateGain(int i, List<Instance> list, double d) {
        double d2 = 0.0d;
        Iterator<Map.Entry<Short, Long>> it = calculateFrequencyOfAttributeValues(list, i).entrySet().iterator();
        while (it.hasNext()) {
            d2 += (r0.getValue().longValue() / list.size()) * calculateEntropyIfValue(list, i, it.next().getKey());
        }
        return d - d2;
    }

    private double calculateEntropyIfValue(List<Instance> list, int i, Short sh) {
        int i2 = 0;
        HashMap hashMap = new HashMap();
        for (Instance instance : list) {
            if (instance.getItems()[i].equals(sh)) {
                Short klass = instance.getKlass();
                if (hashMap.get(klass) == null) {
                    hashMap.put(klass, 1);
                } else {
                    hashMap.put(klass, Integer.valueOf(((Integer) hashMap.get(klass)).intValue() + 1));
                }
                i2++;
            }
        }
        double d = 0.0d;
        Iterator<Short> it = this.targetAttributeValues.iterator();
        while (it.hasNext()) {
            if (((Integer) hashMap.get(it.next())) != null) {
                double intValue = r0.intValue() / i2;
                d -= (intValue * Math.log(intValue)) / Math.log(2.0d);
            }
        }
        return d;
    }

    private Map<Short, Long> calculateFrequencyOfAttributeValues(List<Instance> list, int i) {
        HashMap hashMap = new HashMap();
        Iterator<Instance> it = list.iterator();
        while (it.hasNext()) {
            Short sh = it.next().getItems()[i];
            if (hashMap.get(sh) == null) {
                hashMap.put(sh, 1L);
            } else {
                hashMap.put(sh, Long.valueOf(((Long) hashMap.get(sh)).longValue() + 1));
            }
        }
        return hashMap;
    }

    private Map<Short, Long> calculateFrequencyOfClassValues(List<Instance> list) {
        HashMap hashMap = new HashMap();
        Iterator<Instance> it = list.iterator();
        while (it.hasNext()) {
            Short klass = it.next().getKlass();
            if (hashMap.get(klass) == null) {
                hashMap.put(klass, 1L);
            } else {
                hashMap.put(klass, Long.valueOf(((Long) hashMap.get(klass)).longValue() + 1));
            }
        }
        return hashMap;
    }

    public void printStatistics() {
        System.out.println("Time to construct decision tree = " + (this.endTime - this.startTime) + " ms");
        System.out.println();
    }

    @Override // ca.pfv.spmf.algorithms.classifiers.general.ClassificationAlgorithm
    protected Classifier train(Dataset dataset) {
        this.startTime = System.currentTimeMillis();
        DecisionTree decisionTree = new DecisionTree(dataset.getMapItemToString(), dataset.getAttributes());
        int[] iArr = new int[dataset.getAttributes().size()];
        for (int i = 0; i < dataset.getAttributes().size(); i++) {
            iArr[i] = i;
        }
        this.targetAttributeValues = dataset.getListOfClassValues();
        decisionTree.root = id3(iArr, dataset.getInstances(), dataset.getMapClassToFrequency());
        this.endTime = System.currentTimeMillis();
        return new ClassifierID3(decisionTree);
    }

    @Override // ca.pfv.spmf.algorithms.classifiers.general.ClassificationAlgorithm
    public String getName() {
        return "ID3";
    }
}
