/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.physical;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperatorActions;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor;
import org.opensearch.transport.client.node.NodeClient;

public class MLOperator
extends MLCommonsOperatorActions {
    private final PhysicalPlan input;
    private final Map<String, Literal> arguments;
    private final NodeClient nodeClient;
    private Iterator<ExprValue> iterator;

    public void open() {
        super.open();
        Map<String, Object> args = this.processArgs(this.arguments);
        String categoryField = this.arguments.containsKey("category_field") ? (String)this.arguments.get("category_field").getValue() : null;
        final boolean isPrediction = !((String)args.get("action")).equals("train");
        final Iterator<String> trainIter = Collections.singletonList("train").iterator();
        List<Pair<DataFrame, DataFrame>> inputDataFrames = this.generateCategorizedInputDataset(this.input, categoryField);
        final List<MLOutput> mlOutputs = inputDataFrames.stream().map(pair -> this.getMLOutput((DataFrame)pair.getRight(), args, this.nodeClient)).toList();
        final Iterator<Pair<DataFrame, DataFrame>> inputDataFramesIter = inputDataFrames.iterator();
        final Iterator<MLOutput> mlOutputIter = mlOutputs.iterator();
        this.iterator = new Iterator<ExprValue>(){
            private DataFrame inputDataFrame = null;
            private Iterator<Row> inputRowIter = null;
            private MLOutput mlOutput = null;
            private Iterator<Row> resultRowIter = null;

            @Override
            public boolean hasNext() {
                if (isPrediction) {
                    return this.inputRowIter != null && this.inputRowIter.hasNext() || inputDataFramesIter.hasNext();
                }
                boolean res = trainIter.hasNext();
                if (res) {
                    trainIter.next();
                }
                return res;
            }

            @Override
            public ExprValue next() {
                if (isPrediction) {
                    if (this.inputRowIter == null || !this.inputRowIter.hasNext()) {
                        Pair pair = (Pair)inputDataFramesIter.next();
                        this.inputDataFrame = (DataFrame)pair.getLeft();
                        this.inputRowIter = this.inputDataFrame.iterator();
                        this.mlOutput = (MLOutput)mlOutputIter.next();
                        this.resultRowIter = ((MLPredictionOutput)this.mlOutput).getPredictionResult().iterator();
                    }
                    return MLOperator.this.buildPPLResult(true, this.inputRowIter, this.inputDataFrame, this.mlOutput, this.resultRowIter);
                }
                return MLOperator.this.buildPPLResult(false, null, null, (MLOutput)mlOutputs.getFirst(), null);
            }
        };
    }

    public <R, C> R accept(PhysicalPlanNodeVisitor<R, C> visitor, C context) {
        return (R)visitor.visitML((PhysicalPlan)this, context);
    }

    public boolean hasNext() {
        return this.iterator.hasNext();
    }

    public ExprValue next() {
        return this.iterator.next();
    }

    public List<PhysicalPlan> getChild() {
        return Collections.singletonList(this.input);
    }

    protected Map<String, Object> processArgs(Map<String, Literal> arguments) {
        HashMap<String, Object> res = new HashMap<String, Object>();
        arguments.forEach((k, v) -> res.put((String)k, v.getValue()));
        return res;
    }

    @Generated
    public MLOperator(PhysicalPlan input, Map<String, Literal> arguments, NodeClient nodeClient) {
        this.input = input;
        this.arguments = arguments;
        this.nodeClient = nodeClient;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MLOperator)) {
            return false;
        }
        MLOperator other = (MLOperator)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        PhysicalPlan this$input = this.getInput();
        PhysicalPlan other$input = other.getInput();
        if (this$input == null ? other$input != null : !this$input.equals(other$input)) {
            return false;
        }
        Map<String, Literal> this$arguments = this.getArguments();
        Map<String, Literal> other$arguments = other.getArguments();
        if (this$arguments == null ? other$arguments != null : !((Object)this$arguments).equals(other$arguments)) {
            return false;
        }
        NodeClient this$nodeClient = this.getNodeClient();
        NodeClient other$nodeClient = other.getNodeClient();
        return !(this$nodeClient == null ? other$nodeClient != null : !this$nodeClient.equals(other$nodeClient));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof MLOperator;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        PhysicalPlan $input = this.getInput();
        result = result * 59 + ($input == null ? 43 : $input.hashCode());
        Map<String, Literal> $arguments = this.getArguments();
        result = result * 59 + ($arguments == null ? 43 : ((Object)$arguments).hashCode());
        NodeClient $nodeClient = this.getNodeClient();
        result = result * 59 + ($nodeClient == null ? 43 : $nodeClient.hashCode());
        return result;
    }

    @Generated
    public PhysicalPlan getInput() {
        return this.input;
    }

    @Generated
    public Map<String, Literal> getArguments() {
        return this.arguments;
    }

    @Generated
    public NodeClient getNodeClient() {
        return this.nodeClient;
    }
}

