/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import scala.Tuple2;

public class MapmmChainSPInstruction
extends SPInstruction
implements LineageTraceable {
    private MapMultChain.ChainType _chainType = null;
    private CPOperand _input1 = null;
    private CPOperand _input2 = null;
    private CPOperand _input3 = null;
    private CPOperand _output = null;

    private MapmmChainSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, MapMultChain.ChainType type, String opcode, String istr) {
        super(SPInstruction.SPType.MAPMMCHAIN, op, opcode, istr);
        this._input1 = in1;
        this._input2 = in2;
        this._output = out;
        this._chainType = type;
    }

    private MapmmChainSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, MapMultChain.ChainType type, String opcode, String istr) {
        super(SPInstruction.SPType.MAPMMCHAIN, op, opcode, istr);
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this._output = out;
        this._chainType = type;
    }

    public static MapmmChainSPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 4, 5);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("mapmmchain")) {
            throw new DMLRuntimeException("MapmmChainSPInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        if (parts.length == 5) {
            CPOperand out = new CPOperand(parts[3]);
            MapMultChain.ChainType type = MapMultChain.ChainType.valueOf(parts[4]);
            return new MapmmChainSPInstruction(null, in1, in2, out, type, opcode, str);
        }
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4]);
        MapMultChain.ChainType type = MapMultChain.ChainType.valueOf(parts[5]);
        return new MapmmChainSPInstruction(null, in1, in2, in3, out, type, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inX = sec.getBinaryMatrixBlockRDDHandleForVariable(this._input1.getName());
        PartitionedBroadcast<MatrixBlock> inV = sec.getBroadcastForVariable(this._input2.getName());
        MatrixBlock out = null;
        if (this._chainType == MapMultChain.ChainType.XtXv) {
            JavaRDD<MatrixBlock> tmp = inX.values().map(new RDDMapMMChainFunction(inV));
            out = RDDAggregateUtils.sumStable(tmp);
        } else {
            PartitionedBroadcast<MatrixBlock> inW = sec.getBroadcastForVariable(this._input3.getName());
            JavaRDD<MatrixBlock> tmp = inX.map(new RDDMapMMChainFunction2(inV, inW, this._chainType));
            out = RDDAggregateUtils.sumStable(tmp);
        }
        sec.setMatrixOutput(this._output.getName(), out);
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        CPOperand chainT = new CPOperand(this._chainType.name(), Types.ValueType.INT64, Types.DataType.SCALAR, true);
        return Pair.of((Object)this._output.getName(), (Object)new LineageItem(this.getOpcode(), LineageItemUtils.getLineage(ec, this._input1, this._input2, this._input3, chainT)));
    }

    private static class RDDMapMMChainFunction2
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -7926980450209760212L;
        private PartitionedBroadcast<MatrixBlock> _pmV = null;
        private PartitionedBroadcast<MatrixBlock> _pmW = null;
        private MapMultChain.ChainType _chainType = null;

        public RDDMapMMChainFunction2(PartitionedBroadcast<MatrixBlock> bV, PartitionedBroadcast<MatrixBlock> bW, MapMultChain.ChainType chain) {
            this._pmV = bV;
            this._pmW = bW;
            this._chainType = chain;
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
            MatrixBlock pmV = this._pmV.getBlock(1, 1);
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            int rowIx = (int)ixIn.getRowIndex();
            return blkIn.chainMatrixMultOperations(pmV, this._pmW.getBlock(rowIx, 1), new MatrixBlock(), this._chainType);
        }
    }

    private static class RDDMapMMChainFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private PartitionedBroadcast<MatrixBlock> _pmV = null;

        public RDDMapMMChainFunction(PartitionedBroadcast<MatrixBlock> bV) {
            this._pmV = bV;
        }

        public MatrixBlock call(MatrixBlock arg0) {
            MatrixBlock pmV = this._pmV.getBlock(1, 1);
            return arg0.chainMatrixMultOperations(pmV, null, new MatrixBlock(), MapMultChain.ChainType.XtXv);
        }
    }
}

