/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.cost;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.cost.VarStats;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public abstract class CostEstimator {
    protected static final Log LOG = LogFactory.getLog((String)CostEstimator.class.getName());
    private static final int DEFAULT_NUMITER = 15;
    protected static final VarStats _unknownStats = new VarStats(1L, 1L, -1, -1L, false);
    protected static final VarStats _scalarStats = new VarStats(1L, 1L, 1, 1L, true);

    public double getTimeEstimate(Program rtprog, LocalVariableMap vars, HashMap<String, VarStats> stats) {
        double costs = 0.0;
        CostEstimator.maintainVariableStatistics(vars, stats);
        for (ProgramBlock pb : rtprog.getProgramBlocks()) {
            costs += this.rGetTimeEstimate(pb, stats, new HashSet<String>(), true);
        }
        return costs;
    }

    public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars, HashMap<String, VarStats> stats, boolean recursive) {
        CostEstimator.maintainVariableStatistics(vars, stats);
        return this.rGetTimeEstimate(pb, stats, new HashSet<String>(), recursive);
    }

    private double rGetTimeEstimate(ProgramBlock pb, HashMap<String, VarStats> stats, HashSet<String> memoFunc, boolean recursive) {
        double ret;
        block12: {
            block15: {
                block14: {
                    block13: {
                        block11: {
                            ret = 0.0;
                            if (!(pb instanceof WhileProgramBlock)) break block11;
                            WhileProgramBlock tmp = (WhileProgramBlock)pb;
                            if (recursive) {
                                for (ProgramBlock pb2 : tmp.getChildBlocks()) {
                                    ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                                }
                            }
                            ret *= 15.0;
                            break block12;
                        }
                        if (!(pb instanceof IfProgramBlock)) break block13;
                        IfProgramBlock tmp = (IfProgramBlock)pb;
                        if (!recursive) break block12;
                        for (ProgramBlock pb2 : tmp.getChildBlocksIfBody()) {
                            ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                        }
                        if (tmp.getChildBlocksElseBody() != null) {
                            for (ProgramBlock pb2 : tmp.getChildBlocksElseBody()) {
                                ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                                ret /= 2.0;
                            }
                        }
                        break block12;
                    }
                    if (!(pb instanceof ForProgramBlock)) break block14;
                    ForProgramBlock tmp = (ForProgramBlock)pb;
                    if (recursive) {
                        for (ProgramBlock pb2 : tmp.getChildBlocks()) {
                            ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                        }
                    }
                    ret *= (double)CostEstimator.getNumIterations(tmp);
                    break block12;
                }
                if (!(pb instanceof FunctionProgramBlock)) break block15;
                FunctionProgramBlock tmp = (FunctionProgramBlock)pb;
                if (!recursive) break block12;
                for (ProgramBlock pb2 : tmp.getChildBlocks()) {
                    ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                }
                break block12;
            }
            if (pb instanceof BasicProgramBlock) {
                BasicProgramBlock bpb = (BasicProgramBlock)pb;
                ArrayList<Instruction> tmp = bpb.getInstructions();
                for (Instruction inst : tmp) {
                    FunctionCallCPInstruction finst;
                    String fkey;
                    if (!(inst instanceof CPInstruction)) continue;
                    CostEstimator.maintainCPInstVariableStatistics((CPInstruction)inst, stats);
                    Object[] o = CostEstimator.extractCPInstStatistics(inst, stats);
                    VarStats[] vs = (VarStats[])o[0];
                    String[] attr = (String[])o[1];
                    ret += this.getCPInstTimeEstimate(inst, vs, attr);
                    if (!(inst instanceof FunctionCallCPInstruction) || memoFunc.contains(fkey = DMLProgram.constructFunctionKey((finst = (FunctionCallCPInstruction)inst).getNamespace(), finst.getFunctionName())) || pb.getProgram() == null) continue;
                    if (LOG.isDebugEnabled()) {
                        LOG.debug((Object)("Begin Function " + fkey));
                    }
                    memoFunc.add(fkey);
                    Program prog = pb.getProgram();
                    FunctionProgramBlock fpb = prog.getFunctionProgramBlock(finst.getNamespace(), finst.getFunctionName());
                    ret += this.rGetTimeEstimate(fpb, stats, memoFunc, recursive);
                    memoFunc.remove(fkey);
                    if (!LOG.isDebugEnabled()) continue;
                    LOG.debug((Object)("End Function " + fkey));
                }
            }
        }
        return ret;
    }

    private static void maintainVariableStatistics(LocalVariableMap vars, HashMap<String, VarStats> stats) {
        for (String varname : vars.keySet()) {
            Data dat = vars.get(varname);
            VarStats vs = null;
            if (dat instanceof MatrixObject) {
                MatrixObject mo = (MatrixObject)dat;
                DataCharacteristics dc = mo.getDataCharacteristics();
                long rlen = dc.getRows();
                long clen = dc.getCols();
                int blen = dc.getBlocksize();
                long nnz = dc.getNonZeros();
                boolean inmem = mo.getStatus() == CacheableData.CacheStatus.CACHED;
                vs = new VarStats(rlen, clen, blen, nnz, inmem);
            } else {
                vs = _scalarStats;
            }
            stats.put(varname, vs);
        }
    }

    private static void maintainCPInstVariableStatistics(CPInstruction inst, HashMap<String, VarStats> stats) {
        block9: {
            block7: {
                String[] parts;
                String optype;
                block11: {
                    block10: {
                        block8: {
                            if (!(inst instanceof VariableCPInstruction)) break block7;
                            optype = inst.getOpcode();
                            parts = InstructionUtils.getInstructionParts(inst.toString());
                            if (!optype.equals("createvar")) break block8;
                            if (parts.length < 10) {
                                return;
                            }
                            String varname = parts[1];
                            long rlen = Long.parseLong(parts[6]);
                            long clen = Long.parseLong(parts[7]);
                            int blen = Integer.parseInt(parts[8]);
                            long nnz = Long.parseLong(parts[9]);
                            VarStats vs = new VarStats(rlen, clen, blen, nnz, false);
                            stats.put(varname, vs);
                            break block9;
                        }
                        if (!optype.equals("cpvar")) break block10;
                        String varname = parts[1];
                        String varname2 = parts[2];
                        VarStats vs = stats.get(varname);
                        stats.put(varname2, vs);
                        break block9;
                    }
                    if (!optype.equals("mvvar")) break block11;
                    String varname = parts[1];
                    String varname2 = parts[2];
                    VarStats vs = stats.remove(varname);
                    stats.put(varname2, vs);
                    break block9;
                }
                if (!optype.equals("rmvar")) break block9;
                String varname = parts[1];
                stats.remove(varname);
                break block9;
            }
            if (inst instanceof DataGenCPInstruction) {
                DataGenCPInstruction randInst = (DataGenCPInstruction)inst;
                String varname = randInst.output.getName();
                long rlen = randInst.getRows();
                long clen = randInst.getCols();
                int blen = randInst.getBlocksize();
                long nnz = (long)(randInst.getSparsity() * (double)rlen * (double)clen);
                VarStats vs = new VarStats(rlen, clen, blen, nnz, true);
                stats.put(varname, vs);
            } else if (inst instanceof StringInitCPInstruction) {
                StringInitCPInstruction iinst = (StringInitCPInstruction)inst;
                String varname = iinst.output.getName();
                long rlen = iinst.getRows();
                long clen = iinst.getCols();
                VarStats vs = new VarStats(rlen, clen, ConfigurationManager.getBlocksize(), rlen * clen, true);
                stats.put(varname, vs);
            } else if (inst instanceof FunctionCallCPInstruction) {
                FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst;
                for (String varname : finst.getBoundOutputParamNames()) {
                    stats.put(varname, _unknownStats);
                }
            }
        }
    }

    protected String replaceInstructionPatch(String inst) {
        String ret = inst;
        while (ret.contains("\u00b6")) {
            int index1 = ret.indexOf("\u00b6");
            int index2 = ret.indexOf("\u00b6", index1 + 1);
            String replace = ret.substring(index1, index2 + 1);
            ret = ret.replaceAll(replace, "1");
        }
        return ret;
    }

    private static Object[] extractCPInstStatistics(Instruction inst, HashMap<String, VarStats> stats) {
        Object[] ret = new Object[2];
        VarStats[] vs = new VarStats[3];
        String[] attr = null;
        if (inst instanceof UnaryCPInstruction) {
            if (inst instanceof DataGenCPInstruction) {
                DataGenCPInstruction rinst = (DataGenCPInstruction)inst;
                vs[0] = _unknownStats;
                vs[1] = _unknownStats;
                vs[2] = stats.get(rinst.output.getName());
                int type = 2;
                if (rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0) {
                    type = 0;
                } else if (rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue()) {
                    type = 1;
                }
                attr = new String[]{String.valueOf(type)};
            } else if (inst instanceof StringInitCPInstruction) {
                StringInitCPInstruction rinst = (StringInitCPInstruction)inst;
                vs[0] = _unknownStats;
                vs[1] = _unknownStats;
                vs[2] = stats.get(rinst.output.getName());
            } else {
                String[] parts;
                String opcode;
                UnaryCPInstruction uinst = (UnaryCPInstruction)inst;
                vs[0] = stats.get(uinst.input1.getName());
                vs[1] = _unknownStats;
                vs[2] = stats.get(uinst.output.getName());
                if (vs[0] == null) {
                    vs[0] = _scalarStats;
                }
                if (vs[2] == null) {
                    vs[2] = _scalarStats;
                }
                if (inst instanceof MMTSJCPInstruction) {
                    String type = ((MMTSJCPInstruction)inst).getMMTSJType().toString();
                    attr = new String[]{type};
                } else if (inst instanceof AggregateUnaryCPInstruction && (opcode = (parts = InstructionUtils.getInstructionParts(inst.toString()))[0]).equals("cm")) {
                    attr = new String[]{parts[parts.length - 2]};
                }
            }
        } else if (inst instanceof BinaryCPInstruction) {
            BinaryCPInstruction binst = (BinaryCPInstruction)inst;
            vs[0] = stats.get(binst.input1.getName());
            vs[1] = stats.get(binst.input2.getName());
            vs[2] = stats.get(binst.output.getName());
            if (vs[0] == null) {
                vs[0] = _scalarStats;
            }
            if (vs[1] == null) {
                vs[1] = _scalarStats;
            }
            if (vs[2] == null) {
                vs[2] = _scalarStats;
            }
        } else if (inst instanceof AggregateTernaryCPInstruction) {
            AggregateTernaryCPInstruction binst = (AggregateTernaryCPInstruction)inst;
            vs[0] = stats.get(binst.input1.getName());
            vs[1] = stats.get(binst.input2.getName());
            vs[2] = stats.get(binst.output.getName());
            if (vs[0] == null) {
                vs[0] = _scalarStats;
            }
            if (vs[1] == null) {
                vs[1] = _scalarStats;
            }
            if (vs[2] == null) {
                vs[2] = _scalarStats;
            }
        } else if (inst instanceof ParameterizedBuiltinCPInstruction) {
            String[] parts = InstructionUtils.getInstructionParts(inst.toString());
            String opcode = parts[0];
            if (opcode.equals("groupedagg")) {
                LinkedHashMap<String, String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
                String fn = (String)((HashMap)paramsMap).get("fn");
                String order = (String)((HashMap)paramsMap).get("order");
                CMOperator.AggregateOperationTypes type = CMOperator.getAggOpType(fn, order);
                attr = new String[]{String.valueOf(type.ordinal())};
            } else if (opcode.equals("rmempty")) {
                LinkedHashMap<String, String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
                attr = new String[]{String.valueOf(((String)((HashMap)paramsMap).get("margin")).equals("rows") ? 0 : 1)};
            }
            vs[0] = stats.get(parts[1].substring(7).replaceAll("\u00b6", ""));
            vs[1] = _unknownStats;
            vs[2] = stats.get(parts[parts.length - 1]);
            if (vs[0] == null) {
                vs[0] = _scalarStats;
            }
            if (vs[2] == null) {
                vs[2] = _scalarStats;
            }
        } else if (inst instanceof MultiReturnBuiltinCPInstruction) {
            MultiReturnBuiltinCPInstruction minst = (MultiReturnBuiltinCPInstruction)inst;
            vs[0] = stats.get(minst.input1.getName());
            vs[1] = stats.get(minst.getOutput(0).getName());
            vs[2] = stats.get(minst.getOutput(1).getName());
        } else if (inst instanceof VariableCPInstruction) {
            CostEstimator.setUnknownStats(vs);
            VariableCPInstruction varinst = (VariableCPInstruction)inst;
            if (varinst.getOpcode().equals("write")) {
                if (stats.containsKey(varinst.getInput1().getName())) {
                    vs[0] = stats.get(varinst.getInput1().getName());
                }
                attr = new String[]{varinst.getInput3().getName()};
            }
        } else {
            CostEstimator.setUnknownStats(vs);
        }
        vs[2]._inmem = true;
        ret[0] = vs;
        ret[1] = attr;
        return ret;
    }

    private static void setUnknownStats(VarStats[] vs) {
        vs[0] = _unknownStats;
        vs[1] = _unknownStats;
        vs[2] = _unknownStats;
    }

    private static long getNumIterations(ForProgramBlock pb) {
        return OptimizerUtils.getNumIterations(pb, 15L);
    }

    protected abstract double getCPInstTimeEstimate(Instruction var1, VarStats[] var2, String[] var3);
}

