package ca.pfv.spmf.algorithms.classifiers.naive_bayes_text_classifier;

import ca.pfv.spmf.tools.MemoryLogger;
import ca.pfv.spmf.tools.textprocessing.PorterStemmer;
import ca.pfv.spmf.tools.textprocessing.StopWordAnalyzer;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/* loaded from: input_file:ca/pfv/spmf/algorithms/classifiers/naive_bayes_text_classifier/AlgoNaiveBayesClassifier.class */
public class AlgoNaiveBayesClassifier {
    private ArrayList<String> mClassNames;
    private StopWordAnalyzer mAnalyzer;
    private PorterStemmer mStemmer;
    HashMap<String, Integer> classProb;
    private String mTestDataDirectory = "";
    private String mTrainingDataDirectory = "";
    private boolean mInMemoryFlag = false;
    private HashMap<String, List<File>> mFileLists = new HashMap<>();
    private String mOutputDirectory = "";
    private ArrayList<MemoryFile> mMemFiles = new ArrayList<>();
    long mStartTimestamp = 0;
    long mEndTimeStamp = 0;

    public void runAlgorithm(String str, String str2, String str3, boolean z) throws Exception {
        this.mTrainingDataDirectory = str;
        this.mTestDataDirectory = str2;
        this.mOutputDirectory = str3;
        this.mInMemoryFlag = z;
        runAlgorithm();
        Runtime.getRuntime().freeMemory();
    }

    private void runAlgorithm() throws Exception {
        double calculateProbabilityInMemory;
        this.mStartTimestamp = System.currentTimeMillis();
        this.mAnalyzer = new StopWordAnalyzer();
        this.mStemmer = new PorterStemmer();
        this.classProb = new HashMap<>();
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(this.mOutputDirectory + "/output.tsv")));
        ArrayList<OccurrenceProbabilties> arrayList = new ArrayList<>();
        File[] listFiles = new File(this.mTestDataDirectory).listFiles();
        File[] listFiles2 = new File(this.mTrainingDataDirectory).listFiles();
        this.mClassNames = new ArrayList<>();
        int i = 0;
        for (File file : listFiles2) {
            this.mClassNames.add(file.getName());
            OccurrenceProbabilties occurrenceProbabilties = new OccurrenceProbabilties();
            occurrenceProbabilties.setClassName(file.getName());
            occurrenceProbabilties.setOccuranceMap(new HashMap<>());
            arrayList.add(occurrenceProbabilties);
            File[] listFiles3 = new File(this.mTrainingDataDirectory + "/" + file.getName()).listFiles();
            this.mFileLists.put(file.getName(), Arrays.asList(listFiles3));
            this.classProb.put(file.getName(), Integer.valueOf(listFiles3.length));
            i++;
        }
        if (this.mInMemoryFlag) {
            System.out.println("Loading Data in to memory.... May take a while depending upon the size of the data");
            loadIntoMemory();
        }
        for (File file2 : listFiles) {
            TreeMap treeMap = new TreeMap();
            System.out.println("---------------Computing for Test File:" + file2.getName() + "-----------");
            Iterator<String> it = this.mClassNames.iterator();
            while (it.hasNext()) {
                String next = it.next();
                TestRecord readOneTestFile = readOneTestFile(file2);
                BigDecimal bigDecimal = new BigDecimal("1.0");
                Iterator<String> it2 = readOneTestFile.getWords().iterator();
                while (it2.hasNext()) {
                    String next2 = it2.next();
                    if (getFromExistingProbability(next2, arrayList, next) == 0.0d) {
                        calculateProbabilityInMemory = this.mInMemoryFlag ? calculateProbabilityInMemory(next2, arrayList, next) : calculateProbability(next2, arrayList, next);
                        Iterator<OccurrenceProbabilties> it3 = arrayList.iterator();
                        while (true) {
                            if (!it3.hasNext()) {
                                break;
                            }
                            OccurrenceProbabilties next3 = it3.next();
                            if (next3.getClassName().equalsIgnoreCase(next)) {
                                next3.getOccuranceMap().put(next2, Double.valueOf(calculateProbabilityInMemory));
                                break;
                            }
                        }
                    } else {
                        calculateProbabilityInMemory = getFromExistingProbability(next2, arrayList, next);
                    }
                    bigDecimal = bigDecimal.multiply(new BigDecimal(calculateProbabilityInMemory));
                }
                treeMap.put(next, bigDecimal.multiply(new BigDecimal((this.classProb.get(next).intValue() / i))));
            }
            Map.Entry entry = null;
            for (Map.Entry entry2 : treeMap.entrySet()) {
                if (entry == null || ((BigDecimal) entry2.getValue()).compareTo((BigDecimal) entry.getValue()) > 0) {
                    entry = entry2;
                }
            }
            System.out.println(file2.getName() + "\t" + ((String) entry.getKey()));
            bufferedWriter.write(file2.getName() + "\t" + ((String) entry.getKey()) + "\n");
        }
        bufferedWriter.close();
        this.mEndTimeStamp = System.currentTimeMillis();
    }

    private void loadIntoMemory() throws IOException {
        Iterator<String> it = this.mClassNames.iterator();
        while (it.hasNext()) {
            String next = it.next();
            List<File> list = this.mFileLists.get(next);
            MemoryFile memoryFile = new MemoryFile();
            ArrayList<String> arrayList = new ArrayList<>();
            memoryFile.setClassname(next);
            Iterator<File> it2 = list.iterator();
            while (it2.hasNext()) {
                BufferedReader bufferedReader = new BufferedReader(new FileReader(it2.next()));
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    for (String str : this.mAnalyzer.removeStopWords(readLine.replaceAll("\\P{L}", " ").toLowerCase().replaceAll("\n", " ").replaceAll("\\s+", " ")).split("\\s+")) {
                        String stem = this.mStemmer.stem(str);
                        if (stem.length() > 1) {
                            arrayList.add(stem);
                        }
                    }
                }
                bufferedReader.close();
            }
            memoryFile.setContent(arrayList);
            this.mMemFiles.add(memoryFile);
        }
    }

    private double calculateProbabilityInMemory(String str, ArrayList<OccurrenceProbabilties> arrayList, String str2) {
        int i = 0;
        int i2 = 0;
        Iterator<MemoryFile> it = this.mMemFiles.iterator();
        while (it.hasNext()) {
            MemoryFile next = it.next();
            if (next.getClassname().equals(str2)) {
                i2 += Collections.frequency(next.getContent(), str) * 50;
                i += next.getContent().size();
            }
        }
        return (i2 + 50.0d) / (i + 100.0d);
    }

    private double calculateProbability(String str, ArrayList<OccurrenceProbabilties> arrayList, String str2) throws Exception {
        List<File> list = this.mFileLists.get(str2);
        ArrayList arrayList2 = new ArrayList();
        double d = 0.0d;
        Iterator<File> it = list.iterator();
        while (it.hasNext()) {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(it.next()));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                for (String str3 : this.mAnalyzer.removeStopWords(readLine.replaceAll("\\P{L}", " ").toLowerCase().replaceAll("\n", " ").replaceAll("\\s+", " ")).split("\\s+")) {
                    String stem = this.mStemmer.stem(str3);
                    if (stem.length() > 1) {
                        arrayList2.add(stem);
                    }
                    if (stem.equalsIgnoreCase(str)) {
                        d += 20.0d;
                    }
                }
            }
            bufferedReader.close();
        }
        return (d + 50.0d) / (arrayList2.size() + 100.0d);
    }

    public double getFromExistingProbability(String str, ArrayList<OccurrenceProbabilties> arrayList, String str2) {
        double d = 0.0d;
        Iterator<OccurrenceProbabilties> it = arrayList.iterator();
        while (it.hasNext()) {
            OccurrenceProbabilties next = it.next();
            if (next.getClassName().equals(str2)) {
                for (String str3 : next.getOccuranceMap().keySet()) {
                    if (next.getOccuranceMap().get(str3) != null && str3.equals(str)) {
                        d = next.getOccuranceMap().get(str3).doubleValue();
                    }
                }
            }
        }
        return d;
    }

    public TestRecord readOneTestFile(File file) throws Exception {
        TestRecord testRecord = new TestRecord();
        ArrayList<String> arrayList = new ArrayList<>();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                testRecord.setRecordId(Integer.parseInt(file.getName().replaceAll("\\D+", "")));
                testRecord.setWords(arrayList);
                bufferedReader.close();
                return testRecord;
            }
            for (String str : this.mAnalyzer.removeStopWords(readLine.toLowerCase().replaceAll("\\P{L}", " ").replaceAll("\n", " ").replaceAll("\\s+", " ").trim()).split("\\s+")) {
                String stem = this.mStemmer.stem(str);
                if (stem.length() > 1) {
                    arrayList.add(stem);
                }
            }
        }
    }

    public void printStatistics() {
        System.out.println("========== Naive Bayes Classifier Stats ============");
        System.out.println(" Total time ~: " + (this.mEndTimeStamp - this.mStartTimestamp) + " ms");
        System.out.println(" Max memory:" + MemoryLogger.getInstance().getMaxMemory() + " mb ");
        System.out.println("=====================================");
    }
}
