/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.SpoofCompiler;
import org.apache.sysml.hops.codegen.template.PlanSelection;
import org.apache.sysml.hops.codegen.template.TemplateBase;

public class CPlanMemoTable {
    private static final Log LOG = LogFactory.getLog((String)CPlanMemoTable.class.getName());
    protected HashMap<Long, List<MemoTableEntry>> _plans = new HashMap();
    protected HashMap<Long, Hop> _hopRefs = new HashMap();
    protected HashSet<Long> _plansBlacklist = new HashSet();

    public void addHop(Hop hop) {
        this._hopRefs.put(hop.getHopID(), hop);
    }

    public boolean containsHop(Hop hop) {
        return this._hopRefs.containsKey(hop.getHopID());
    }

    public boolean contains(long hopID) {
        return this._plans.containsKey(hopID);
    }

    public boolean contains(long hopID, TemplateBase.TemplateType type) {
        return this.contains(hopID) && this.get(hopID).stream().filter(p -> p.type == type).findAny().isPresent();
    }

    public boolean containsTopLevel(long hopID) {
        return !this._plansBlacklist.contains(hopID) && this.getBest(hopID) != null;
    }

    public void add(Hop hop, TemplateBase.TemplateType type) {
        this.add(hop, type, -1L, -1L, -1L);
    }

    public void add(Hop hop, TemplateBase.TemplateType type, long in1) {
        this.add(hop, type, in1, -1L, -1L);
    }

    public void add(Hop hop, TemplateBase.TemplateType type, long in1, long in2) {
        this.add(hop, type, in1, in2, -1L);
    }

    public void add(Hop hop, TemplateBase.TemplateType type, long in1, long in2, long in3) {
        this.add(hop, new MemoTableEntry(type, in1, in2, in3));
    }

    public void add(Hop hop, MemoTableEntry me) {
        this._hopRefs.put(hop.getHopID(), hop);
        if (!this._plans.containsKey(hop.getHopID())) {
            this._plans.put(hop.getHopID(), new ArrayList());
        }
        this._plans.get(hop.getHopID()).add(me);
    }

    public void addAll(Hop hop, MemoTableEntrySet P) {
        this._hopRefs.put(hop.getHopID(), hop);
        if (!this._plans.containsKey(hop.getHopID())) {
            this._plans.put(hop.getHopID(), new ArrayList());
        }
        this._plans.get(hop.getHopID()).addAll(P.plans);
    }

    public void remove(Hop hop, HashSet<MemoTableEntry> blackList) {
        this._plans.put(hop.getHopID(), this._plans.get(hop.getHopID()).stream().filter(p -> !blackList.contains(p)).collect(Collectors.toList()));
    }

    public void setDistinct(long hopID, List<MemoTableEntry> plans) {
        this._plans.put(hopID, plans.stream().distinct().collect(Collectors.toList()));
    }

    public void pruneRedundant(long hopID) {
        if (!this.contains(hopID)) {
            return;
        }
        this.setDistinct(hopID, this._plans.get(hopID));
        HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>();
        List<MemoTableEntry> list = this._plans.get(hopID);
        Hop hop = this._hopRefs.get(hopID);
        for (MemoTableEntry e1 : list) {
            for (MemoTableEntry e2 : list) {
                if (e1 == e2 || !e1.subsumes(e2)) continue;
                boolean rmSafe = true;
                for (int i = 0; i <= 2; ++i) {
                    rmSafe &= e1.isPlanRef(i) && !e2.isPlanRef(i) ? hop.getInput().get(i).getParent().size() == 1 : true;
                }
                if (!rmSafe) continue;
                rmList.add(e2);
            }
        }
        this.remove(hop, rmList);
    }

    public void pruneSuboptimal(ArrayList<Hop> roots) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("#1: Memo before plan selection (" + this.size() + " plans)\n" + this));
        }
        HashSet<Long> ix = new HashSet<Long>();
        for (Map.Entry<Long, List<MemoTableEntry>> e : this._plans.entrySet()) {
            for (MemoTableEntry me : e.getValue()) {
                ix.add(me.input1);
                ix.add(me.input2);
                ix.add(me.input3);
            }
        }
        Iterator<Map.Entry<Long, List<MemoTableEntry>>> iter = this._plans.entrySet().iterator();
        while (iter.hasNext()) {
            Map.Entry<Long, List<MemoTableEntry>> e;
            e = iter.next();
            if (ix.contains(e.getKey())) continue;
            e.setValue(((List)e.getValue()).stream().filter(p -> p.hasPlanRef()).collect(Collectors.toList()));
            if (!((List)e.getValue()).isEmpty()) continue;
            iter.remove();
        }
        for (Map.Entry entry : this._plans.entrySet()) {
            for (MemoTableEntry me : (List)entry.getValue()) {
                for (int i = 0; i <= 2; ++i) {
                    if (!me.isPlanRef(i) || this._hopRefs.get(me.input(i)).getParent().size() != 1) continue;
                    this._plansBlacklist.add(me.input(i));
                }
            }
        }
        PlanSelection selector = SpoofCompiler.createPlanSelector();
        selector.selectPlans(this, roots);
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("#2: Memo after plan selection (" + this.size() + " plans)\n" + this));
        }
    }

    public List<MemoTableEntry> get(long hopID) {
        return this._plans.get(hopID);
    }

    public List<MemoTableEntry> getDistinct(long hopID) {
        return this._plans.get(hopID).stream().map(p -> new MemoTableEntry(p.type, -1L, -1L, -1L, p.closed)).distinct().collect(Collectors.toList());
    }

    public MemoTableEntry getBest(long hopID) {
        List<MemoTableEntry> tmp = this.get(hopID);
        if (tmp == null || tmp.isEmpty()) {
            return null;
        }
        return tmp.stream().filter(p -> PlanSelection.isValid(p, this._hopRefs.get(hopID))).min(Comparator.comparing(p -> p.type.getRank())).orElse(null);
    }

    public MemoTableEntry getBest(long hopID, TemplateBase.TemplateType pref) {
        List<MemoTableEntry> tmp = this.get(hopID);
        if (tmp == null || tmp.isEmpty()) {
            return null;
        }
        return Collections.min(tmp, Comparator.comparing(p -> p.type == pref ? -p.countPlanRefs() : p.type.getRank() + 1));
    }

    public long[] getAllRefs(long hopID) {
        long[] refs = new long[3];
        for (MemoTableEntry me : this.get(hopID)) {
            for (int i = 0; i < 3; ++i) {
                if (!me.isPlanRef(i)) continue;
                int n = i;
                refs[n] = refs[n] | me.input(i);
            }
        }
        return refs;
    }

    public int size() {
        return this._plans.values().stream().map(list -> list.size()).mapToInt(x -> x).sum();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("----------------------------------\n");
        sb.append("MEMO TABLE: \n");
        sb.append("----------------------------------\n");
        for (Map.Entry<Long, List<MemoTableEntry>> e : this._plans.entrySet()) {
            sb.append(e.getKey() + " " + this._hopRefs.get(e.getKey()).getOpString() + ": ");
            sb.append(Arrays.toString(e.getValue().toArray(new MemoTableEntry[0])) + "\n");
        }
        sb.append("----------------------------------\n");
        sb.append("Blacklisted Plans: ");
        sb.append(Arrays.toString((Object[])this._plansBlacklist.toArray(new Long[0])) + "\n");
        sb.append("----------------------------------\n");
        return sb.toString();
    }

    public static class MemoTableEntrySet {
        public ArrayList<MemoTableEntry> plans = new ArrayList();

        public MemoTableEntrySet(TemplateBase.TemplateType type, boolean close) {
            this.plans.add(new MemoTableEntry(type, -1L, -1L, -1L, close));
        }

        public MemoTableEntrySet(TemplateBase.TemplateType type, int pos, long hopID, boolean close) {
            this.plans.add(new MemoTableEntry(type, pos == 0 ? hopID : -1L, pos == 1 ? hopID : -1L, pos == 2 ? hopID : -1L));
        }

        public void crossProduct(int pos, Long ... refs) {
            ArrayList<MemoTableEntry> tmp = new ArrayList<MemoTableEntry>();
            for (MemoTableEntry me : this.plans) {
                for (Long ref : refs) {
                    tmp.add(new MemoTableEntry(me.type, pos == 0 ? ref : me.input1, pos == 1 ? ref : me.input2, pos == 2 ? ref : me.input3));
                }
            }
            this.plans = tmp;
        }

        public String toString() {
            return Arrays.toString(this.plans.toArray(new MemoTableEntry[0]));
        }
    }

    public static class MemoTableEntry {
        public TemplateBase.TemplateType type;
        public final long input1;
        public final long input2;
        public final long input3;
        public boolean closed = false;

        public MemoTableEntry(TemplateBase.TemplateType t, long in1, long in2, long in3) {
            this(t, in1, in2, in3, false);
        }

        public MemoTableEntry(TemplateBase.TemplateType t, long in1, long in2, long in3, boolean close) {
            this.type = t;
            this.input1 = in1;
            this.input2 = in2;
            this.input3 = in3;
            this.closed = close;
        }

        public boolean isPlanRef(int index) {
            return index == 0 && this.input1 >= 0L || index == 1 && this.input2 >= 0L || index == 2 && this.input3 >= 0L;
        }

        public boolean hasPlanRef() {
            return this.isPlanRef(0) || this.isPlanRef(1) || this.isPlanRef(2);
        }

        public int countPlanRefs() {
            return (this.input1 >= 0L ? 1 : 0) + (this.input2 >= 0L ? 1 : 0) + (this.input3 >= 0L ? 1 : 0);
        }

        public long input(int index) {
            return index == 0 ? this.input1 : (index == 1 ? this.input2 : this.input3);
        }

        public boolean subsumes(MemoTableEntry that) {
            return !(this.type != that.type || !this.isPlanRef(0) && that.isPlanRef(0) || !this.isPlanRef(1) && that.isPlanRef(1) || !this.isPlanRef(2) && that.isPlanRef(2));
        }

        public int hashCode() {
            return Arrays.hashCode(new long[]{this.type.ordinal(), this.input1, this.input2, this.input3});
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof MemoTableEntry)) {
                return false;
            }
            MemoTableEntry that = (MemoTableEntry)obj;
            return this.type == that.type && this.input1 == that.input1 && this.input2 == that.input2 && this.input3 == that.input3;
        }

        public String toString() {
            return this.type.name() + "(" + this.input1 + "," + this.input2 + "," + this.input3 + ")";
        }
    }
}

