/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.yarn.ropt;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.yarn.ropt.GridEnumeration;

public class GridEnumerationMemory
extends GridEnumeration {
    public static final int DEFAULT_NSTEPS = 20;
    public static final int DEFAULT_MEM_ADD = 0x100000;
    private int _nsteps = 20;

    public GridEnumerationMemory(ArrayList<ProgramBlock> prog, long min, long max) {
        super(prog, min, max);
    }

    public void setNumSteps(int steps) {
        this._nsteps = steps;
    }

    @Override
    public ArrayList<Long> enumerateGridPoints() {
        ArrayList<Long> ret = new ArrayList<Long>();
        long gap = (this._max - this._min) / (long)(this._nsteps - 1);
        ArrayList<Long> mem = new ArrayList<Long>();
        this.getMemoryEstimates(this._prog, mem);
        HashSet<Long> preRet = new HashSet<Long>();
        for (Long val : mem) {
            if (val < this._min) {
                preRet.add(this._min);
                continue;
            }
            if (val > this._max) {
                preRet.add(this._max);
                continue;
            }
            long bin = Math.max((val - this._min) / gap, 0L);
            preRet.add(this.filterMax(this._min + bin * gap));
            preRet.add(this.filterMax(this._min + (bin + 1L) * gap));
        }
        for (Long val : preRet) {
            ret.add(val);
        }
        Collections.sort(ret);
        return ret;
    }

    private long filterMax(long val) {
        if (val > this._max) {
            return this._max;
        }
        return val;
    }

    private void getMemoryEstimates(ArrayList<ProgramBlock> pbs, ArrayList<Long> mem) {
        for (ProgramBlock pb : pbs) {
            this.getMemoryEstimates(pb, mem);
        }
    }

    private void getMemoryEstimates(ProgramBlock pb, ArrayList<Long> mem) {
        if (pb instanceof FunctionProgramBlock) {
            FunctionProgramBlock fpb = (FunctionProgramBlock)pb;
            this.getMemoryEstimates(fpb.getChildBlocks(), mem);
        } else if (pb instanceof WhileProgramBlock) {
            WhileProgramBlock fpb = (WhileProgramBlock)pb;
            this.getMemoryEstimates(fpb.getChildBlocks(), mem);
        } else if (pb instanceof IfProgramBlock) {
            IfProgramBlock fpb = (IfProgramBlock)pb;
            this.getMemoryEstimates(fpb.getChildBlocksIfBody(), mem);
            this.getMemoryEstimates(fpb.getChildBlocksElseBody(), mem);
        } else if (pb instanceof ForProgramBlock) {
            ForProgramBlock fpb = (ForProgramBlock)pb;
            this.getMemoryEstimates(fpb.getChildBlocks(), mem);
        } else {
            StatementBlock sb = pb.getStatementBlock();
            if (sb != null && sb.getHops() != null) {
                Hop.resetVisitStatus(sb.getHops());
                for (Hop hop : sb.getHops()) {
                    this.getMemoryEstimates(hop, mem);
                }
            }
        }
    }

    private void getMemoryEstimates(Hop hop, ArrayList<Long> mem) {
        if (hop.isVisited()) {
            return;
        }
        for (Hop hi : hop.getInput()) {
            this.getMemoryEstimates(hi, mem);
        }
        mem.add((long)((hop.getMemEstimate() + 1048576.0) / OptimizerUtils.MEM_UTIL_FACTOR));
        hop.setVisited();
    }
}

