/*
 * Decompiled with CFR 0.152.
 */
package ch.wsl.fps.knn.model;

import ch.wsl.fps.knn.gui.KnnExceptionHandler;
import ch.wsl.fps.knn.model.DataCell;
import ch.wsl.fps.knn.model.DataColumn;
import ch.wsl.fps.knn.model.DataRow;
import ch.wsl.fps.knn.model.TargetDataColumn;
import ch.wsl.fps.knn.model.TargetDataRow;
import java.util.Collections;
import java.util.List;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;

public class Calculator {
    public static void preCalcAvgAndStdDev(List<DataColumn> allDataCols) {
        for (DataColumn dataCol : allDataCols) {
            if (dataCol.getDataType().isAvailableForCalculation()) {
                Calculator.calcAvgAndStandardDeviation(dataCol);
                continue;
            }
            dataCol.setAvg(null);
            dataCol.setStdAbw(null);
        }
    }

    public static void preCalcCorrelationCoefficient(List<DataColumn> allDataCols, TargetDataColumn targetCol) {
        if (targetCol != null && targetCol.getDataType().isAvailableForCalculation()) {
            for (DataColumn dataCol : allDataCols) {
                if (dataCol.getDataType().isAvailableForCalculation()) {
                    Calculator.calcCorrelationCoefficient(dataCol, targetCol);
                    continue;
                }
                dataCol.setKorrelationsKoeff(null);
            }
        } else {
            for (DataColumn dataCol : allDataCols) {
                dataCol.setKorrelationsKoeff(null);
            }
        }
    }

    public static void calc(List<DataColumn> allDataCols, List<DataRow> allDataRows, TargetDataColumn targetCol, TargetDataRow targetRow, int k) {
        for (DataRow dataRow : allDataRows) {
            Calculator.calcDistanceForEntry(allDataCols, dataRow, targetRow);
        }
        Collections.sort(allDataRows, (e1, e2) -> (int)Math.signum(e1.getDistance() - e2.getDistance()));
        for (DataRow dataRow : allDataRows) {
            dataRow.setRank(allDataRows.indexOf(dataRow) + 1);
        }
        Calculator.calcTargetValueEstimation(k, allDataRows, targetCol, targetRow);
    }

    private static void calcAvgAndStandardDeviation(DataColumn dataCol) {
        SummaryStatistics stats = new SummaryStatistics();
        dataCol.getCells().stream().forEach(e -> stats.addValue(e.getCalcValue()));
        double avg = stats.getMean();
        double stdAbw = stats.getStandardDeviation();
        dataCol.setAvg(avg);
        dataCol.setStdAbw(stdAbw);
    }

    private static void calcCorrelationCoefficient(DataColumn dataCol, TargetDataColumn targetCol) {
        PearsonsCorrelation correlation = new PearsonsCorrelation();
        double[] refVarArray = dataCol.getCells().stream().mapToDouble(DataCell::getCalcValue).toArray();
        double[] targetVarArray = targetCol.getCells().stream().mapToDouble(DataCell::getCalcValue).toArray();
        double korrelationsKoeff = 0.0;
        try {
            korrelationsKoeff = correlation.correlation(refVarArray, targetVarArray);
        }
        catch (Exception e) {
            KnnExceptionHandler.logReduced(e);
        }
        dataCol.setKorrelationsKoeff(korrelationsKoeff);
    }

    private static void calcDistanceForEntry(List<DataColumn> allDataCols, DataRow dataRow, TargetDataRow targetRow) {
        double sum = 0.0;
        for (DataColumn dataCol : allDataCols) {
            if (!dataCol.isSelected()) continue;
            double korrelKoeff = dataCol.getKorrelationsKoeff();
            double stdAbw = dataCol.getStdAbw();
            double xRef = dataRow.getCell(dataCol).getCalcValue();
            double xTarget = targetRow.getCell(dataCol).getCalcValue();
            double value = Math.pow(korrelKoeff, 2.0) / Math.pow(stdAbw, 2.0) * Math.pow(xRef - xTarget, 2.0);
            sum += value;
        }
        double distance = Math.sqrt(sum);
        dataRow.setDistance(distance);
    }

    private static void calcTargetValueEstimation(int k, List<DataRow> allDataRows, TargetDataColumn targetCol, TargetDataRow targetRow) {
        String targetTitle = targetCol.getTitle();
        double sumOfKNearestDistances = 0.0;
        int i = 0;
        while (i < allDataRows.size() && i < k) {
            DataRow dataRow = allDataRows.get(i);
            sumOfKNearestDistances += 1.0 / (1.0 + dataRow.getDistance());
            ++i;
        }
        double sumOfWeightedTargetValues = 0.0;
        int i2 = 0;
        while (i2 < allDataRows.size() && i2 < k) {
            DataRow dataRow = allDataRows.get(i2);
            double targetValue = dataRow.getCell(targetTitle).getCalcValue();
            double weight = 1.0 / (1.0 + dataRow.getDistance());
            sumOfWeightedTargetValues += targetValue * weight;
            ++i2;
        }
        double avg = sumOfWeightedTargetValues / sumOfKNearestDistances;
        targetRow.getCell(targetTitle).setOriginalValue(avg);
    }
}

