]> gerrit.simantics Code Review - simantics/platform.git/blobdiff - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/internal/elaboration/subsumption2/SubSolver2.java
New solver for SCL effects inequalities
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / internal / elaboration / subsumption2 / SubSolver2.java
diff --git a/bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/internal/elaboration/subsumption2/SubSolver2.java b/bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/internal/elaboration/subsumption2/SubSolver2.java
new file mode 100644 (file)
index 0000000..3acd0ca
--- /dev/null
@@ -0,0 +1,581 @@
+package org.simantics.scl.compiler.internal.elaboration.subsumption2;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.simantics.scl.compiler.errors.ErrorLog;
+import org.simantics.scl.compiler.internal.elaboration.subsumption.Subsumption;
+import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.LowerBoundSource;
+import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.Node;
+import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.PartOfUnion;
+import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.Sub;
+import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.UnionNode;
+import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.VarNode;
+import org.simantics.scl.compiler.internal.types.effects.EffectIdMap;
+import org.simantics.scl.compiler.types.TMetaVar;
+import org.simantics.scl.compiler.types.util.Polarity;
+
+import gnu.trove.map.hash.THashMap;
+import gnu.trove.map.hash.TIntIntHashMap;
+import gnu.trove.set.hash.THashSet;
+
+public class SubSolver2 {
+    public static final boolean DEBUG = false;
+    
+    // Input
+    private final ErrorLog errorLog;
+    private final ArrayList<Subsumption> subsumptions;
+
+    //
+    private final EffectIdMap effectIds = new EffectIdMap();
+    private final THashMap<TMetaVar, VarNode> varNodeMap = new THashMap<TMetaVar, VarNode>();
+    private final ArrayList<UnionNode> unionNodes = new ArrayList<UnionNode>(); 
+
+    private static TIntIntHashMap STATISTICS = new TIntIntHashMap();
+    
+    private SubSolver2(ErrorLog errorLog, ArrayList<Subsumption> subsumptions) {
+        this.errorLog = errorLog;
+        this.subsumptions = subsumptions;
+        /*if(subsumptions.size() == 1) {
+            TypeUnparsingContext tuc = new TypeUnparsingContext();
+            Subsumption sub = subsumptions.get(0);
+            if(sub.a instanceof TCon && sub.b instanceof TCon)
+                System.out.println("caseCC");
+            else if(sub.a instanceof TMetaVar && sub.b instanceof TCon)
+                System.out.println("caseMC");
+            else if(sub.a instanceof TVar && sub.b instanceof TCon)
+                System.out.println("caseVC");
+            System.out.println("    " + sub.a.toString(tuc) + " < " + sub.b.toString(tuc));
+        }
+        synchronized(STATISTICS) {
+            STATISTICS.adjustOrPutValue(subsumptions.size(), 1, 1);
+            showStatistics();
+        }*/
+    }
+    
+    public static void showStatistics() {
+        System.out.println("---");
+        int[] keys = STATISTICS.keys();
+        Arrays.sort(keys);
+        int sum = 0;
+        for(int key : keys)
+            sum += STATISTICS.get(key);
+        for(int key : keys) {
+            int value = STATISTICS.get(key);
+            System.out.println(key + ": " + value + " (" + (value*100.0/sum) + "%)");
+        }
+    }
+
+    private static boolean subsumes(int a, int b) {
+        return (a&b) == a;
+    }
+
+    private void processSubsumptions() {
+        ArrayList<TMetaVar> aVars = new ArrayList<TMetaVar>(2);
+        ArrayList<TMetaVar> bVars = new ArrayList<TMetaVar>(2);
+        for(Subsumption subsumption : subsumptions) {
+            int aCons = effectIds.toId(subsumption.a, aVars);
+            int bCons = effectIds.toId(subsumption.b, bVars);
+
+            if(bVars.isEmpty()) {
+                if(!subsumes(aCons, bCons)) {
+                    reportSubsumptionFailure(subsumption.loc, aCons, bCons);
+                    continue;
+                }
+                for(TMetaVar aVar : aVars)
+                    getOrCreateNode(aVar).upperBound &= bCons;
+            }
+            else {
+                Node bNode;
+                if(bVars.size() == 1 && bCons == 0)
+                    bNode = getOrCreateNode(bVars.get(0));
+                else
+                    bNode = createUnion(subsumption.loc, bCons, bVars);
+                if(aCons != 0)
+                    setLowerBound(subsumption.loc, aCons, bNode);
+                for(TMetaVar aVar : aVars)
+                    new Sub(getOrCreateNode(aVar), bNode);
+                bVars.clear();
+            }
+            aVars.clear();
+        }
+    }
+
+    private void setLowerBound(long location, int lower, Node node) {
+        node.lowerBound |= lower;
+        node.addLowerBoundSource(location, lower);
+    }
+
+    private UnionNode createUnion(long location, int cons, ArrayList<TMetaVar> vars) {
+        UnionNode node = new UnionNode(location, cons);
+        for(TMetaVar var : vars)
+            new PartOfUnion(getOrCreateNode(var), node);
+        unionNodes.add(node);
+        return node;
+    }
+
+    private VarNode getOrCreateNode(TMetaVar var) {
+        VarNode node = varNodeMap.get(var);
+        if(node == null) {
+            node = new VarNode(var);
+            varNodeMap.put(var, node);
+        }
+        return node;
+    }
+
+    public boolean solve() {
+        //System.out.println("------------------------------------------------------");
+        int errorCount = errorLog.getErrorCount();
+
+        // Check errors
+        processSubsumptions();
+        propagateUpperBounds();
+        checkLowerBounds();
+        
+        if(DEBUG)
+            print();
+
+        if(errorLog.getErrorCount() != errorCount)
+            return false;
+
+        // Simplify constraints
+        stronglyConnectedComponents();
+        propagateLowerBounds();
+        simplify();
+
+        if(DEBUG)
+            print();
+        
+        return true;
+    }
+    
+    private void touchNeighborhood(VarNode node) {
+        for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
+            touch(cur.a);
+        for(Sub cur=node.upper;cur!=null;cur=cur.aNext)
+            touch(cur.b);
+        for(PartOfUnion cur=node.partOf;cur!=null;cur=cur.aNext)
+            touch(cur.b);
+    }
+    
+    private void touchNeighborhood(UnionNode node) {
+        for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
+            touch(cur.a);
+        for(PartOfUnion cur=node.parts;cur!=null;cur=cur.bNext)
+            touch(cur.a);
+    }
+
+    THashSet<Node> set = new THashSet<Node>(); 
+    private void simplify() {
+        for(VarNode node : sortedNodes) {
+            if(node.index == SubsumptionGraph.REMOVED)
+                continue;
+            activeSet.add(node);
+            queue.addLast(node);
+        }
+        for(UnionNode node : unionNodes) {
+            if(node.constPart == SubsumptionGraph.REMOVED)
+                continue;
+            activeSet.add(node);
+            queue.addLast(node);
+        }
+        
+        while(!queue.isEmpty()) {
+            Node node_ = queue.removeFirst();
+            activeSet.remove(node_);
+            if(node_ instanceof VarNode) {
+                VarNode node = (VarNode)node_;
+                if(node.index == SubsumptionGraph.REMOVED)
+                    continue;
+                if(node.lowerBound == node.upperBound) {
+                    if(DEBUG)
+                        System.out.println("replace " + toName(node) + " by " + effectIds.toType(node.lowerBound) + ", node.lowerBound == node.upperBound");
+                    touchNeighborhood(node);
+                    node.removeConstantNode(effectIds, node.lowerBound);
+                    continue;
+                }
+                for(Sub cur=node.upper;cur!=null;cur=cur.aNext)
+                    if(cur.b == node)
+                        cur.remove();
+                if(node.upper != null && node.upper.aNext != null) {
+                    for(Sub cur=node.upper;cur!=null;cur=cur.aNext)
+                        if(!set.add(cur.b) || subsumes(node.upperBound, cur.a.lowerBound)) {
+                            touch(cur.b);
+                            cur.remove();
+                        }
+                    set.clear();
+                }
+                if(node.lower != null && node.lower.bNext != null) {
+                    for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
+                        if(!set.add(cur.a) || subsumes(cur.a.upperBound, node.lowerBound)) {
+                            touch(cur.a);
+                            cur.remove();
+                        }
+                    set.clear();
+                }
+                Polarity polarity = node.getPolarity();
+                if(!polarity.isNegative()) { 
+                    if(node.partOf == null) {
+                        if(node.lower == null) {
+                            // No low nodes
+                            if(DEBUG)
+                                System.out.println("replace " + toName(node) + " by " + effectIds.toType(node.lowerBound) + ", polarity=" + polarity + ", no low nodes");
+                            touchNeighborhood(node);
+                            node.removeConstantNode(effectIds, node.lowerBound);
+                            continue;
+                        }
+                        else if(node.lower.bNext == null) {
+                            // Exactly one low node
+                            VarNode low = node.lower.a;
+
+                            if(low.lowerBound == node.lowerBound) {
+                                node.lower.remove();
+                                if(DEBUG)
+                                    System.out.println("replace " + toName(node) + " by " + toName(low) + ", polarity=" + polarity + ", just one low node");
+                                touchNeighborhood(node);
+                                node.replaceBy(low);
+                                continue;
+                            }
+                        }
+                    }
+                }
+                else if(polarity == Polarity.NEGATIVE) {
+                    if(node.upper != null && node.upper.aNext == null) {
+                        Node high = node.upper.b;
+                        if(node.upperBound == high.upperBound && high instanceof VarNode) {
+                            VarNode varHigh = (VarNode)high;
+                            
+                            node.upper.remove();
+                            if(DEBUG)
+                                System.out.println("replace " + toName(node) + " by " + toName(varHigh) + ", polarity=" + polarity + ", just one high node");
+                            touchNeighborhood(node);
+                            node.replaceBy(varHigh);
+                            continue;
+                        }
+                    }
+                }
+            }
+            else {
+                UnionNode union = (UnionNode)node_;
+                if(union.constPart == SubsumptionGraph.REMOVED)
+                    continue;
+                if(union.lower == null) {
+                    int low = union.constPart;
+                    for(PartOfUnion partOf=union.parts;partOf!=null;partOf=partOf.bNext)
+                        low |= partOf.a.lowerBound;
+
+                    if(subsumes(union.lowerBound, low)) {
+                        if(DEBUG) {
+                            System.out.print("remove union, " + constToString(union.lowerBound) + " < " + constToString(low));
+                            printUnion(union);
+                        }
+                        touchNeighborhood(union);
+                        union.remove();
+                        continue;
+                    }
+                }
+                else {
+                    for(Sub cur=union.lower;cur!=null;cur=cur.bNext) {
+                        VarNode lowNode = union.lower.a;
+                        for(PartOfUnion partOf=union.parts;partOf!=null;partOf=partOf.bNext)
+                            if(partOf.a == lowNode) {
+                                cur.remove();
+                                touch(union);
+                                break;
+                            }
+                    }
+                }
+            }
+        }
+    }
+
+    private void checkLowerBounds() {
+        for(VarNode node : varNodeMap.values())
+            checkLowerBound(node);
+        for(UnionNode node : unionNodes) 
+            checkLowerBound(node);
+    }
+
+    private void checkLowerBound(Node node) {
+        int upperBound = node.upperBound;
+        if(!subsumes(node.lowerBound, upperBound))
+            for(LowerBoundSource source=node.lowerBoundSource;source!=null;source=source.next)
+                if(!subsumes(source.lower, upperBound))
+                    reportSubsumptionFailure(source.location, source.lower, upperBound);
+        node.lowerBoundSource = null;
+    }
+
+    private void propagateLowerBounds() {
+        for(VarNode node : sortedNodes) {
+            for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
+                node.lowerBound |= cur.a.lowerBound;
+        }
+        if(!unionNodes.isEmpty()) {
+            for(UnionNode node : unionNodes) {
+                if(node.parts != null && node.parts.bNext != null) {
+                    // Remove duplicate parts of the union, might be there because of merging of strongly connected components
+                    THashSet<VarNode> varSet = new THashSet<VarNode>(); 
+                    for(PartOfUnion cur=node.parts;cur!=null;cur=cur.bNext)
+                        if(!varSet.add(cur.a))
+                            cur.remove();
+                }
+                
+                for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
+                    node.lowerBound |= cur.a.lowerBound;
+
+                activeSet.add(node);
+                queue.addLast(node);
+            }
+            while(!queue.isEmpty()) {
+                Node node = queue.removeFirst();
+                activeSet.remove(node);
+                int lowerBound = node.lowerBound;
+
+                if(node instanceof VarNode) {
+                    VarNode var = (VarNode)node;
+                    for(Sub cur=var.upper;cur!=null;cur=cur.aNext) {
+                        Node highNode = cur.b;
+                        int newLowerBound = highNode.lowerBound & lowerBound;
+                        if(newLowerBound != highNode.lowerBound) {
+                            highNode.lowerBound = newLowerBound;
+                            touch(highNode);
+                        }
+                    }
+                }
+                else {
+                    UnionNode union = (UnionNode)node;
+                    for(PartOfUnion cur=union.parts;cur!=null;cur=cur.bNext) {
+                        int residual = lowerBound & (~union.constPart);
+                        for(PartOfUnion cur2=union.parts;cur2!=null;cur2=cur2.bNext)
+                            if(cur2 != cur)
+                                residual = lowerBound & (~cur2.a.upperBound);
+                        VarNode partNode = cur.a;
+                        int newLowerBound = partNode.lowerBound | residual;
+                        if(newLowerBound != partNode.lowerBound) {
+                            partNode.lowerBound = newLowerBound;
+                            touch(partNode);
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    private void reportSubsumptionFailure(long location, int lowerBound, int upperBound) {
+        errorLog.log(location, "Side-effect " + effectIds.toType(lowerBound & (~upperBound)) + " is forbidden here.");        
+    }
+
+    private final THashSet<Node> activeSet = new THashSet<>();
+    private final ArrayDeque<Node> queue = new ArrayDeque<>(); 
+
+    private void touch(Node node) {
+        if(activeSet.add(node))
+            queue.addLast(node);
+    }
+    
+    private void propagateUpperBounds() {
+        for(VarNode node : varNodeMap.values())
+            if(node.upperBound != EffectIdMap.MAX) {
+                activeSet.add(node);
+                queue.addLast(node);
+            }
+
+        while(!queue.isEmpty()) {
+            Node node = queue.removeFirst();
+            activeSet.remove(node);
+            int upperBound = node.upperBound;
+
+            if(node instanceof VarNode) {
+                // Upper bounds for unions are not calculated immediately
+                for(PartOfUnion cur=((VarNode)node).partOf;cur!=null;cur=cur.aNext) {
+                    UnionNode union = cur.b;
+                    touch(union);
+                }
+            }
+            else {
+                // New upper bound for union is calculated here
+                UnionNode union = (UnionNode)node;
+                int newUpperBound = union.constPart;
+                for(PartOfUnion cur=union.parts;cur!=null;cur=cur.bNext)
+                    newUpperBound |= cur.a.upperBound;
+                if(newUpperBound != upperBound)
+                    node.upperBound = upperBound = newUpperBound;
+                else
+                    continue; // No changes in upper bound, no need to propagate
+            }
+
+            // Propagate upper bound to smaller variables
+            for(Sub cur=node.lower;cur!=null;cur=cur.bNext) {
+                VarNode lowNode = cur.a;
+                int newUpperBound = lowNode.upperBound & upperBound;
+                if(newUpperBound != lowNode.upperBound) {
+                    lowNode.upperBound = newUpperBound;
+                    touch(lowNode);
+                }
+            }
+        }
+    }
+
+    int curIndex;
+    private void stronglyConnectedComponents() {
+        sortedNodes = new ArrayList<VarNode>(varNodeMap.size());
+        for(VarNode node : varNodeMap.values())
+            node.index = -1;
+        for(VarNode node : varNodeMap.values())
+            if(node.index == -1) {
+                curIndex = 0;
+                stronglyConnectedComponents(node);
+            }
+    }
+
+    ArrayList<VarNode> sortedNodes;
+    ArrayList<VarNode> stack = new ArrayList<VarNode>(); 
+    private int stronglyConnectedComponents(VarNode node) {
+        int lowindex = node.index = curIndex++;
+        stack.add(node);
+        for(Sub sub=node.lower;sub != null;sub=sub.bNext) {
+            VarNode child = sub.a;
+            int childIndex = child.index;
+            if(childIndex == -1)
+                childIndex = stronglyConnectedComponents(child);
+            lowindex = Math.min(lowindex, childIndex);
+        }
+        if(node.index == lowindex) {
+            // root of strongly connected component
+            VarNode stackNode = stack.remove(stack.size()-1);
+            if(stackNode != node) {
+                ArrayList<VarNode> otherInComponent = new ArrayList<VarNode>(4);
+                while(stackNode != node) {
+                    otherInComponent.add(stackNode);
+                    stackNode = stack.remove(stack.size()-1);
+                }
+                mergeComponent(node, otherInComponent);
+            }
+            node.index = Integer.MAX_VALUE;
+            sortedNodes.add(node);
+        }
+        return lowindex;
+    }
+
+    private void mergeComponent(VarNode root, ArrayList<VarNode> otherInComponent) {
+        // There is no need to merge upper bounds, because they have been propagated
+        int lowerBound = root.lowerBound;
+        for(VarNode node : otherInComponent)
+            lowerBound |= node.lowerBound;
+        root.lowerBound = lowerBound;
+
+        for(VarNode node : otherInComponent) {
+            if(DEBUG)
+                System.out.println("replace " + toName(node) + " by " + toName(root));
+            node.replaceBy(root);
+        }
+    }
+
+    // Dummy debugging functions
+    private String toName(Node node) {
+        return "";
+    }
+    private void printUnion(UnionNode union) {
+    }
+    private void print() {
+    }
+    private String constToString(int cons) {
+        return "";
+    }
+    /*
+    private TypeUnparsingContext tuc = new TypeUnparsingContext();
+    private THashMap<Node, String> nameMap = new THashMap<Node, String>();
+    private char nextChar = 'a';
+    
+    private String toName(Node node) {
+        String name = nameMap.get(node);
+        if(name == null) {
+            name = new String(new char[] {'?', nextChar++});
+            nameMap.put(node, name);
+        }
+        return name;
+    }
+    
+    private String constToString(int cons) {
+        return effectIds.toType(cons).toString(tuc);
+    }
+    
+    private boolean hasContent() {
+        for(VarNode node : varNodeMap.values())
+            if(node.index != SubsumptionGraph.REMOVED)
+//                if(node.lower != null)
+                return true;
+        for(UnionNode node : unionNodes)
+            if(node.constPart != SubsumptionGraph.REMOVED)
+                return true;
+        return false;
+    }
+    
+    private void print() {
+        if(!hasContent())
+            return;
+        System.out.println("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv");
+        TypeUnparsingContext tuc = new TypeUnparsingContext();
+        for(VarNode node : varNodeMap.values()) {
+            if(node.index == SubsumptionGraph.REMOVED) {
+                //System.out.println(toName(node) + " removed");
+                continue;
+            }
+            System.out.print(toName(node));
+            if(node.lowerBound != EffectIdMap.MIN || node.upperBound != EffectIdMap.MAX) {
+                System.out.print(" in [");
+                if(node.lowerBound != EffectIdMap.MIN)
+                    System.out.print(constToString(node.lowerBound));
+                System.out.print("..");
+                if(node.upperBound != EffectIdMap.MAX) {
+                    if(node.upperBound == 0)
+                        System.out.print("Pure");
+                    else
+                        System.out.print(constToString(node.upperBound));
+                }
+                System.out.print("]");
+            }
+            System.out.println(" (" + node.getPolarity() + ")");
+            
+            for(Sub cur=node.upper;cur!=null;cur=cur.aNext) {
+                System.out.print("    < ");
+                Node highNode = cur.b;
+                if(highNode instanceof VarNode) {
+                    System.out.println(toName(highNode));
+                }
+                else
+                    printUnion((UnionNode)highNode);
+            }
+        }
+        for(UnionNode node : unionNodes) {
+            if(node.lower != null)
+                continue;
+            System.out.print(constToString(node.lowerBound) + " < ");
+            printUnion(node);
+        }
+        System.out.println("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^");
+    }
+    
+    private void printUnion(UnionNode union) {
+        System.out.print("union(");
+        boolean first = true;
+        if(union.constPart != EffectIdMap.MIN) {
+            System.out.print(constToString(union.constPart));
+            first = false;
+        }
+        for(PartOfUnion part=union.parts;part!=null;part=part.bNext) {
+            if(first)
+                first = false;
+            else
+                System.out.print(", ");
+            System.out.print(toName(part.a));
+        }
+        System.out.println(")");
+    }
+    */
+    
+    public static void solve(ErrorLog errorLog, ArrayList<Subsumption> subsumptions) {
+        new SubSolver2(errorLog, subsumptions).solve();
+    }
+}