/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.unload;

import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.unload.UnloadModelNodeRequest;
import org.opensearch.ml.common.transport.unload.UnloadModelNodeResponse;
import org.opensearch.ml.common.transport.unload.UnloadModelNodesRequest;
import org.opensearch.ml.common.transport.unload.UnloadModelNodesResponse;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class TransportUnloadModelAction
extends TransportNodesAction<UnloadModelNodesRequest, UnloadModelNodesResponse, UnloadModelNodeRequest, UnloadModelNodeResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportUnloadModelAction.class);
    private final MLModelManager mlModelManager;
    private final ClusterService clusterService;
    private final Client client;
    private DiscoveryNodeHelper nodeFilter;
    private final MLStats mlStats;

    @Inject
    public TransportUnloadModelAction(TransportService transportService, ActionFilters actionFilters, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, DiscoveryNodeHelper nodeFilter, MLStats mlStats) {
        super("cluster:admin/opensearch/ml/unload_model", threadPool, clusterService, transportService, actionFilters, UnloadModelNodesRequest::new, UnloadModelNodeRequest::new, "management", UnloadModelNodeResponse.class);
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.client = client;
        this.nodeFilter = nodeFilter;
        this.mlStats = mlStats;
    }

    protected UnloadModelNodesResponse newResponse(UnloadModelNodesRequest nodesRequest, List<UnloadModelNodeResponse> responses, List<FailedNodeException> failures) {
        if (responses != null) {
            HashMap removedNodeMap = new HashMap();
            responses.stream().forEach(r -> {
                HashSet<String> notFoundModels = new HashSet<String>();
                Map modelUnloadStatus = r.getModelUnloadStatus();
                for (Map.Entry entry : modelUnloadStatus.entrySet()) {
                    String status = (String)entry.getValue();
                    if ("unloaded".equals(status) || "not_found".equals(status)) {
                        String modelId = (String)entry.getKey();
                        if (!removedNodeMap.containsKey(modelId)) {
                            removedNodeMap.put(modelId, new ArrayList());
                        }
                        ((List)removedNodeMap.get(modelId)).add(r.getNode().getId());
                    }
                    if (!"not_found".equals(status)) continue;
                    notFoundModels.add((String)entry.getKey());
                }
                notFoundModels.forEach(m -> modelUnloadStatus.remove(m));
            });
            HashMap<String, String[]> removedNodes = new HashMap<String, String[]>();
            for (Map.Entry entry : removedNodeMap.entrySet()) {
                removedNodes.put((String)entry.getKey(), ((List)entry.getValue()).toArray(new String[0]));
                log.debug("removed node for model: {}, {}", entry.getKey(), (Object)Arrays.toString(((List)entry.getValue()).toArray(new String[0])));
            }
            MLSyncUpInput syncUpInput = MLSyncUpInput.builder().removedWorkerNodes(removedNodes).build();
            MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(this.nodeFilter.getAllNodes(), syncUpInput);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                if (removedNodeMap.size() > 0) {
                    BulkRequest bulkRequest = new BulkRequest();
                    for (String modelId : removedNodeMap.keySet()) {
                        UpdateRequest updateRequest = new UpdateRequest();
                        ((UpdateRequest)updateRequest.index(".plugins-ml-model")).id(modelId).doc((Map)ImmutableMap.of((Object)"model_state", (Object)MLModelState.UNLOADED));
                        bulkRequest.add(updateRequest);
                    }
                    ActionListener actionListenr = ActionListener.wrap(r -> log.debug("updated model state as unloaded for : {}", (Object)Arrays.toString(removedNodeMap.keySet().toArray(new String[0]))), e -> log.error("Failed to update model state as unloaded", (Throwable)e));
                    this.client.bulk(bulkRequest, ActionListener.runAfter((ActionListener)actionListenr, () -> this.syncUpUnloadedModels(syncUpRequest)));
                } else {
                    this.syncUpUnloadedModels(syncUpRequest);
                }
            }
        }
        return new UnloadModelNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    private void syncUpUnloadedModels(MLSyncUpNodesRequest syncUpRequest) {
        this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListener.wrap(r -> log.debug("sync up removed nodes successfully"), e -> log.error("failed to sync up removed node", (Throwable)e)));
    }

    protected UnloadModelNodeRequest newNodeRequest(UnloadModelNodesRequest request) {
        return new UnloadModelNodeRequest(request);
    }

    protected UnloadModelNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new UnloadModelNodeResponse(in);
    }

    protected UnloadModelNodeResponse nodeOperation(UnloadModelNodeRequest request) {
        return this.createUnloadModelNodeResponse(request.getUnloadModelNodesRequest());
    }

    private UnloadModelNodeResponse createUnloadModelNodeResponse(UnloadModelNodesRequest unloadModelNodesRequest) {
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
        String[] modelIds = unloadModelNodesRequest.getModelIds();
        Map<String, String> modelUnloadStatus = this.mlModelManager.unloadModel(modelIds);
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
        return new UnloadModelNodeResponse(this.clusterService.localNode(), modelUnloadStatus);
    }
}

