]> gerrit.simantics Code Review - simantics/platform.git/blobdiff - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/query/QDisjunction.java
Migrated source code from Simantics SVN
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / query / QDisjunction.java
diff --git a/bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/query/QDisjunction.java b/bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/query/QDisjunction.java
new file mode 100644 (file)
index 0000000..f279264
--- /dev/null
@@ -0,0 +1,223 @@
+package org.simantics.scl.compiler.elaboration.query;
+
+import gnu.trove.map.hash.THashMap;
+import gnu.trove.map.hash.TLongObjectHashMap;
+import gnu.trove.set.hash.TIntHashSet;
+
+import java.util.ArrayList;
+import java.util.Set;
+
+import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
+import org.simantics.scl.compiler.elaboration.expressions.EApply;
+import org.simantics.scl.compiler.elaboration.expressions.ESimpleLambda;
+import org.simantics.scl.compiler.elaboration.expressions.ESimpleLet;
+import org.simantics.scl.compiler.elaboration.expressions.EVariable;
+import org.simantics.scl.compiler.elaboration.expressions.Expression;
+import org.simantics.scl.compiler.elaboration.expressions.QueryTransformer;
+import org.simantics.scl.compiler.elaboration.expressions.Variable;
+import org.simantics.scl.compiler.elaboration.query.compilation.ConstraintCollectionContext;
+import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
+import org.simantics.scl.compiler.elaboration.query.compilation.QueryCompilationContext;
+import org.simantics.scl.compiler.elaboration.query.compilation.QueryConstraint;
+import org.simantics.scl.compiler.elaboration.query.compilation.UnsolvableQueryException;
+import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
+import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
+import org.simantics.scl.compiler.errors.Locations;
+import org.simantics.scl.compiler.types.Types;
+
+
+public class QDisjunction extends QAbstractCombiner {
+
+    public QDisjunction(Query ... queries) {
+        super(queries);
+    }
+    
+    private static class CachedPlan {
+        Variable[] variables;
+        QueryCompilationContext[] subplans;
+        double totalBranching;
+        double totalCost;
+        
+        public CachedPlan(Variable[] variables, QueryCompilationContext[] subplans,
+                double totalBranching, double totalCost) {
+            this.variables = variables;
+            this.subplans = subplans;
+            this.totalBranching = totalBranching;
+            this.totalCost = totalCost;
+        }
+    }
+
+    @Override
+    public void collectConstraints(final ConstraintCollectionContext context) {
+        TIntHashSet vars = new TIntHashSet();
+        collectVars(context.getVariableMap(), vars);
+        
+        final Variable continuationFunction = new Variable("continuation");
+        int[] variables = vars.toArray();
+        long variableMask_ = 0L;
+        for(int v : variables)
+            variableMask_ |= 1L << v;
+        final long variableMask = variableMask_;
+        
+        context.addConstraint(new QueryConstraint(variables) {
+            
+            TLongObjectHashMap<CachedPlan> cache = new TLongObjectHashMap<CachedPlan>();
+            
+            private CachedPlan create(long boundVariables) {
+                QueryCompilationContext[] subplans = new QueryCompilationContext[queries.length];
+                double totalBranching = 1.0;
+                double totalCost = 0.0;
+                ArrayList<Variable> solvedVariablesList = new ArrayList<Variable>();
+                for(int v : variables)
+                    if( ((boundVariables >> v)&1) == 0 )
+                        solvedVariablesList.add(context.getVariable(v));
+                Variable[] solvedVariables = solvedVariablesList.toArray(new Variable[solvedVariablesList.size()]);
+                for(int i=0;i<queries.length;++i) {
+                    Expression[] parameters = new Expression[solvedVariables.length];
+                    for(int j=0;j<solvedVariables.length;++j)
+                        parameters[j] = new EVariable(solvedVariables[j]);
+                    EApply cont = new EApply(Locations.NO_LOCATION, Types.PROC,
+                            new EVariable(continuationFunction), parameters);
+                    cont.setType(context.getQueryCompilationContext().getContinuation().getType());
+                    subplans[i] = context.getQueryCompilationContext().createSubcontext(cont);
+                    try {
+                        new QExists(solvedVariables, queries[i]).generate(subplans[i]);
+                    } catch (UnsolvableQueryException e) {
+                        return new CachedPlan(null, null, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
+                    }
+                    totalBranching += subplans[i].getBranching();
+                    totalCost += subplans[i].getCost();
+                }
+                return new CachedPlan(solvedVariables, subplans, totalBranching, totalCost);
+            }
+            
+            private CachedPlan get(long boundVariables) {
+                boundVariables &= variableMask;
+                CachedPlan plan = cache.get(boundVariables);
+                if(plan == null) {
+                    plan = create(boundVariables);
+                    cache.put(boundVariables, plan);
+                }
+                return plan;
+            }
+            
+            @Override
+            public double getSolutionCost(long boundVariables) {
+                return get(boundVariables).totalCost;
+            }
+            
+            @Override
+            public double getSolutionBranching(long boundVariables) {
+                return get(boundVariables).totalBranching;
+            }
+            
+            @Override
+            public boolean canBeSolvedFrom(long boundVariables) {
+                return get(boundVariables).totalCost != Double.POSITIVE_INFINITY;
+            }
+            
+            @Override
+            public void generate(QueryCompilationContext context) {
+                CachedPlan plan = get(finalBoundVariables);
+                
+                Expression[] disjuncts = new Expression[plan.subplans.length];
+                for(int i=0;i<plan.subplans.length;++i)
+                    disjuncts[i] = plan.subplans[i].getContinuation().copy(context.getTypingContext());
+                Expression result = context.disjunction(disjuncts);
+                
+                ReplaceContext replaceContext = new ReplaceContext(context.getTypingContext());
+                Variable[] newVariables = new Variable[plan.variables.length];
+                for(int i=0;i<newVariables.length;++i) {
+                    Variable oldVariable = plan.variables[i];
+                    Variable newVariable = new Variable(oldVariable.getName(), oldVariable.getType());
+                    newVariables[i] = newVariable;
+                    oldVariable.setName(oldVariable.getName() + "_temp");
+                    replaceContext.varMap.put(oldVariable, new EVariable(newVariable));
+                }
+                
+                Expression functionDefinition = context.getContinuation().replace(replaceContext);
+                boolean first = true;
+                for(int i=plan.variables.length-1;i>=0;--i) {
+                    functionDefinition = new ESimpleLambda(
+                            first ? Types.PROC /* FIXME */ : Types.NO_EFFECTS,
+                            newVariables[i], functionDefinition);
+                    first = false;
+                }
+                continuationFunction.setType(functionDefinition.getType());
+                
+                context.setContinuation(new ESimpleLet(
+                        continuationFunction, 
+                        functionDefinition, 
+                        result));
+            }
+        });
+    }
+
+    @Override
+    public Diff[] derivate(THashMap<LocalRelation, Diffable> diffables) throws DerivateException {
+        Diff[][] diffs = new Diff[queries.length][];
+        int totalDiffCount = 0;
+        for(int i=0;i<queries.length;++i) {
+            Diff[] ds = queries[i].derivate(diffables);
+            diffs[i] = ds;
+            totalDiffCount += ds.length;
+        }
+        if(totalDiffCount == 0)
+            return NO_DIFF;
+        else {
+            Diff[] result = new Diff[totalDiffCount];
+            int i=0;
+            for(Diff[] ds : diffs)
+                for(Diff diff : ds)
+                    result[i++] = diff;
+            return result;
+        }
+    }
+    
+    @Override
+    public Query replace(ReplaceContext context) {
+        Query[] newQueries = new Query[queries.length];
+        for(int i=0;i<queries.length;++i)
+            newQueries[i] = queries[i].replace(context);
+        return new QDisjunction(newQueries);
+    }
+
+    @Override
+    public Query removeRelations(Set<SCLRelation> relations) {
+        for(int i=0;i<queries.length;++i) {
+            Query query = queries[i];
+            Query newQuery = query.removeRelations(relations);
+            if(query != newQuery) {
+                ArrayList<Query> newQueries = new ArrayList<Query>(queries.length);
+                for(int j=0;j<i;++j)
+                    newQueries.add(queries[j]);
+                if(newQuery != EMPTY_QUERY)
+                    newQueries.add(newQuery);
+                for(++i;i<queries.length;++i) {
+                    query = queries[i];
+                    newQuery = query.removeRelations(relations);
+                    if(newQuery != EMPTY_QUERY)
+                        newQueries.add(newQuery);
+                }
+                if(newQueries.isEmpty())
+                    return EMPTY_QUERY;
+                else if(newQueries.size()==1)
+                    return newQueries.get(0);
+                else
+                    return new QDisjunction(newQueries.toArray(new Query[newQueries.size()]));
+            }
+        }
+        return this;
+    }
+    
+    @Override
+    public void accept(QueryVisitor visitor) {
+        visitor.visit(this);
+    }
+    
+    @Override
+    public Query accept(QueryTransformer transformer) {
+        return transformer.transform(this);
+    }
+
+}