/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class MMFEDInstruction
extends BinaryFEDInstruction {
    private MMFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, MapMult.CacheType type, boolean outputEmpty, AggBinaryOp.SparkAggType aggtype, String opcode, String istr) {
        super(FEDInstruction.FEDType.MAPMM, op, in1, in2, out, opcode, istr);
    }

    public static MMFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!ArrayUtils.contains((Object[])new String[]{"mapmm", "pmm", "cpmm", "rmm"}, (Object)opcode)) {
            throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        MapMult.CacheType type = MapMult.CacheType.valueOf(parts[4]);
        boolean outputEmpty = Boolean.parseBoolean(parts[5]);
        AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[6]);
        AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
        return new MMFEDInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixObject mo1 = ec.getMatrixObject(this.input1);
        MatrixObject mo2 = ec.getMatrixObject(this.input2);
        long id = FederationUtils.getNextFedDataID();
        FederatedRequest frEmpty = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new Object[]{new MatrixCharacteristics(-1L, -1L), Types.DataType.MATRIX});
        if (mo1.isFederated(FederationMap.FType.COL) && mo2.isFederated(FederationMap.FType.ROW) && mo1.getFedMapping().isAligned(mo2.getFedMapping(), FederationMap.AlignType.COL_T)) {
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, Types.ExecType.SPARK, false);
            if (this._fedOut.isForcedFederated()) {
                mo1.getFedMapping().execute(this.getTID(), frEmpty, fr1);
                this.setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr1.getID(), ec);
            } else {
                FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
                FederatedRequest fr3 = mo2.getFedMapping().cleanup(this.getTID(), fr1.getID(), fr2.getID());
                Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), frEmpty, fr1, fr2, fr3);
                MatrixBlock ret = FederationUtils.aggAdd(tmp);
                ec.setMatrixOutput(this.output.getName(), ret);
            }
        } else if (mo1.isFederated(FederationMap.FType.ROW) || mo1.isFederated(FederationMap.FType.PART)) {
            FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()}, Types.ExecType.SPARK, false);
            if (mo2.getNumColumns() == 1L) {
                if (this._fedOut.isForcedFederated()) {
                    mo1.getFedMapping().execute(this.getTID(), frEmpty, fr1, fr2);
                    if (mo1.isFederated(FederationMap.FType.PART)) {
                        this.setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
                    } else {
                        this.setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
                    }
                } else {
                    FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
                    FederatedRequest fr4 = mo1.getFedMapping().cleanup(this.getTID(), fr2.getID());
                    Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), frEmpty, fr1, fr2, fr3, fr4);
                    MatrixBlock ret = mo1.isFederated(FederationMap.FType.PART) ? FederationUtils.aggAdd(tmp) : FederationUtils.bind(tmp, false);
                    ec.setMatrixOutput(this.output.getName(), ret);
                }
            } else if (!this._fedOut.isForcedLocal()) {
                mo1.getFedMapping().execute(this.getTID(), true, frEmpty, fr1, fr2);
                if (mo1.isFederated(FederationMap.FType.PART) || mo2.isFederated(FederationMap.FType.PART)) {
                    this.setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
                } else {
                    this.setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
                }
            } else {
                FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
                FederatedRequest fr4 = mo1.getFedMapping().cleanup(this.getTID(), fr2.getID());
                Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), frEmpty, fr1, fr2, fr3, fr4);
                MatrixBlock ret = mo1.isFederated(FederationMap.FType.PART) ? FederationUtils.aggAdd(tmp) : FederationUtils.bind(tmp, false);
                ec.setMatrixOutput(this.output.getName(), ret);
            }
        } else if (mo2.isFederated(FederationMap.FType.ROW)) {
            if (mo1.isFederated(FederationMap.FType.COL) && MMFEDInstruction.isAggBinaryFedAligned(mo1, mo2)) {
                FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true);
                if (this._fedOut.isForcedFederated()) {
                    mo2.getFedMapping().execute(this.getTID(), true, fr2);
                    this.setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
                } else {
                    FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
                    Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(this.getTID(), fr2, fr3);
                    MatrixBlock ret = FederationUtils.aggAdd(tmp);
                    ec.setMatrixOutput(this.output.getName(), ret);
                }
            } else {
                FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
                FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
                if (this._fedOut.isForcedFederated()) {
                    mo2.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[]{fr2});
                    this.setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
                } else {
                    FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
                    FederatedRequest fr4 = mo2.getFedMapping().cleanup(this.getTID(), fr2.getID());
                    Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[]{fr2, fr3, fr4});
                    MatrixBlock ret = FederationUtils.aggAdd(tmp);
                    ec.setMatrixOutput(this.output.getName(), ret);
                }
            }
        } else if (mo1.isFederated(FederationMap.FType.COL)) {
            FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, true);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, true);
            if (this._fedOut.isForcedFederated()) {
                mo1.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[]{fr2});
                this.setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
            } else {
                FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
                FederatedRequest fr4 = mo1.getFedMapping().cleanup(this.getTID(), fr2.getID());
                Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), fr1, new FederatedRequest[]{fr2, fr3, fr4});
                MatrixBlock ret = FederationUtils.aggAdd(tmp);
                ec.setMatrixOutput(this.output.getName(), ret);
            }
        } else {
            throw new DMLRuntimeException("Federated AggregateBinary not supported with the following federated objects: " + mo1.isFederated() + ":" + mo1.getFedMapping() + " " + mo2.isFederated() + ":" + mo2.getFedMapping());
        }
    }

    private static boolean isAggBinaryFedAligned(MatrixObject mo1, MatrixObject mo2) {
        FederatedRange[] mo1FederatedRanges = mo1.getFedMapping().getFederatedRanges();
        FederatedRange[] mo2FederatedRanges = mo2.getFedMapping().getFederatedRanges();
        for (int i = 0; i < mo1FederatedRanges.length; ++i) {
            FederatedRange mo1FedRange = mo1FederatedRanges[i];
            FederatedRange mo2FedRange = mo2FederatedRanges[i];
            if (mo1FedRange.getBeginDims()[1] == mo2FedRange.getBeginDims()[0] && mo1FedRange.getEndDims()[1] == mo2FedRange.getEndDims()[0]) continue;
            return false;
        }
        return true;
    }

    private void setPartialOutput(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2, long outputID, ExecutionContext ec) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
        FederationMap outputFedMap = federationMap.copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
        out.setFedMapping(outputFedMap);
    }

    private void setOutputFedMapping(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2, long outputID, ExecutionContext ec) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
        out.setFedMapping(federationMap.copyWithNewID(outputID, mo2.getNumColumns()));
    }
}

