/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.parser.StatementBlock;

public class RewriteAddChkpointLop
extends LopRewriteRule {
    @Override
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
        if (!ConfigurationManager.isCheckpointEnabled()) {
            return List.of(sb);
        }
        ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
        if (lops == null) {
            return List.of(sb);
        }
        HashSet<Lop> sparkRoots = new HashSet<Lop>();
        HashMap sparkOpCount = new HashMap();
        List<Lop> roots = lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
        roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
        if (sparkRoots.isEmpty()) {
            return List.of(sb);
        }
        HashMap<Long, Integer> operatorJobCount = new HashMap<Long, Integer>();
        OperatorOrderingUtils.markSharedSparkOps(sparkRoots, operatorJobCount);
        this.addChkpointLop(lops, operatorJobCount);
        this.placeCompiledCheckpoints(lops, sb);
        return List.of(sb);
    }

    @Override
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
        return sbs;
    }

    private void addChkpointLop(List<Lop> nodes, Map<Long, Integer> operatorJobCount) {
        for (Lop l : nodes) {
            if (!operatorJobCount.containsKey(l.getID()) || operatorJobCount.get(l.getID()) <= 1 || !OperatorOrderingUtils.isPersistableSparkOp(l)) continue;
            ArrayList<Lop> oldOuts = new ArrayList<Lop>(l.getOutputs());
            Checkpoint chkpoint = new Checkpoint(l, l.getDataType(), l.getValueType(), Checkpoint.getDefaultStorageLevelString(), false);
            for (Lop out : oldOuts) {
                chkpoint.addOutput(out);
                out.replaceInput(l, chkpoint);
                l.removeOutput(out);
            }
        }
    }

    private void placeCompiledCheckpoints(List<Lop> nodes, StatementBlock sb) {
        if (sb.getCheckpointPositions() == null) {
            return;
        }
        for (Lop l : nodes) {
            if (!this.isCheckpointed(l, sb)) continue;
            ArrayList<Lop> oldOuts = new ArrayList<Lop>(l.getOutputs());
            Checkpoint chkpoint = new Checkpoint(l, l.getDataType(), l.getValueType(), Checkpoint.getDefaultStorageLevelString(), false);
            for (Lop out : oldOuts) {
                chkpoint.addOutput(out);
                out.replaceInput(l, chkpoint);
                l.removeOutput(out);
            }
        }
    }

    private boolean isCheckpointed(Lop lop, StatementBlock sb) {
        HashMap<Lop.Type, List<Lop.Type>> cpPositions = sb.getCheckpointPositions();
        if (cpPositions == null) {
            return false;
        }
        if (cpPositions.containsKey((Object)lop.getType())) {
            List<Lop.Type> outputsT = cpPositions.get((Object)lop.getType());
            ArrayList<Lop> outputs = new ArrayList<Lop>(lop.getOutputs());
            if (outputs.size() != outputsT.size()) {
                return false;
            }
            for (int i = 0; i < outputs.size(); ++i) {
                if (((Lop)outputs.get(i)).getType() == outputsT.get(i) && ((Lop)outputs.get(i)).isExecSpark()) continue;
                return false;
            }
        } else {
            return false;
        }
        return true;
    }
}

