/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.hashtable;

import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.compression.BlockCompressionFactory;
import org.apache.flink.runtime.io.disk.ChannelReaderInputViewIterator;
import org.apache.flink.runtime.io.disk.iomanager.FileIOChannel;
import org.apache.flink.runtime.io.disk.iomanager.HeaderlessChannelReaderInputView;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.runtime.hashtable.BaseHybridHashTable;
import org.apache.flink.table.runtime.hashtable.LongHashPartition;
import org.apache.flink.table.runtime.hashtable.ProbeIterator;
import org.apache.flink.table.runtime.io.ChannelWithMeta;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
import org.apache.flink.table.runtime.util.FileChannelUtil;
import org.apache.flink.util.MathUtils;

public abstract class LongHybridHashTable
extends BaseHybridHashTable {
    private final BinaryRowDataSerializer buildSideSerializer;
    private final BinaryRowDataSerializer probeSideSerializer;
    private final ArrayList<LongHashPartition> partitionsBeingBuilt;
    private final ArrayList<LongHashPartition> partitionsPending;
    private ProbeIterator probeIterator;
    private LongHashPartition.MatchIterator matchIterator;
    private boolean denseMode = false;
    private long minKey;
    private long maxKey;
    private MemorySegment[] denseBuckets;
    private LongHashPartition densePartition;

    public LongHybridHashTable(Configuration conf, Object owner, BinaryRowDataSerializer buildSideSerializer, BinaryRowDataSerializer probeSideSerializer, MemoryManager memManager, long reservedMemorySize, IOManager ioManager, int avgRecordLen, long buildRowCount) {
        super(conf, owner, memManager, reservedMemorySize, ioManager, avgRecordLen, buildRowCount, false);
        this.buildSideSerializer = buildSideSerializer;
        this.probeSideSerializer = probeSideSerializer;
        this.partitionsBeingBuilt = new ArrayList();
        this.partitionsPending = new ArrayList();
        this.createPartitions(this.initPartitionFanOut, 0);
    }

    public void putBuildRow(BinaryRowData row) throws IOException {
        long key = this.getBuildLongKey(row);
        int hashCode = LongHybridHashTable.hashLong(key, 0);
        this.insertIntoTable(key, hashCode, row);
    }

    public void endBuild() throws IOException {
        int buildWriteBuffers = 0;
        for (LongHashPartition p : this.partitionsBeingBuilt) {
            buildWriteBuffers += p.finalizeBuildPhase(this.ioManager, this.currentEnumerator);
        }
        this.buildSpillRetBufferNumbers += buildWriteBuffers;
        this.probeIterator = new ProbeIterator(this.probeSideSerializer.createInstance());
        this.tryDenseMode();
    }

    public boolean tryProbe(RowData record) throws IOException {
        int hash;
        LongHashPartition p;
        long probeKey = this.getProbeLongKey(record);
        if (this.denseMode) {
            this.probeIterator.setInstance(record);
            if (probeKey >= this.minKey && probeKey <= this.maxKey) {
                long denseBucket = probeKey - this.minKey;
                long denseBucketOffset = denseBucket << 3;
                int denseSegIndex = (int)(denseBucketOffset >>> this.segmentSizeBits);
                int denseSegOffset = (int)(denseBucketOffset & (long)this.segmentSizeMask);
                long address = this.denseBuckets[denseSegIndex].getLong(denseSegOffset);
                this.matchIterator = this.densePartition.valueIter(address);
            } else {
                this.matchIterator = this.densePartition.valueIter(0xFFFFFFFFFL);
            }
            return true;
        }
        if (!this.probeIterator.hasSource()) {
            this.probeIterator.setInstance(record);
        }
        if ((p = this.partitionsBeingBuilt.get((hash = LongHybridHashTable.hashLong(probeKey, this.currentRecursionDepth)) % this.partitionsBeingBuilt.size())).isInMemory()) {
            this.matchIterator = p.get(probeKey, hash);
            return true;
        }
        p.insertIntoProbeBuffer(this.probeSideSerializer, this.probeToBinary(record));
        return false;
    }

    public boolean nextMatching() throws IOException {
        return !this.denseMode && (this.processProbeIter() || this.prepareNextPartition());
    }

    public RowData getCurrentProbeRow() {
        return this.probeIterator.current();
    }

    public LongHashPartition.MatchIterator getBuildSideIterator() {
        return this.matchIterator;
    }

    @Override
    public void close() {
        if (this.denseMode) {
            this.closed.compareAndSet(false, true);
        } else {
            super.close();
        }
    }

    @Override
    public void free() {
        if (this.denseMode) {
            this.returnAll(Arrays.asList(this.denseBuckets));
            this.returnAll(Arrays.asList(this.densePartition.getPartitionBuffers()));
        }
        super.free();
    }

    private void tryDenseMode() {
        if (this.numSpillFiles != 0L) {
            return;
        }
        long minKey = Long.MAX_VALUE;
        long maxKey = Long.MIN_VALUE;
        long recordCount = 0L;
        for (LongHashPartition p : this.partitionsBeingBuilt) {
            long partitionRecords = p.getBuildSideRecordCount();
            recordCount += partitionRecords;
            if (partitionRecords <= 0L) continue;
            if (p.getMinKey() < minKey) {
                minKey = p.getMinKey();
            }
            if (p.getMaxKey() <= maxKey) continue;
            maxKey = p.getMaxKey();
        }
        if (this.buildSpillRetBufferNumbers != 0) {
            throw new RuntimeException("buildSpillRetBufferNumbers should be 0: " + this.buildSpillRetBufferNumbers);
        }
        long range = maxKey - minKey + 1L;
        if (range > 0L && (range <= recordCount * 4L || range <= (long)(this.segmentSize / 8))) {
            int buffers = (int)Math.ceil((double)(range * 8L) / (double)this.segmentSize);
            MemorySegment[] denseBuckets = new MemorySegment[buffers];
            for (int i = 0; i < buffers; ++i) {
                MemorySegment seg = this.getNextBuffer();
                if (seg == null) {
                    this.returnAll(Arrays.asList(denseBuckets));
                    return;
                }
                denseBuckets[i] = seg;
                for (int j = 0; j < this.segmentSize; j += 8) {
                    seg.putLong(j, 0xFFFFFFFFFL);
                }
            }
            this.denseMode = true;
            LOG.info("LongHybridHashTable: Use dense mode!");
            this.minKey = minKey;
            this.maxKey = maxKey;
            ArrayList<MemorySegment> segments = new ArrayList<MemorySegment>();
            this.buildSpillReturnBuffers.drainTo(segments);
            this.returnAll(segments);
            ArrayList<MemorySegment> dataBuffers = new ArrayList<MemorySegment>();
            long addressOffset = 0L;
            for (LongHashPartition p : this.partitionsBeingBuilt) {
                p.iteratorToDenseBucket(denseBuckets, addressOffset, minKey);
                p.updateDenseAddressOffset(addressOffset);
                dataBuffers.addAll(Arrays.asList(p.getPartitionBuffers()));
                addressOffset += (long)(p.getPartitionBuffers().length << this.segmentSizeBits);
                this.returnAll(Arrays.asList(p.getBuckets()));
            }
            this.denseBuckets = denseBuckets;
            this.densePartition = new LongHashPartition(this, this.buildSideSerializer, dataBuffers.toArray(new MemorySegment[0]));
            this.freeCurrent();
        }
    }

    private void createPartitions(int numPartitions, int recursionLevel) {
        this.ensureNumBuffersReturned(numPartitions);
        this.currentEnumerator = this.ioManager.createChannelEnumerator();
        this.partitionsBeingBuilt.clear();
        double numRecordPerPartition = (double)this.buildRowCount / (double)numPartitions;
        int maxBuffer = this.maxInitBufferOfBucketArea(numPartitions);
        for (int i = 0; i < numPartitions; ++i) {
            LongHashPartition p = new LongHashPartition(this, i, this.buildSideSerializer, numRecordPerPartition, maxBuffer, recursionLevel);
            this.partitionsBeingBuilt.add(p);
        }
    }

    public abstract long getBuildLongKey(RowData var1);

    public abstract long getProbeLongKey(RowData var1);

    public abstract BinaryRowData probeToBinary(RowData var1);

    private void insertIntoTable(long key, int hashCode, BinaryRowData row) throws IOException {
        LongHashPartition p = this.partitionsBeingBuilt.get(hashCode % this.partitionsBeingBuilt.size());
        p.insertIntoTable(key, hashCode, row);
    }

    static int hashLong(long key, int level) {
        long h = key * 2654435769L;
        int hash = (int)(h ^ h >> 32);
        return BaseHybridHashTable.hash(hash, level);
    }

    private boolean processProbeIter() throws IOException {
        if (this.probeIterator.hasSource()) {
            BinaryRowData next;
            ProbeIterator probeIter = this.probeIterator;
            while ((next = probeIter.next()) != null) {
                long probeKey = this.getProbeLongKey(next);
                int hash = LongHybridHashTable.hashLong(probeKey, this.currentRecursionDepth);
                LongHashPartition p = this.partitionsBeingBuilt.get(hash % this.partitionsBeingBuilt.size());
                if (p.isInMemory()) {
                    this.matchIterator = p.get(probeKey, hash);
                    return true;
                }
                p.insertIntoProbeBuffer(this.probeSideSerializer, next);
            }
            return false;
        }
        return false;
    }

    private boolean prepareNextPartition() throws IOException {
        for (LongHashPartition p : this.partitionsBeingBuilt) {
            p.finalizeProbePhase(this.partitionsPending);
        }
        this.partitionsBeingBuilt.clear();
        if (this.currentSpilledProbeSide != null) {
            this.currentSpilledProbeSide.getChannel().closeAndDelete();
            this.currentSpilledProbeSide = null;
        }
        if (this.partitionsPending.isEmpty()) {
            return false;
        }
        LongHashPartition p = this.partitionsPending.get(0);
        LOG.info(String.format("Begin to process spilled partition [%d]", p.getPartitionNumber()));
        if (p.probeSideRecordCounter == 0L) {
            this.partitionsPending.remove(0);
            return this.prepareNextPartition();
        }
        this.buildTableFromSpilledPartition(p);
        ChannelWithMeta channelWithMeta = new ChannelWithMeta(p.probeSideBuffer.getChannel().getChannelID(), p.probeSideBuffer.getBlockCount(), p.probeNumBytesInLastSeg);
        this.currentSpilledProbeSide = FileChannelUtil.createInputView(this.ioManager, channelWithMeta, new ArrayList<FileIOChannel>(), this.compressionEnable, this.compressionCodecFactory, this.compressionBlockSize, this.segmentSize);
        ChannelReaderInputViewIterator<BinaryRowData> probeReader = new ChannelReaderInputViewIterator<BinaryRowData>(this.currentSpilledProbeSide, new ArrayList<MemorySegment>(), this.probeSideSerializer);
        this.probeIterator.set(probeReader);
        this.probeIterator.setReuse(this.probeSideSerializer.createInstance());
        this.partitionsPending.remove(0);
        this.currentRecursionDepth = p.getRecursionLevel() + 1;
        return this.nextMatching();
    }

    private void buildTableFromSpilledPartition(LongHashPartition p) throws IOException {
        int totalBuffersAvailable;
        int nextRecursionLevel = p.getRecursionLevel() + 1;
        if (nextRecursionLevel == 2) {
            LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
        } else if (nextRecursionLevel > 3) {
            throw new RuntimeException("Hash join exceeded maximum number of recursions, without reducing partitions enough to be memory resident. Probably cause: Too many duplicate keys.");
        }
        if (p.getBuildSideBlockCount() > p.getProbeSideBlockCount()) {
            LOG.info(String.format("Hash join: Partition(%d) build side block [%d] more than probe side block [%d]", p.getPartitionNumber(), p.getBuildSideBlockCount(), p.getProbeSideBlockCount()));
        }
        if ((totalBuffersAvailable = this.internalPool.freePages() + this.buildSpillRetBufferNumbers) != this.totalNumBuffers) {
            throw new RuntimeException(String.format("Hash Join bug in memory management: Memory buffers leaked. availableMemory(%s), buildSpillRetBufferNumbers(%s), reservedNumBuffers(%s)", this.internalPool.freePages(), this.buildSpillRetBufferNumbers, this.totalNumBuffers));
        }
        int maxBucketAreaBuffers = MathUtils.roundUpToPowerOfTwo((int)Math.max(1.0, Math.ceil(Math.ceil((double)p.getBuildSideRecordCount() / 0.5) * 16.0 / (double)this.segmentSize)));
        long totalBuffersNeeded = maxBucketAreaBuffers + p.getBuildSideBlockCount() + 2;
        if (totalBuffersNeeded < (long)totalBuffersAvailable) {
            LOG.info(String.format("Build in memory hash table from spilled partition [%d]", p.getPartitionNumber()));
            List<MemorySegment> partitionBuffers = this.readAllBuffers(p.getBuildSideChannel().getChannelID(), p.getBuildSideBlockCount());
            LongHashPartition newPart = new LongHashPartition(this, 0, this.buildSideSerializer, maxBucketAreaBuffers, nextRecursionLevel, partitionBuffers, p.getLastSegmentLimit());
            this.partitionsBeingBuilt.add(newPart);
            LongHashPartition.PartitionIterator pIter = newPart.newPartitionIterator();
            while (pIter.advanceNext()) {
                long key = this.getBuildLongKey(pIter.getRow());
                int hashCode = LongHybridHashTable.hashLong(key, nextRecursionLevel);
                int pointer = (int)pIter.getPointer();
                newPart.insertIntoBucket(key, hashCode, pIter.getRow().getSizeInBytes(), pointer);
            }
        } else {
            int splits = (int)(totalBuffersNeeded / (long)totalBuffersAvailable) + 1;
            int partitionFanOut = Math.min(Math.min(10 * splits, 127), this.maxNumPartition());
            this.createPartitions(partitionFanOut, nextRecursionLevel);
            LOG.info(String.format("Build hybrid hash table from spilled partition [%d] with recursion level [%d]", p.getPartitionNumber(), nextRecursionLevel));
            HeaderlessChannelReaderInputView inView = this.createInputView(p.getBuildSideChannel().getChannelID(), p.getBuildSideBlockCount(), p.getLastSegmentLimit());
            BinaryRowData rec = this.buildSideSerializer.createInstance();
            try {
                while (true) {
                    LongHashPartition.deserializeFromPages(rec, inView, this.buildSideSerializer);
                    long key = this.getBuildLongKey(rec);
                    this.insertIntoTable(key, LongHybridHashTable.hashLong(key, nextRecursionLevel), rec);
                }
            }
            catch (EOFException e) {
                inView.getChannel().closeAndDelete();
                int buildWriteBuffers = 0;
                for (LongHashPartition part : this.partitionsBeingBuilt) {
                    buildWriteBuffers += part.finalizeBuildPhase(this.ioManager, this.currentEnumerator);
                }
                this.buildSpillRetBufferNumbers += buildWriteBuffers;
            }
        }
    }

    @Override
    public int spillPartition() throws IOException {
        MemorySegment currBuff;
        int largestNumBlocks = 0;
        int largestPartNum = -1;
        for (int i = 0; i < this.partitionsBeingBuilt.size(); ++i) {
            LongHashPartition p = this.partitionsBeingBuilt.get(i);
            if (!p.isInMemory() || p.getNumOccupiedMemorySegments() <= largestNumBlocks) continue;
            largestNumBlocks = p.getNumOccupiedMemorySegments();
            largestPartNum = i;
        }
        LongHashPartition p = this.partitionsBeingBuilt.get(largestPartNum);
        int numBuffersFreed = p.spillPartition(this.ioManager, this.currentEnumerator.next(), this.buildSpillReturnBuffers);
        p.releaseBuckets();
        this.buildSpillRetBufferNumbers += numBuffersFreed;
        LOG.info(String.format("Grace hash join: Ran out memory, choosing partition [%d] to spill, %d memory segments being freed", largestPartNum, numBuffersFreed));
        while (this.buildSpillRetBufferNumbers > 0 && (currBuff = (MemorySegment)this.buildSpillReturnBuffers.poll()) != null) {
            this.returnPage(currBuff);
            --this.buildSpillRetBufferNumbers;
        }
        ++this.numSpillFiles;
        this.spillInBytes += (long)(numBuffersFreed * this.segmentSize);
        return largestPartNum;
    }

    @Override
    protected void clearPartitions() {
        this.probeIterator = null;
        for (int i = this.partitionsBeingBuilt.size() - 1; i >= 0; --i) {
            LongHashPartition p = this.partitionsBeingBuilt.get(i);
            try {
                p.clearAllMemory(this.internalPool);
                continue;
            }
            catch (Exception e) {
                LOG.error("Error during partition cleanup.", (Throwable)e);
            }
        }
        this.partitionsBeingBuilt.clear();
        for (LongHashPartition p : this.partitionsPending) {
            p.clearAllMemory(this.internalPool);
        }
    }

    public boolean compressionEnable() {
        return this.compressionEnable;
    }

    public BlockCompressionFactory compressionCodecFactory() {
        return this.compressionCodecFactory;
    }

    public int compressionBlockSize() {
        return this.compressionBlockSize;
    }
}

