/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.math.expr;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.TokenSource;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.ParseTreeListener;
import org.antlr.v4.runtime.tree.ParseTreeWalker;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.ApplyFunction;
import org.apache.druid.math.expr.ApplyFunctionExpr;
import org.apache.druid.math.expr.BigIntegerExpr;
import org.apache.druid.math.expr.BinaryOpExprBase;
import org.apache.druid.math.expr.ConstantExpr;
import org.apache.druid.math.expr.Evals;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprListenerImpl;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Function;
import org.apache.druid.math.expr.FunctionExpr;
import org.apache.druid.math.expr.IdentifierExpr;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.LambdaExpr;
import org.apache.druid.math.expr.StringExpr;
import org.apache.druid.math.expr.UnaryExpr;
import org.apache.druid.math.expr.antlr.ExprLexer;
import org.apache.druid.math.expr.antlr.ExprParser;

public class Parser {
    private static final Logger log = new Logger(Parser.class);
    private static final Map<String, Function> FUNCTIONS;
    private static final Map<String, ApplyFunction> APPLY_FUNCTIONS;

    public static Function getFunction(String name) {
        return FUNCTIONS.get(StringUtils.toLowerCase(name));
    }

    public static ApplyFunction getApplyFunction(String name) {
        return APPLY_FUNCTIONS.get(StringUtils.toLowerCase(name));
    }

    public static Supplier<Expr> lazyParse(@Nullable String in, ExprMacroTable macroTable) {
        return Suppliers.memoize(() -> in == null ? null : Parser.parse(in, macroTable));
    }

    public static Expr parse(String in, ExprMacroTable macroTable) {
        return Parser.parse(in, macroTable, true);
    }

    @VisibleForTesting
    public static Expr parse(String in, ExprMacroTable macroTable, boolean withFlatten) {
        ExprLexer lexer = new ExprLexer((CharStream)new ANTLRInputStream(in));
        CommonTokenStream tokens = new CommonTokenStream((TokenSource)lexer);
        ExprParser parser = new ExprParser((TokenStream)tokens);
        parser.setBuildParseTree(true);
        ExprParser.ExprContext parseTree = parser.expr();
        ParseTreeWalker walker = new ParseTreeWalker();
        ExprListenerImpl listener = new ExprListenerImpl((ParseTree)parseTree, macroTable);
        walker.walk((ParseTreeListener)listener, (ParseTree)parseTree);
        Expr parsed = listener.getAST();
        if (parsed == null) {
            throw new RE("Failed to parse expression: %s", in);
        }
        return withFlatten ? Parser.flatten(parsed) : parsed;
    }

    public static Expr identifier(String identifier) {
        return new IdentifierExpr(identifier);
    }

    public static Expr constant(String constant) {
        return new StringExpr(constant);
    }

    public static Expr flatten(Expr expr) {
        return expr.visit(childExpr -> {
            ExprMacroTable.ExprMacroFunctionExpr macroFn;
            if (childExpr instanceof BinaryOpExprBase) {
                BinaryOpExprBase binary = (BinaryOpExprBase)childExpr;
                if (Evals.isAllConstants(binary.left, binary.right)) {
                    return childExpr.eval(InputBindings.nilBindings()).toExpr();
                }
            } else if (childExpr instanceof UnaryExpr) {
                UnaryExpr unary = (UnaryExpr)childExpr;
                if (unary.expr instanceof ConstantExpr) {
                    return childExpr.eval(InputBindings.nilBindings()).toExpr();
                }
            } else if (childExpr instanceof FunctionExpr) {
                FunctionExpr functionExpr = (FunctionExpr)childExpr;
                ImmutableList<Expr> args = functionExpr.args;
                if (Evals.isAllConstants(args)) {
                    return childExpr.eval(InputBindings.nilBindings()).toExpr();
                }
            } else if (childExpr instanceof ApplyFunctionExpr) {
                ApplyFunctionExpr applyFunctionExpr = (ApplyFunctionExpr)childExpr;
                ImmutableList<Expr> args = applyFunctionExpr.argsExpr;
                if (Evals.isAllConstants(args) && applyFunctionExpr.analyzeInputs().getFreeVariables().size() == 0) {
                    return childExpr.eval(InputBindings.nilBindings()).toExpr();
                }
            } else if (childExpr instanceof ExprMacroTable.ExprMacroFunctionExpr && Evals.isAllConstants((macroFn = (ExprMacroTable.ExprMacroFunctionExpr)childExpr).getArgs())) {
                return childExpr.eval(InputBindings.nilBindings()).toExpr();
            }
            return childExpr;
        }).visit(childExpr -> {
            if (childExpr instanceof BigIntegerExpr) {
                return childExpr.eval(InputBindings.nilBindings()).toExpr();
            }
            return childExpr;
        });
    }

    public static Expr applyUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply) {
        if (bindingsToApply.isEmpty()) {
            return expr;
        }
        List<String> unappliedBindingsInExpression = bindingsToApply.stream().filter(x -> bindingAnalysis.getRequiredBindings().contains(x)).collect(Collectors.toList());
        Expr newExpr = Parser.rewriteUnappliedSubExpressions(expr, unappliedBindingsInExpression, arg -> Parser.applyUnappliedBindings(arg, bindingAnalysis, bindingsToApply));
        Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
        Set<String> expectedArrays = newExprBindings.getArrayVariables();
        List<String> remainingUnappliedBindings = unappliedBindingsInExpression.stream().filter(x -> !expectedArrays.contains(x)).collect(Collectors.toList());
        if (remainingUnappliedBindings.isEmpty()) {
            return newExpr;
        }
        return Parser.applyUnapplied(newExpr, remainingUnappliedBindings);
    }

    public static Expr foldUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply, String accumulatorId) {
        if (bindingsToApply.isEmpty()) {
            return expr;
        }
        List<String> unappliedBindingsInExpression = bindingsToApply.stream().filter(x -> bindingAnalysis.getRequiredBindings().contains(x)).collect(Collectors.toList());
        Expr newExpr = Parser.rewriteUnappliedSubExpressions(expr, unappliedBindingsInExpression, arg -> Parser.foldUnappliedBindings(arg, bindingAnalysis, bindingsToApply, accumulatorId));
        Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
        Set<String> expectedArrays = newExprBindings.getArrayVariables();
        List<String> remainingUnappliedBindings = unappliedBindingsInExpression.stream().filter(x -> !expectedArrays.contains(x)).collect(Collectors.toList());
        if (remainingUnappliedBindings.isEmpty()) {
            return newExpr;
        }
        return Parser.foldUnapplied(newExpr, remainingUnappliedBindings, accumulatorId);
    }

    private static Expr rewriteUnappliedSubExpressions(Expr expr, List<String> unappliedBindingsInExpression, UnaryOperator<Expr> applyUnappliedFn) {
        return expr.visit(childExpr -> {
            if (childExpr instanceof ApplyFunctionExpr) {
                return Parser.liftApplyLambda((ApplyFunctionExpr)childExpr, unappliedBindingsInExpression);
            }
            if (childExpr instanceof FunctionExpr) {
                FunctionExpr fnExpr = (FunctionExpr)childExpr;
                Set<Expr> arrayInputs = fnExpr.function.getArrayInputs((List<Expr>)fnExpr.args);
                ArrayList<Expr> newArgs = new ArrayList<Expr>();
                for (Expr arg : fnExpr.args) {
                    if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) {
                        Expr newArg = (Expr)applyUnappliedFn.apply(arg);
                        newArgs.add(newArg);
                        continue;
                    }
                    newArgs.add(arg);
                }
                FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs);
                return newFnExpr;
            }
            return childExpr;
        });
    }

    private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings) {
        List args = expr.analyzeInputs().getFreeVariables().stream().filter(x -> unappliedBindings.contains(x.getBinding())).collect(Collectors.toList());
        if (args.isEmpty()) {
            return expr;
        }
        ArrayList<IdentifierExpr> lambdaArgs = new ArrayList<IdentifierExpr>();
        HashMap<String, IdentifierExpr> toReplace = new HashMap<String, IdentifierExpr>();
        for (IdentifierExpr applyFnArg : args) {
            if (toReplace.containsKey(applyFnArg.getBinding())) continue;
            IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
            lambdaArgs.add(lambdaRewrite);
            toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
        }
        Expr newExpr = expr.visit(childExpr -> {
            if (childExpr instanceof IdentifierExpr && toReplace.containsKey(((IdentifierExpr)childExpr).getBinding())) {
                return (Expr)toReplace.get(((IdentifierExpr)childExpr).getBinding());
            }
            return childExpr;
        });
        LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
        ApplyFunction.BaseMapFunction fn = lambdaArgs.size() == 1 ? new ApplyFunction.MapFunction() : new ApplyFunction.CartesianMapFunction();
        ApplyFunctionExpr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, (List<Expr>)ImmutableList.copyOf(lambdaArgs));
        return magic;
    }

    private static Expr foldUnapplied(Expr expr, List<String> unappliedBindings, String accumulatorId) {
        List args = expr.analyzeInputs().getFreeVariables().stream().filter(x -> unappliedBindings.contains(x.getBinding())).collect(Collectors.toList());
        ArrayList<IdentifierExpr> lambdaArgs = new ArrayList<IdentifierExpr>();
        HashMap<String, IdentifierExpr> toReplace = new HashMap<String, IdentifierExpr>();
        for (IdentifierExpr applyFnArg : args) {
            if (toReplace.containsKey(applyFnArg.getBinding())) continue;
            IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
            lambdaArgs.add(lambdaRewrite);
            toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
        }
        lambdaArgs.add(new IdentifierExpr(accumulatorId));
        Expr newExpr = expr.visit(childExpr -> {
            if (childExpr instanceof IdentifierExpr && toReplace.containsKey(((IdentifierExpr)childExpr).getBinding())) {
                return (Expr)toReplace.get(((IdentifierExpr)childExpr).getBinding());
            }
            return childExpr;
        });
        LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
        ApplyFunction.BaseFoldFunction fn = lambdaArgs.size() == 2 ? new ApplyFunction.FoldFunction() : new ApplyFunction.CartesianFoldFunction();
        ApplyFunctionExpr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, (List<Expr>)ImmutableList.copyOf(lambdaArgs));
        return magic;
    }

    private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List<String> unappliedArgs) {
        ApplyFunctionExpr newExpr;
        Set unappliedInThisApply = unappliedArgs.stream().filter(u -> !expr.bindingAnalysis.getArrayBindings().contains(u)).collect(Collectors.toSet());
        List<String> unappliedIdentifiers = expr.bindingAnalysis.getFreeVariables().stream().filter(x -> unappliedInThisApply.contains(x.getBindingIfIdentifier())).map(IdentifierExpr::getIdentifierIfIdentifier).collect(Collectors.toList());
        ArrayList<Expr> newArgs = new ArrayList<Expr>();
        for (int i = 0; i < expr.argsExpr.size(); ++i) {
            newArgs.add(Parser.applyUnappliedBindings((Expr)expr.argsExpr.get(i), (Expr.BindingAnalysis)expr.argsBindingAnalyses.get(i), unappliedIdentifiers));
        }
        List unappliedLambdaBindings = expr.lambdaBindingAnalysis.getFreeVariables().stream().filter(x -> unappliedArgs.contains(x.getBindingIfIdentifier())).map(x -> new IdentifierExpr(x.getIdentifier(), x.getBinding())).collect(Collectors.toList());
        if (unappliedLambdaBindings.isEmpty()) {
            return new ApplyFunctionExpr(expr.function, expr.name, expr.lambdaExpr, newArgs);
        }
        newArgs.addAll(unappliedLambdaBindings);
        switch (expr.function.name()) {
            case "map": 
            case "cartesian_map": {
                ArrayList<IdentifierExpr> lambdaIds = new ArrayList<IdentifierExpr>(expr.lambdaExpr.getIdentifiers().size() + unappliedArgs.size());
                lambdaIds.addAll((Collection<IdentifierExpr>)expr.lambdaExpr.getIdentifierExprs());
                lambdaIds.addAll(unappliedLambdaBindings);
                LambdaExpr newLambda = new LambdaExpr(lambdaIds, expr.lambdaExpr.getExpr());
                ApplyFunction.CartesianMapFunction newFn = new ApplyFunction.CartesianMapFunction();
                newExpr = new ApplyFunctionExpr(newFn, newFn.name(), newLambda, newArgs);
                break;
            }
            case "all": 
            case "any": 
            case "filter": {
                ApplyFunction.CartesianMapFunction newArrayFn = new ApplyFunction.CartesianMapFunction();
                IdentifierExpr identityExprIdentifier = new IdentifierExpr("_");
                LambdaExpr identityExpr = new LambdaExpr((List<IdentifierExpr>)ImmutableList.of((Object)identityExprIdentifier), identityExprIdentifier);
                ApplyFunctionExpr arrayExpr = new ApplyFunctionExpr(newArrayFn, newArrayFn.name(), identityExpr, newArgs);
                newExpr = new ApplyFunctionExpr(expr.function, expr.function.name(), identityExpr, (List<Expr>)ImmutableList.of((Object)arrayExpr));
                break;
            }
            case "fold": 
            case "cartesian_fold": {
                ArrayList<Expr> newFoldArgs = new ArrayList<Expr>(expr.argsExpr.size() + unappliedLambdaBindings.size());
                ArrayList<IdentifierExpr> newFoldLambdaIdentifiers = new ArrayList<IdentifierExpr>(expr.lambdaExpr.getIdentifiers().size() + unappliedLambdaBindings.size());
                ImmutableList<IdentifierExpr> existingFoldLambdaIdentifiers = expr.lambdaExpr.getIdentifierExprs();
                for (int i = 0; i < expr.argsExpr.size() - 1; ++i) {
                    newFoldArgs.add((Expr)expr.argsExpr.get(i));
                    newFoldLambdaIdentifiers.add((IdentifierExpr)existingFoldLambdaIdentifiers.get(i));
                }
                newFoldArgs.addAll(unappliedLambdaBindings);
                newFoldLambdaIdentifiers.addAll(unappliedLambdaBindings);
                newFoldLambdaIdentifiers.add((IdentifierExpr)existingFoldLambdaIdentifiers.get(existingFoldLambdaIdentifiers.size() - 1));
                newFoldArgs.add((Expr)expr.argsExpr.get(expr.argsExpr.size() - 1));
                LambdaExpr newFoldLambda = new LambdaExpr(newFoldLambdaIdentifiers, expr.lambdaExpr.getExpr());
                ApplyFunction.CartesianFoldFunction newFn = new ApplyFunction.CartesianFoldFunction();
                newExpr = new ApplyFunctionExpr(newFn, newFn.name(), newFoldLambda, newFoldArgs);
                break;
            }
            default: {
                throw new RE("Unable to transform apply function:[%s]", expr.function.name());
            }
        }
        return newExpr;
    }

    public static void validateExpr(Expr expression, Expr.BindingAnalysis bindingAnalysis) {
        Sets.SetView conflicted = Sets.intersection(bindingAnalysis.getScalarBindings(), bindingAnalysis.getArrayBindings());
        if (!conflicted.isEmpty()) {
            throw new RE("Invalid expression: %s; %s used as both scalar and array variables", expression, conflicted);
        }
    }

    static {
        HashMap<String, Function> functionMap = new HashMap<String, Function>();
        for (Class<?> clazz : Function.class.getClasses()) {
            if (Modifier.isAbstract(clazz.getModifiers()) || !Function.class.isAssignableFrom(clazz)) continue;
            try {
                Function function = (Function)clazz.newInstance();
                functionMap.put(StringUtils.toLowerCase(function.name()), function);
            }
            catch (Exception e) {
                log.error(e, "failed to instantiate %s.. ignoring", clazz.getName());
            }
        }
        FUNCTIONS = ImmutableMap.copyOf(functionMap);
        HashMap<String, ApplyFunction> applyFunctionMap = new HashMap<String, ApplyFunction>();
        for (Class<?> clazz : ApplyFunction.class.getClasses()) {
            if (Modifier.isAbstract(clazz.getModifiers()) || !ApplyFunction.class.isAssignableFrom(clazz)) continue;
            try {
                ApplyFunction function = (ApplyFunction)clazz.newInstance();
                applyFunctionMap.put(StringUtils.toLowerCase(function.name()), function);
            }
            catch (Exception e) {
                log.error(e, "failed to instantiate %s.. ignoring", clazz.getName());
            }
        }
        APPLY_FUNCTIONS = ImmutableMap.copyOf(applyFunctionMap);
    }
}

