/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysml.runtime.matrix.data.DnnParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.DnnUtils;
import org.apache.sysml.utils.NativeHelper;

public class DnnCPInstruction
extends UnaryCPInstruction {
    private static final Log LOG = LogFactory.getLog(DnnCPInstruction.class.getName());
    private static boolean warnedUnderUtilitization = false;
    private final CPOperand _in2;
    private final CPOperand _in3;
    private final CPOperand _in4;
    private final CPOperand _in5;
    private final CPOperand _in6;
    private final CPOperand _in7;
    private final CPOperand _in8;
    private final CPOperand _out2;
    private final CPOperand _out3;
    private final CPOperand _out4;
    private final CPOperand _out5;
    private final ArrayList<CPOperand> _input_shape;
    private final ArrayList<CPOperand> _filter_shape;
    private final ArrayList<CPOperand> _stride;
    private final ArrayList<CPOperand> _padding;
    private final int _numThreads;
    private final double _intermediateMemoryBudget;

    public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget, String opcode, String istr) {
        super(CPInstruction.CPType.Dnn, null, in, out, opcode, istr);
        this._in2 = in2;
        this._in3 = in3;
        this._in4 = null;
        this._in5 = null;
        this._in6 = null;
        this._in7 = null;
        this._in8 = null;
        this._out2 = null;
        this._out3 = null;
        this._out4 = null;
        this._out5 = null;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
        this._numThreads = numThreads;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) {
        this(in, in2, null, out, null, null, null, null, numThreads, intermediateMemoryBudget, opcode, istr);
        if (!(opcode.equals("bias_add") || opcode.equals("relu_backward") || opcode.equals("bias_multiply"))) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode);
        }
    }

    private DnnCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
        this(in, null, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
    }

    public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
        this(in, in2, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
    }

    public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
        this(in, in2, in3, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
    }

    public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, CPOperand in7, CPOperand in8, CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException {
        super(CPInstruction.CPType.Dnn, null, in1, out, opcode, istr);
        this._in2 = in2;
        this._in3 = in3;
        this._in4 = in4;
        this._in5 = in5;
        this._in6 = in6;
        this._in7 = in7;
        this._in8 = in8;
        this._out2 = out2;
        this._out3 = out3;
        this._out4 = out4;
        this._out5 = out5;
        this._stride = null;
        this._padding = null;
        this._input_shape = null;
        this._filter_shape = null;
        this._numThreads = 0;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public static DnnCPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling") || opcode.equalsIgnoreCase("avgpooling")) {
            InstructionUtils.checkNumFields(parts, 16);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand out = new CPOperand(parts[14]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[2]));
            stride.add(new CPOperand(parts[3]));
            padding.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            input_shape.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            filter_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            int k = Integer.parseInt(parts[15]);
            return new DnnCPInstruction(in, out, opcode, str, stride, padding, input_shape, filter_shape, k, Double.parseDouble(parts[16]));
        }
        if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("relu_maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") || opcode.equalsIgnoreCase("conv2d") || opcode.equalsIgnoreCase("conv2d_backward_filter") || opcode.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(parts, 17);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[15]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[3]));
            stride.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            int k = Integer.parseInt(parts[16]);
            return new DnnCPInstruction(in, in2, out, opcode, str, stride, padding, input_shape, filter_shape, k, Double.parseDouble(parts[17]));
        }
        if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
            InstructionUtils.checkNumFields(parts, 18);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[16]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[4]));
            stride.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            padding.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            input_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            filter_shape.add(new CPOperand(parts[15]));
            int k = Integer.parseInt(parts[17]);
            return new DnnCPInstruction(in, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape, k, Double.parseDouble(parts[18]));
        }
        if (opcode.equalsIgnoreCase("bias_add") || opcode.equals("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")) {
            InstructionUtils.checkNumFields(parts, 5);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            int k = Integer.parseInt(parts[4]);
            return new DnnCPInstruction(in, in2, out, opcode, str, k, Double.parseDouble(parts[5]));
        }
        if (opcode.equalsIgnoreCase("batch_norm2d")) {
            InstructionUtils.checkNumFields(parts, 13);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand in5 = new CPOperand(parts[5]);
            CPOperand in6 = new CPOperand(parts[6]);
            CPOperand in7 = new CPOperand(parts[7]);
            CPOperand in8 = new CPOperand(parts[8]);
            CPOperand out = new CPOperand(parts[9]);
            CPOperand out2 = new CPOperand(parts[10]);
            CPOperand out3 = new CPOperand(parts[11]);
            CPOperand out4 = new CPOperand(parts[12]);
            CPOperand out5 = new CPOperand(parts[13]);
            return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0.0);
        }
        if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
            InstructionUtils.checkNumFields(parts, 9);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand in5 = new CPOperand(parts[5]);
            CPOperand in6 = new CPOperand(parts[6]);
            CPOperand out = new CPOperand(parts[7]);
            CPOperand out2 = new CPOperand(parts[8]);
            CPOperand out3 = new CPOperand(parts[9]);
            return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0.0);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
    }

    private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) {
        return (int)ec.getScalarInput(aL.get(index).getName(), aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue();
    }

    public void processReluBackwardInstruction(ExecutionContext ec) {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), input.isInSparseFormat() || dout.isInSparseFormat());
        if (!input.isEmpty() && !dout.isEmpty()) {
            outputBlock.allocateBlock();
            LibMatrixDNN.reluBackward(input, dout, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    public void processBiasAddInstruction(ExecutionContext ec) {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock bias = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock outputBlock = null;
        if (bias.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
        }
        if (input.isEmpty() && bias.isEmpty()) {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
        } else if (bias.isEmpty()) {
            outputBlock = new MatrixBlock(input);
        } else {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false);
            outputBlock.allocateDenseBlock();
            LibMatrixDNN.biasAdd(input, bias, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    public void processBiasMultiplyInstruction(ExecutionContext ec) {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock bias = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock outputBlock = null;
        if (bias.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
        }
        if (bias.isEmpty()) {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
        } else {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), input.isInSparseFormat()).allocateBlock();
            LibMatrixDNN.biasMultiply(input, bias, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    public void processBatchNorm2dInstruction(ExecutionContext ec) {
        MatrixBlock image = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock scale = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock bias = ec.getMatrixInput(this._in3.getName(), this.getExtendedOpcode());
        MatrixBlock runningMean = ec.getMatrixInput(this._in4.getName(), this.getExtendedOpcode());
        MatrixBlock runningVar = ec.getMatrixInput(this._in5.getName(), this.getExtendedOpcode());
        String phase = ec.getScalarInput(this._in6.getName(), this._in6.getValueType(), this._in6.isLiteral()).getStringValue();
        double epsilon = ec.getScalarInput(this._in7.getName(), this._in7.getValueType(), this._in7.isLiteral()).getDoubleValue();
        double mu = ec.getScalarInput(this._in8.getName(), this._in8.getValueType(), this._in8.isLiteral()).getDoubleValue();
        MatrixBlock ret = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock();
        MatrixBlock retRunningMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock();
        MatrixBlock retRunningVar = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock();
        MatrixBlock resultSaveMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock();
        MatrixBlock resultSaveInvVariance = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock();
        LibMatrixDNN.batchNorm2D(image, scale, bias, runningMean, runningVar, phase, epsilon, mu, ret, retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance);
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in3.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in4.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in5.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.output.getName(), ret, this.getExtendedOpcode());
        ec.setMatrixOutput(this._out2.getName(), retRunningMean, this.getExtendedOpcode());
        ec.setMatrixOutput(this._out3.getName(), retRunningVar, this.getExtendedOpcode());
        ec.setMatrixOutput(this._out4.getName(), resultSaveMean, this.getExtendedOpcode());
        ec.setMatrixOutput(this._out5.getName(), resultSaveInvVariance, this.getExtendedOpcode());
    }

    public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) {
        MatrixBlock image = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock scale = ec.getMatrixInput(this._in3.getName(), this.getExtendedOpcode());
        double epsilon = ec.getScalarInput(this._in4.getName(), this._in4.getValueType(), this._in4.isLiteral()).getDoubleValue();
        MatrixBlock resultSaveMean = ec.getMatrixInput(this._in5.getName(), this.getExtendedOpcode());
        MatrixBlock resultSaveInvVariance = ec.getMatrixInput(this._in6.getName(), this.getExtendedOpcode());
        MatrixBlock dX = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock();
        MatrixBlock dScale = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock();
        MatrixBlock dBias = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock();
        LibMatrixDNN.batchNorm2DBackward(image, dout, scale, epsilon, resultSaveMean, resultSaveInvVariance, dX, dScale, dBias);
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in3.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in5.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in6.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.output.getName(), dX, this.getExtendedOpcode());
        ec.setMatrixOutput(this._out2.getName(), dScale, this.getExtendedOpcode());
        ec.setMatrixOutput(this._out3.getName(), dBias, this.getExtendedOpcode());
    }

    private static boolean isFilterSparse(MatrixBlock filter) {
        long numElems = filter.getNumRows() * filter.getNumColumns();
        if (filter.isInSparseFormat() && (double)numElems < 1.0E7) {
            filter.sparseToDense();
        }
        return filter.isInSparseFormat();
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        if (this.instOpcode.equalsIgnoreCase("bias_add")) {
            this.processBiasAddInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("bias_multiply")) {
            this.processBiasMultiplyInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("relu_backward")) {
            this.processReluBackwardInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d")) {
            this.processBatchNorm2dInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
            this.processBatchNorm2dBackwardInstruction(ec);
            return;
        }
        MatrixBlock outputBlock = null;
        MatrixBlock matBlock = this.instOpcode.equalsIgnoreCase("avgpooling_backward") ? null : ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        int pad_h = DnnCPInstruction.getScalarInput(ec, this._padding, 0);
        int pad_w = DnnCPInstruction.getScalarInput(ec, this._padding, 1);
        int stride_h = DnnCPInstruction.getScalarInput(ec, this._stride, 0);
        int stride_w = DnnCPInstruction.getScalarInput(ec, this._stride, 1);
        int N = DnnCPInstruction.getScalarInput(ec, this._input_shape, 0);
        int C = DnnCPInstruction.getScalarInput(ec, this._input_shape, 1);
        int H = DnnCPInstruction.getScalarInput(ec, this._input_shape, 2);
        int W = DnnCPInstruction.getScalarInput(ec, this._input_shape, 3);
        int K = DnnCPInstruction.getScalarInput(ec, this._filter_shape, 0);
        int R = DnnCPInstruction.getScalarInput(ec, this._filter_shape, 2);
        int S = DnnCPInstruction.getScalarInput(ec, this._filter_shape, 3);
        int P = (int)DnnUtils.getP(H, R, stride_h, pad_h);
        int Q = (int)DnnUtils.getQ(W, S, stride_w, pad_w);
        DnnParameters params = new DnnParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, this._numThreads);
        params.enableNative = NativeHelper.isNativeLibraryLoaded();
        if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling") || this.instOpcode.equalsIgnoreCase("avgpooling")) {
            if (matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(N, C * P * Q, true);
            } else {
                LibMatrixDNN.PoolingType poolType;
                outputBlock = new MatrixBlock(N, C * P * Q, false).allocateBlock();
                LibMatrixDNN.PoolingType poolingType = poolType = this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
                if (this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
                    params.minValForMaxPoolOperations = 0.0;
                }
                LibMatrixDNN.pooling(matBlock, outputBlock, params, poolType);
            }
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward") || this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
            boolean isEmpty;
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            boolean bl = this.instOpcode.equalsIgnoreCase("avgpooling_backward") ? dout.isEmpty() : (isEmpty = matBlock.isEmpty() || dout.isEmpty());
            if (isEmpty) {
                outputBlock = new MatrixBlock(N, C * H * W, true);
            } else {
                outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock();
                LibMatrixDNN.PoolingType poolType = this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
                boolean performReLUBackward = this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward");
                if (performReLUBackward) {
                    params.minValForMaxPoolOperations = 0.0;
                }
                LibMatrixDNN.poolingBackward(matBlock, dout, outputBlock, params, performReLUBackward, poolType);
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            this.resetNumThreads(params, C * R * S, P * Q, matBlock.getNonZeros() / (long)(matBlock.getNumRows() * matBlock.getNumColumns()));
            MatrixBlock filter = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (filter.isEmpty() || matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(N, K * P * Q, true);
            } else {
                boolean sparse = matBlock.isUltraSparse(false) && params.bias == null && matBlock.getInMemorySize() < MatrixBlock.estimateSizeDenseInMemory(N, K * P * Q);
                outputBlock = new MatrixBlock(N, K * P * Q, sparse).allocateBlock();
                if (params.enableNative && !DnnCPInstruction.isFilterSparse(filter) && !matBlock.isInSparseFormat()) {
                    LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            boolean isOutputConvEmpty;
            this.resetNumThreads(params, C * R * S, P * Q, matBlock.getNonZeros() / (long)(matBlock.getNumRows() * matBlock.getNumColumns()));
            MatrixBlock filter = ec.getMatrixInput(this._in3.getName(), this.getExtendedOpcode());
            MatrixBlock bias = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (bias.getNumRows() != params.K || bias.getNumColumns() != 1) {
                throw new DMLRuntimeException("Incorrect shape of bias matrix: [" + bias.getNumRows() + " " + bias.getNumColumns() + "]. Expected: [" + params.K + ", 1]");
            }
            boolean bl = isOutputConvEmpty = filter.isEmpty() || matBlock.isEmpty();
            if (isOutputConvEmpty && bias.isEmpty()) {
                outputBlock = new MatrixBlock(N, K * P * Q, true);
            } else if (isOutputConvEmpty && !bias.isEmpty()) {
                outputBlock = new MatrixBlock(N, K * P * Q, false).allocateBlock();
                for (int n = 0; n < params.N; ++n) {
                    DnnUtils.fillBias(bias, outputBlock.getDenseBlockValues(), n, n + 1, params.N, params.K, params.P * params.Q);
                }
            } else {
                outputBlock = new MatrixBlock(N, K * P * Q, false).allocateBlock();
                if (!bias.isEmpty()) {
                    params.bias = bias;
                }
                if (params.enableNative && !DnnCPInstruction.isFilterSparse(filter) && !matBlock.isInSparseFormat()) {
                    LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in3.getName(), this.getExtendedOpcode());
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (dout.isEmpty() || matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(K, C * R * S, true);
            } else {
                outputBlock = new MatrixBlock(K, C * R * S, false).allocateBlock();
                if (params.enableNative && !matBlock.isInSparseFormat() && !dout.isInSparseFormat()) {
                    LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (dout.isEmpty() || matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(N, C * H * W, true);
            } else {
                outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock();
                if (params.enableNative && !DnnCPInstruction.isFilterSparse(matBlock) && !dout.isInSparseFormat()) {
                    LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else {
            throw new DMLRuntimeException("Unsupported op code " + this.instOpcode);
        }
        if (!this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
            ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        }
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    private void resetNumThreads(DnnParameters params, int numRows, int numCols, double sparsity) {
        double memBudget1Thread;
        int limitedDegreeOfParallelism;
        if (DMLScript.USE_ACCELERATOR && params.numThreads > (limitedDegreeOfParallelism = (int)Math.floor(this._intermediateMemoryBudget / (memBudget1Thread = (double)OptimizerUtils.estimateSizeExactSparsity((long)numRows, (long)numCols, sparsity))))) {
            params.numThreads = limitedDegreeOfParallelism;
            if (!warnedUnderUtilitization) {
                LOG.warn("CPU Under-utilization to respect the intermediate memory budget. To avoid this, please try reducing the mini-batch or forcing gpu execution.");
            }
            warnedUnderUtilitization = true;
        }
    }
}

