/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.expression.function.CollectionUDF;

import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;

public class MapRemoveFunctionImpl
extends ImplementorUDF {
    public MapRemoveFunctionImpl() {
        super(new MapRemoveImplementor(), NullPolicy.ARG0);
    }

    @Override
    public SqlReturnTypeInference getReturnTypeInference() {
        return sqlOperatorBinding -> {
            RelDataType mapType = sqlOperatorBinding.getOperandType(0);
            return sqlOperatorBinding.getTypeFactory().createTypeWithNullability(mapType, true);
        };
    }

    @Override
    public UDFOperandMetadata getOperandMetadata() {
        return null;
    }

    public static Object mapRemove(Object mapArg, Object keysArg) {
        if (mapArg == null || keysArg == null) {
            return mapArg;
        }
        MapRemoveFunctionImpl.verifyArgTypes(mapArg, keysArg);
        return MapRemoveFunctionImpl.mapRemove((Map)mapArg, (List)keysArg);
    }

    private static void verifyArgTypes(Object mapArg, Object keysArg) {
        if (!(mapArg instanceof Map)) {
            throw new IllegalArgumentException("First argument must be a map, got: " + String.valueOf(mapArg.getClass()));
        }
        if (!(keysArg instanceof List)) {
            throw new IllegalArgumentException("Second argument must be an array/list, got: " + String.valueOf(keysArg.getClass()));
        }
    }

    private static Map<String, Object> mapRemove(Map<String, Object> originalMap, List<Object> keysToRemove) {
        HashMap<String, Object> resultMap = new HashMap<String, Object>(originalMap);
        for (Object keyObj : keysToRemove) {
            if (keyObj == null) continue;
            String key = keyObj.toString();
            resultMap.remove(key);
        }
        return resultMap;
    }

    public static class MapRemoveImplementor
    implements NotNullImplementor {
        public Expression implement(RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
            return Expressions.call((Method)Types.lookupMethod(MapRemoveFunctionImpl.class, (String)"mapRemove", (Class[])new Class[]{Object.class, Object.class}), (Expression[])new Expression[]{translatedOperands.get(0), translatedOperands.get(1)});
        }
    }
}

