/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolsExtractor;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.DeterminismEvaluator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.GroupReference;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Lookup;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.LimitNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.SymbolMapper;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeCoercion;
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeManager;

public class PlanNodeDecorrelator {
    private final PlannerContext plannerContext;
    private final SymbolAllocator symbolAllocator;
    private final Lookup lookup;
    private final TypeCoercion typeCoercion;

    public PlanNodeDecorrelator(PlannerContext plannerContext, SymbolAllocator symbolAllocator, Lookup lookup) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
        this.typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType);
    }

    public Optional<DecorrelatedNode> decorrelateFilters(PlanNode node, List<Symbol> correlation) {
        if (correlation.isEmpty()) {
            return Optional.of(new DecorrelatedNode((List<Expression>)ImmutableList.of(), node));
        }
        Optional<DecorrelationResult> decorrelationResultOptional = node.accept(new DecorrelatingVisitor(this.plannerContext.getTypeManager(), correlation), null);
        return decorrelationResultOptional.flatMap(decorrelationResult -> this.decorrelatedNode(decorrelationResult.correlatedPredicates, decorrelationResult.node, correlation));
    }

    private Optional<DecorrelatedNode> decorrelatedNode(List<Expression> correlatedPredicates, PlanNode node, List<Symbol> correlation) {
        if (this.containsCorrelation(node, correlation)) {
            return Optional.empty();
        }
        return Optional.of(new DecorrelatedNode(correlatedPredicates, node));
    }

    private boolean containsCorrelation(PlanNode node, List<Symbol> correlation) {
        return Sets.union(SymbolsExtractor.extractUnique(node, this.lookup), SymbolsExtractor.extractOutputSymbols(node, this.lookup)).stream().anyMatch(correlation::contains);
    }

    public static class DecorrelatedNode {
        private final List<Expression> correlatedPredicates;
        private final PlanNode node;

        public DecorrelatedNode(List<Expression> correlatedPredicates, PlanNode node) {
            Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null");
            this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates);
            this.node = Objects.requireNonNull(node, "node is null");
        }

        public Optional<Expression> getCorrelatedPredicates() {
            if (this.correlatedPredicates.isEmpty()) {
                return Optional.empty();
            }
            return Optional.of(IrUtils.and(this.correlatedPredicates));
        }

        public PlanNode getNode() {
            return this.node;
        }
    }

    private class DecorrelatingVisitor
    extends PlanVisitor<Optional<DecorrelationResult>, Void> {
        private final TypeManager typeManager;
        private final List<Symbol> correlation;

        DecorrelatingVisitor(TypeManager typeManager, List<Symbol> correlation) {
            this.typeManager = Objects.requireNonNull(typeManager, "typeManager is null");
            this.correlation = Objects.requireNonNull(correlation, "correlation is null");
        }

        @Override
        public Optional<DecorrelationResult> visitPlan(PlanNode node, Void context) {
            if (PlanNodeDecorrelator.this.containsCorrelation(node, this.correlation)) {
                return Optional.empty();
            }
            return Optional.of(new DecorrelationResult(node, (Set<Symbol>)ImmutableSet.of(), (List<Expression>)ImmutableList.of(), (Multimap<Symbol, Symbol>)ImmutableMultimap.of(), (Set<Symbol>)ImmutableSet.of(), false));
        }

        @Override
        public Optional<DecorrelationResult> visitGroupReference(GroupReference node, Void context) {
            return PlanNodeDecorrelator.this.lookup.resolve(node).accept(this, null);
        }

        @Override
        public Optional<DecorrelationResult> visitFilter(FilterNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = Optional.of(new DecorrelationResult(node.getChild(), (Set<Symbol>)ImmutableSet.of(), (List<Expression>)ImmutableList.of(), (Multimap<Symbol, Symbol>)ImmutableMultimap.of(), (Set<Symbol>)ImmutableSet.of(), false));
            if (PlanNodeDecorrelator.this.containsCorrelation(node.getChild(), this.correlation)) {
                childDecorrelationResultOptional = node.getChild().accept(this, null);
            }
            if (!childDecorrelationResultOptional.isPresent()) {
                return Optional.empty();
            }
            Expression predicate = node.getPredicate();
            Map<Boolean, List<Expression>> predicates = IrUtils.extractConjuncts(predicate).stream().collect(Collectors.partitioningBy(this::isCorrelated));
            ImmutableList correlatedPredicates = ImmutableList.copyOf((Collection)predicates.get(true));
            ImmutableList uncorrelatedPredicates = ImmutableList.copyOf((Collection)predicates.get(false));
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            FilterNode newFilterNode = new FilterNode(node.getPlanNodeId(), childDecorrelationResult.node, IrUtils.combineConjuncts((Collection<Expression>)uncorrelatedPredicates));
            Sets.SetView symbolsToPropagate = Sets.difference(SymbolsExtractor.extractUnique((Iterable<? extends Expression>)correlatedPredicates), (Set)ImmutableSet.copyOf(this.correlation));
            return Optional.of(new DecorrelationResult(newFilterNode, (Set<Symbol>)Sets.union(childDecorrelationResult.symbolsToPropagate, (Set)symbolsToPropagate), (List<Expression>)ImmutableList.builder().addAll(childDecorrelationResult.correlatedPredicates).addAll((Iterable)correlatedPredicates).build(), (Multimap<Symbol, Symbol>)ImmutableMultimap.builder().putAll(childDecorrelationResult.correlatedSymbolsMapping).putAll(this.extractCorrelatedSymbolsMapping((List<Expression>)correlatedPredicates)).build(), (Set<Symbol>)ImmutableSet.builder().addAll(childDecorrelationResult.constantSymbols).addAll(this.extractConstantSymbols((List<Expression>)correlatedPredicates)).build(), childDecorrelationResult.atMostSingleRow));
        }

        @Override
        public Optional<DecorrelationResult> visitLimit(LimitNode node, Void context) {
            if (node.getCount() == 0L || node.isWithTies()) {
                return Optional.empty();
            }
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getChild().accept(this, null);
            if (!childDecorrelationResultOptional.isPresent()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            if (childDecorrelationResult.atMostSingleRow) {
                return childDecorrelationResultOptional;
            }
            if (node.getCount() == 1L) {
                return this.rewriteLimitWithRowCountOne(childDecorrelationResult, node.getPlanNodeId());
            }
            throw new SemanticException("Decorrelation for LIMIT with row count greater than 1 is not supported yet");
        }

        private Optional<DecorrelationResult> rewriteLimitWithRowCountOne(DecorrelationResult childDecorrelationResult, PlanNodeId nodeId) {
            PlanNode decorrelatedChildNode = childDecorrelationResult.node;
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            if (constantSymbols.isEmpty() || !constantSymbols.containsAll(decorrelatedChildNode.getOutputSymbols())) {
                return Optional.empty();
            }
            AggregationNode aggregationNode = AggregationNode.singleAggregation(nodeId, decorrelatedChildNode, (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of(), AggregationNode.singleGroupingSet(decorrelatedChildNode.getOutputSymbols()));
            return Optional.of(new DecorrelationResult(aggregationNode, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, childDecorrelationResult.constantSymbols, true));
        }

        @Override
        public Optional<DecorrelationResult> visitTopK(TopKNode node, Void context) {
            throw new SemanticException("TopK is not supported in correlated subquery for now");
        }

        @Override
        public Optional<DecorrelationResult> visitAggregation(AggregationNode node, Void context) {
            if (node.hasEmptyGroupingSet()) {
                return Optional.empty();
            }
            if (node.getGroupingSetCount() != 1) {
                return Optional.empty();
            }
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getChild().accept(this, null);
            if (!childDecorrelationResultOptional.isPresent()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            Set<Symbol> constantSymbols = childDecorrelationResult.getConstantSymbols();
            AggregationNode decorrelatedAggregation = childDecorrelationResult.getCorrelatedSymbolMapper().map(node, childDecorrelationResult.node);
            ImmutableSet groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys());
            Preconditions.checkState((boolean)ImmutableSet.copyOf(decorrelatedAggregation.getGroupingKeys()).equals((Object)groupingKeys), (Object)"grouping keys were correlated");
            List symbolsToAdd = (List)childDecorrelationResult.symbolsToPropagate.stream().filter(arg_0 -> DecorrelatingVisitor.lambda$visitAggregation$0((Set)groupingKeys, arg_0)).collect(ImmutableList.toImmutableList());
            if (!constantSymbols.containsAll(symbolsToAdd)) {
                return Optional.empty();
            }
            AggregationNode newAggregation = AggregationNode.builderFrom(decorrelatedAggregation).setGroupingSets(AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(node.getGroupingKeys()).addAll((Iterable)symbolsToAdd).build())).setPreGroupedSymbols((List<Symbol>)ImmutableList.of()).build();
            return Optional.of(new DecorrelationResult(newAggregation, childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, childDecorrelationResult.constantSymbols, constantSymbols.containsAll(newAggregation.getGroupingKeys())));
        }

        @Override
        public Optional<DecorrelationResult> visitProject(ProjectNode node, Void context) {
            Optional<DecorrelationResult> childDecorrelationResultOptional = node.getChild().accept(this, null);
            if (!childDecorrelationResultOptional.isPresent()) {
                return Optional.empty();
            }
            DecorrelationResult childDecorrelationResult = childDecorrelationResultOptional.get();
            ImmutableSet nodeOutputSymbols = ImmutableSet.copyOf(node.getOutputSymbols());
            List symbolsToAdd = (List)childDecorrelationResult.symbolsToPropagate.stream().filter(arg_0 -> DecorrelatingVisitor.lambda$visitProject$1((Set)nodeOutputSymbols, arg_0)).collect(ImmutableList.toImmutableList());
            Assignments assignments = Assignments.builder().putAll(node.getAssignments()).putIdentities(symbolsToAdd).build();
            return Optional.of(new DecorrelationResult(new ProjectNode(node.getPlanNodeId(), childDecorrelationResult.node, assignments), childDecorrelationResult.symbolsToPropagate, childDecorrelationResult.correlatedPredicates, childDecorrelationResult.correlatedSymbolsMapping, childDecorrelationResult.constantSymbols, childDecorrelationResult.atMostSingleRow));
        }

        private Multimap<Symbol, Symbol> extractCorrelatedSymbolsMapping(List<Expression> correlatedConjuncts) {
            ImmutableMultimap.Builder mapping = ImmutableMultimap.builder();
            for (Expression conjunct : correlatedConjuncts) {
                ComparisonExpression comparison;
                if (!(conjunct instanceof ComparisonExpression) || !((comparison = (ComparisonExpression)conjunct).getLeft() instanceof SymbolReference) || !(comparison.getRight() instanceof SymbolReference) || comparison.getOperator() != ComparisonExpression.Operator.EQUAL) continue;
                Symbol left = Symbol.from(comparison.getLeft());
                Symbol right = Symbol.from(comparison.getRight());
                if (this.correlation.contains(left) && !this.correlation.contains(right)) {
                    mapping.put((Object)left, (Object)right);
                }
                if (!this.correlation.contains(right) || this.correlation.contains(left)) continue;
                mapping.put((Object)right, (Object)left);
            }
            return mapping.build();
        }

        private Set<Symbol> extractConstantSymbols(List<Expression> correlatedConjuncts) {
            ImmutableSet.Builder constants = ImmutableSet.builder();
            correlatedConjuncts.stream().filter(ComparisonExpression.class::isInstance).map(ComparisonExpression.class::cast).filter(comparison -> comparison.getOperator() == ComparisonExpression.Operator.EQUAL).forEach(comparison -> {
                Expression left = comparison.getLeft();
                Expression right = comparison.getRight();
                if (!this.isCorrelated(left) && (left instanceof SymbolReference || this.isSimpleInjectiveCast(left)) && this.isConstant(right)) {
                    constants.add((Object)this.getSymbol(left));
                }
                if (!this.isCorrelated(right) && (right instanceof SymbolReference || this.isSimpleInjectiveCast(right)) && this.isConstant(left)) {
                    constants.add((Object)this.getSymbol(right));
                }
            });
            return constants.build();
        }

        private boolean isConstant(Expression expression) {
            return DeterminismEvaluator.isDeterministic(expression) && ImmutableSet.copyOf(this.correlation).containsAll(SymbolsExtractor.extractUnique(expression));
        }

        private boolean isSimpleInjectiveCast(Expression expression) {
            if (!(expression instanceof Cast)) {
                return false;
            }
            Cast cast = (Cast)expression;
            return cast.getExpression() instanceof SymbolReference;
        }

        private Symbol getSymbol(Expression expression) {
            if (expression instanceof SymbolReference) {
                return Symbol.from(expression);
            }
            return Symbol.from(((Cast)expression).getExpression());
        }

        private boolean isCorrelated(Expression expression) {
            return this.correlation.stream().anyMatch(SymbolsExtractor.extractUnique(expression)::contains);
        }

        private static /* synthetic */ boolean lambda$visitProject$1(Set nodeOutputSymbols, Symbol symbol) {
            return !nodeOutputSymbols.contains(symbol);
        }

        private static /* synthetic */ boolean lambda$visitAggregation$0(Set groupingKeys, Symbol symbol) {
            return !groupingKeys.contains(symbol);
        }
    }

    private static class DecorrelationResult {
        final PlanNode node;
        final Set<Symbol> symbolsToPropagate;
        final List<Expression> correlatedPredicates;
        final Multimap<Symbol, Symbol> correlatedSymbolsMapping;
        final Set<Symbol> constantSymbols;
        final boolean atMostSingleRow;

        DecorrelationResult(PlanNode node, Set<Symbol> symbolsToPropagate, List<Expression> correlatedPredicates, Multimap<Symbol, Symbol> correlatedSymbolsMapping, Set<Symbol> constantSymbols, boolean atMostSingleRow) {
            this.node = node;
            this.symbolsToPropagate = symbolsToPropagate;
            this.correlatedPredicates = correlatedPredicates;
            this.atMostSingleRow = atMostSingleRow;
            this.correlatedSymbolsMapping = correlatedSymbolsMapping;
            this.constantSymbols = constantSymbols;
            Preconditions.checkState((boolean)constantSymbols.containsAll(correlatedSymbolsMapping.values()), (Object)"Expected constant symbols to contain all correlated symbols local equivalents");
            Preconditions.checkState((boolean)symbolsToPropagate.containsAll(constantSymbols), (Object)"Expected symbols to propagate to contain all constant symbols");
        }

        SymbolMapper getCorrelatedSymbolMapper() {
            return SymbolMapper.symbolMapper((Map)this.correlatedSymbolsMapping.asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, symbols -> (Symbol)Iterables.getLast((Iterable)((Iterable)symbols.getValue())))));
        }

        Set<Symbol> getConstantSymbols() {
            return this.constantSymbols;
        }
    }
}

