/*
 * Decompiled with CFR 0.152.
 */
package org.apache.bifromq.baserpc.client.loadbalancer;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.Status;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import lombok.Generated;
import org.apache.bifromq.baseenv.EnvProvider;
import org.apache.bifromq.baserpc.client.loadbalancer.ChannelList;
import org.apache.bifromq.baserpc.client.loadbalancer.Constants;
import org.apache.bifromq.baserpc.client.loadbalancer.IServerSelectorUpdateListener;
import org.apache.bifromq.baserpc.client.loadbalancer.SubChannelPicker;
import org.apache.bifromq.baserpc.client.loadbalancer.TenantAwareServerSelector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrafficDirectiveLoadBalancer
extends LoadBalancer {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(TrafficDirectiveLoadBalancer.class);
    private final LoadBalancer.Helper helper;
    private final IServerSelectorUpdateListener updateListener;
    private final SubChannelPicker currentPicker;
    private final Map<String, ChannelList> serverChannels = Maps.newHashMap();
    private final AtomicBoolean balancingStateUpdateScheduled = new AtomicBoolean(false);
    private Map<String, Boolean> currentServers = Maps.newHashMap();
    private Map<String, Set<String>> currentServerGroupTags = Maps.newHashMap();
    private Map<String, Map<String, Integer>> currentTrafficDirective = Maps.newHashMap();

    TrafficDirectiveLoadBalancer(LoadBalancer.Helper helper, IServerSelectorUpdateListener updateListener) {
        this.helper = (LoadBalancer.Helper)Preconditions.checkNotNull((Object)helper, (Object)"helper");
        this.updateListener = updateListener;
        this.currentPicker = new SubChannelPicker();
    }

    private static <T> Set<T> difference(Set<T> a, Set<T> b) {
        HashSet<T> aCopy = new HashSet<T>(a);
        aCopy.removeAll(b);
        return aCopy;
    }

    public void handleResolvedAddresses(LoadBalancer.ResolvedAddresses resolvedAddresses) {
        log.debug("Handle traffic change: resolvedAddresses={}", (Object)resolvedAddresses);
        HashMap<String, Object> newResolved = new HashMap<String, Object>();
        HashMap<String, Boolean> newServers = new HashMap<String, Boolean>();
        for (Object addressGroup : resolvedAddresses.getAddresses()) {
            String serverId = (String)addressGroup.getAttributes().get(Constants.SERVER_ID_ATTR_KEY);
            newResolved.put(serverId, addressGroup);
            newServers.put(serverId, (Boolean)addressGroup.getAttributes().get(Constants.IN_PROC_SERVER_ATTR_KEY));
        }
        HashMap newServerGroupTags = Maps.newHashMap();
        for (EquivalentAddressGroup addressGroup : resolvedAddresses.getAddresses()) {
            newServerGroupTags.put((String)addressGroup.getAttributes().get(Constants.SERVER_ID_ATTR_KEY), (Set)addressGroup.getAttributes().get(Constants.SERVER_GROUP_TAG_ATTR_KEY));
        }
        Map newTrafficDirective = (Map)resolvedAddresses.getAttributes().get(Constants.TRAFFIC_DIRECTIVE_ATTR_KEY);
        boolean updatePicker = !this.currentTrafficDirective.equals(newTrafficDirective) || !this.currentServerGroupTags.equals(newServerGroupTags);
        this.currentServers = newServers;
        this.currentServerGroupTags = newServerGroupTags;
        this.currentTrafficDirective = newTrafficDirective;
        int requested = Math.min(5, EnvProvider.INSTANCE.availableProcessors());
        Set<String> currentServers = this.serverChannels.keySet();
        Set latestServers = newResolved.keySet();
        Set<String> addedServers = TrafficDirectiveLoadBalancer.difference(latestServers, currentServers);
        Set<String> removedServers = TrafficDirectiveLoadBalancer.difference(currentServers, latestServers);
        for (String serverId : currentServers) {
            int openNow;
            if (removedServers.contains(serverId) || requested <= (openNow = this.serverChannels.get((Object)serverId).subChannels.size())) continue;
            updatePicker = true;
            IntStream.range(0, requested - openNow).forEach(i -> this.serverChannels.get((Object)serverId).subChannels.add(this.setupSubchannel(serverId, (EquivalentAddressGroup)newResolved.get(serverId), (Boolean)newServers.get(serverId))));
        }
        for (String serverId : addedServers) {
            this.serverChannels.computeIfAbsent(serverId, k -> {
                ChannelList scList = new ChannelList((Boolean)newServers.get(serverId));
                IntStream.range(0, requested).forEach(i -> scList.subChannels.add(this.setupSubchannel(serverId, (EquivalentAddressGroup)newResolved.get(serverId), (Boolean)newServers.get(serverId))));
                return scList;
            });
        }
        ArrayList<ChannelList> removedSubchannels = new ArrayList<ChannelList>();
        for (String serverId : removedServers) {
            removedSubchannels.add(this.serverChannels.remove(serverId));
        }
        for (ChannelList subChannelList : removedSubchannels) {
            for (LoadBalancer.Subchannel subchannel : subChannelList.subChannels) {
                this.shutdownSubChannel(subchannel);
            }
        }
        if (updatePicker) {
            this.scheduleBalancingStateUpdate();
        }
    }

    public void handleNameResolutionError(Status status) {
        log.error("Name resolution error:{}", (Object)status.getDescription());
        this.helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, (LoadBalancer.SubchannelPicker)this.currentPicker);
    }

    public void shutdown() {
        log.debug("Shutting down all subchannels");
        for (ChannelList subChannelList : this.serverChannels.values()) {
            for (LoadBalancer.Subchannel subchannel : subChannelList.subChannels) {
                this.shutdownSubChannel(subchannel);
            }
        }
    }

    public boolean canHandleEmptyAddressListFromNameResolution() {
        return true;
    }

    private ConnectivityStateInfo getSubChannelState(LoadBalancer.Subchannel subchannel) {
        return (ConnectivityStateInfo)((AtomicReference)subchannel.getAttributes().get(Constants.STATE_INFO)).get();
    }

    private void scheduleBalancingStateUpdate() {
        if (this.balancingStateUpdateScheduled.compareAndSet(false, true)) {
            this.helper.getSynchronizationContext().schedule(this::updateBalancingState, 1L, TimeUnit.SECONDS, this.helper.getScheduledExecutorService());
        }
    }

    private void updateBalancingState() {
        ConnectivityState newState = this.determineChannelState();
        if (newState != ConnectivityState.SHUTDOWN) {
            log.debug("Update balancing state to {}", (Object)newState);
            this.currentPicker.refresh(this.serverChannels);
            this.helper.updateBalancingState(newState, (LoadBalancer.SubchannelPicker)this.currentPicker);
            if (newState == ConnectivityState.READY || newState == ConnectivityState.TRANSIENT_FAILURE && this.currentServers.isEmpty()) {
                this.updateListener.onUpdate(new TenantAwareServerSelector(this.currentServers, this.currentServerGroupTags, this.currentTrafficDirective));
            }
        }
        this.balancingStateUpdateScheduled.set(false);
    }

    private ConnectivityState determineChannelState() {
        ConnectivityState connectivityState = ConnectivityState.READY;
        if (this.serverChannels.isEmpty() || this.serverChannels.values().stream().map(scList -> scList.subChannels).flatMap(Collection::stream).map(this::getSubChannelState).allMatch(state -> state.getState() == ConnectivityState.TRANSIENT_FAILURE)) {
            connectivityState = ConnectivityState.TRANSIENT_FAILURE;
        } else if (this.serverChannels.values().stream().map(scList -> scList.subChannels).flatMap(Collection::stream).map(this::getSubChannelState).allMatch(state -> state.getState() == ConnectivityState.SHUTDOWN)) {
            connectivityState = ConnectivityState.SHUTDOWN;
        } else if (this.serverChannels.values().stream().map(scList -> scList.subChannels).flatMap(Collection::stream).map(this::getSubChannelState).allMatch(state -> state.getState() != ConnectivityState.READY)) {
            connectivityState = ConnectivityState.CONNECTING;
        }
        return connectivityState;
    }

    private LoadBalancer.Subchannel setupSubchannel(String serverId, EquivalentAddressGroup equivalentAddressGroup, boolean inProc) {
        LoadBalancer.Subchannel subchannel = (LoadBalancer.Subchannel)Preconditions.checkNotNull((Object)this.helper.createSubchannel(LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(equivalentAddressGroup).setAttributes(Attributes.newBuilder().set(Constants.STATE_INFO, new AtomicReference<ConnectivityStateInfo>(ConnectivityStateInfo.forNonError((ConnectivityState)ConnectivityState.IDLE))).set(Constants.IN_PROC_SERVER_ATTR_KEY, (Object)inProc).set(Constants.SERVER_ID_ATTR_KEY, (Object)serverId).build()).build()), (Object)"subchannel");
        subchannel.start(state -> this.handleSubchannelStateChange(subchannel, state));
        subchannel.requestConnection();
        return subchannel;
    }

    private void handleSubchannelStateChange(LoadBalancer.Subchannel subchannel, ConnectivityStateInfo state) {
        this.updateSubChannelState(subchannel, state);
        this.scheduleBalancingStateUpdate();
    }

    private void shutdownSubChannel(LoadBalancer.Subchannel subchannel) {
        log.trace("Shutdown sub-channel: {}", (Object)subchannel);
        subchannel.shutdown();
        this.updateSubChannelState(subchannel, ConnectivityStateInfo.forNonError((ConnectivityState)ConnectivityState.SHUTDOWN));
    }

    private void updateSubChannelState(LoadBalancer.Subchannel subchannel, ConnectivityStateInfo state) {
        log.trace("Sub-channel[{}] state change to {}", (Object)subchannel, (Object)state);
        ((AtomicReference)subchannel.getAttributes().get(Constants.STATE_INFO)).set(state);
        if (state.getState() == ConnectivityState.IDLE) {
            subchannel.requestConnection();
        }
    }
}

