/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.rcf;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.AnomalyDescriptor;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TimeZone;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.ColumnType;
import org.opensearch.ml.common.dataframe.ColumnValue;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.TrainAndPredictable;
import org.opensearch.ml.engine.algorithms.rcf.RCFModelSerDeSer;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.utils.ModelSerDeSer;

@Function(value=FunctionName.FIT_RCF)
public class FixedInTimeRandomCutForest
implements TrainAndPredictable {
    @Generated
    private static final Logger log = LogManager.getLogger(FixedInTimeRandomCutForest.class);
    public static final String VERSION = "1.0.0";
    private static final int DEFAULT_NUMBER_OF_TREES = 30;
    private static final int DEFAULT_SHINGLE_SIZE = 8;
    private static final int DEFAULT_OUTPUT_AFTER = 32;
    private static final int DEFAULT_SAMPLES_SIZE = 256;
    private static final double DEFAULT_TIME_DECAY = 1.0E-4;
    private static final double DEFAULT_ANOMALY_RATE = 0.005;
    private static final String DEFAULT_TIME_FIELD = "timestamp";
    private static final String DEFAULT_DATE_FORMAT = "yyyy-MM-dd HH:mm:ss";
    private static final String DEFAULT_TIME_ZONE = "UTC";
    private Integer numberOfTrees;
    private Integer shingleSize;
    private Integer sampleSize;
    private Integer outputAfter;
    private Double timeDecay;
    private Double anomalyRate;
    private String timeField;
    private String dateFormat;
    private String timeZone;
    private DateFormat simpleDateFormat;
    private static final ThresholdedRandomCutForestMapper trcfMapper = new ThresholdedRandomCutForestMapper();
    private ThresholdedRandomCutForest forest;

    public FixedInTimeRandomCutForest() {
    }

    public FixedInTimeRandomCutForest(MLAlgoParams parameters) {
        FitRCFParams rcfParams = parameters == null ? FitRCFParams.builder().build() : (FitRCFParams)parameters;
        this.numberOfTrees = Optional.ofNullable(rcfParams.getNumberOfTrees()).orElse(30);
        this.shingleSize = Optional.ofNullable(rcfParams.getShingleSize()).orElse(8);
        this.sampleSize = Optional.ofNullable(rcfParams.getSampleSize()).orElse(256);
        this.outputAfter = Optional.ofNullable(rcfParams.getOutputAfter()).orElse(32);
        this.timeDecay = Optional.ofNullable(rcfParams.getTimeDecay()).orElse(1.0E-4);
        this.anomalyRate = Optional.ofNullable(rcfParams.getAnomalyRate()).orElse(0.005);
        this.timeField = Optional.ofNullable(rcfParams.getTimeField()).orElse(DEFAULT_TIME_FIELD);
        this.dateFormat = Optional.ofNullable(rcfParams.getDateFormat()).orElse(DEFAULT_DATE_FORMAT);
        this.timeZone = Optional.ofNullable(rcfParams.getTimeZone()).orElse(DEFAULT_TIME_ZONE);
        if (this.dateFormat != null) {
            this.simpleDateFormat = new SimpleDateFormat(this.dateFormat);
            this.simpleDateFormat.setTimeZone(TimeZone.getTimeZone(this.timeZone));
        }
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params) {
        ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model);
        this.forest = (ThresholdedRandomCutForest)trcfMapper.toModel((Object)state);
    }

    @Override
    public void close() {
        this.forest = null;
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        List<Map<String, Object>> predictResult = this.process(dataFrame, this.forest, mlInput.getParameters());
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
    }

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for FIT RCF prediction.");
        }
        ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model);
        this.forest = (ThresholdedRandomCutForest)trcfMapper.toModel((Object)state);
        return this.predict(mlInput);
    }

    @Override
    public MLModel train(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        ThresholdedRandomCutForest forest = this.createThresholdedRandomCutForest(dataFrame);
        this.process(dataFrame, forest, mlInput.getParameters());
        ThresholdedRandomCutForestState state = trcfMapper.toState(forest);
        MLModel model = MLModel.builder().name(FunctionName.FIT_RCF.name()).algorithm(FunctionName.FIT_RCF).version(VERSION).content(ModelSerDeSer.encodeBase64(RCFModelSerDeSer.serializeTRCF(state))).modelState(MLModelState.TRAINED).build();
        return model;
    }

    @Override
    public MLOutput trainAndPredict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        ThresholdedRandomCutForest forest = this.createThresholdedRandomCutForest(dataFrame);
        List<Map<String, Object>> predictResult = this.process(dataFrame, forest, null);
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
    }

    private List<Map<String, Object>> process(DataFrame dataFrame, ThresholdedRandomCutForest forest, MLAlgoParams parameters) {
        String timeField = this.timeField;
        DateFormat dateFormat = this.simpleDateFormat;
        if (parameters != null) {
            FitRCFParams rcfParams = (FitRCFParams)parameters;
            timeField = Optional.ofNullable(rcfParams.getTimeField()).orElse(DEFAULT_TIME_FIELD);
            String timeZone = Optional.ofNullable(rcfParams.getTimeZone()).orElse(DEFAULT_TIME_ZONE);
            dateFormat = new SimpleDateFormat(Optional.ofNullable(rcfParams.getDateFormat()).orElse(DEFAULT_DATE_FORMAT));
            dateFormat.setTimeZone(TimeZone.getTimeZone(timeZone));
        }
        ArrayList<Double> pointList = new ArrayList<Double>();
        ColumnMeta[] columnMetas = dataFrame.columnMetas();
        ArrayList<Map<String, Object>> predictResult = new ArrayList<Map<String, Object>>();
        for (int rowNum = 0; rowNum < dataFrame.size(); ++rowNum) {
            Row row = dataFrame.getRow(rowNum);
            long timestamp = -1L;
            for (int i = 0; i < columnMetas.length; ++i) {
                ColumnMeta columnMeta = columnMetas[i];
                ColumnValue value = row.getValue(i);
                if (timeField != null && timeField.equals(columnMeta.getName())) {
                    ColumnType columnType = columnMeta.getColumnType();
                    if (columnType == ColumnType.LONG) {
                        timestamp = value.longValue();
                        continue;
                    }
                    if (columnType == ColumnType.STRING) {
                        try {
                            timestamp = dateFormat.parse(value.stringValue()).getTime();
                            continue;
                        }
                        catch (ParseException e) {
                            log.error("Failed to parse timestamp " + value.stringValue(), (Throwable)e);
                            throw new MLValidationException("Failed to parse timestamp " + value.stringValue());
                        }
                    }
                    throw new MLValidationException("Wrong data type of time field. Should use LONG or STRING, but got " + columnType);
                }
                pointList.add(value.doubleValue());
            }
            double[] point = pointList.stream().mapToDouble(d -> d).toArray();
            pointList.clear();
            HashMap<String, Number> result = new HashMap<String, Number>();
            AnomalyDescriptor process = forest.process(point, timestamp);
            result.put(timeField, timestamp);
            result.put("score", process.getRCFScore());
            result.put("anomaly_grade", process.getAnomalyGrade());
            predictResult.add(result);
        }
        return predictResult;
    }

    private ThresholdedRandomCutForest createThresholdedRandomCutForest(DataFrame dataFrame) {
        ThresholdedRandomCutForest forest = ThresholdedRandomCutForest.builder().dimensions(this.shingleSize * (dataFrame.columnMetas().length - 1)).sampleSize(this.sampleSize.intValue()).numberOfTrees(this.numberOfTrees.intValue()).timeDecay(this.timeDecay.doubleValue()).outputAfter(this.outputAfter.intValue()).initialAcceptFraction((double)this.outputAfter.intValue() * 1.0 / (double)this.sampleSize.intValue()).parallelExecutionEnabled(false).compact(true).precision(Precision.FLOAT_32).boundingBoxCacheFraction(1.0).shingleSize(this.shingleSize.intValue()).internalShinglingEnabled(true).anomalyRate(this.anomalyRate.doubleValue()).forestMode(ForestMode.STANDARD).build();
        return forest;
    }
}

