--- /dev/null
+package org.simantics.scl.compiler.elaboration.query;
+
+import gnu.trove.map.hash.THashMap;
+import gnu.trove.map.hash.TLongObjectHashMap;
+import gnu.trove.set.hash.TIntHashSet;
+
+import java.util.ArrayList;
+import java.util.Set;
+
+import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
+import org.simantics.scl.compiler.elaboration.expressions.EApply;
+import org.simantics.scl.compiler.elaboration.expressions.ESimpleLambda;
+import org.simantics.scl.compiler.elaboration.expressions.ESimpleLet;
+import org.simantics.scl.compiler.elaboration.expressions.EVariable;
+import org.simantics.scl.compiler.elaboration.expressions.Expression;
+import org.simantics.scl.compiler.elaboration.expressions.QueryTransformer;
+import org.simantics.scl.compiler.elaboration.expressions.Variable;
+import org.simantics.scl.compiler.elaboration.query.compilation.ConstraintCollectionContext;
+import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
+import org.simantics.scl.compiler.elaboration.query.compilation.QueryCompilationContext;
+import org.simantics.scl.compiler.elaboration.query.compilation.QueryConstraint;
+import org.simantics.scl.compiler.elaboration.query.compilation.UnsolvableQueryException;
+import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
+import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
+import org.simantics.scl.compiler.errors.Locations;
+import org.simantics.scl.compiler.types.Types;
+
+
+public class QDisjunction extends QAbstractCombiner {
+
+ public QDisjunction(Query ... queries) {
+ super(queries);
+ }
+
+ private static class CachedPlan {
+ Variable[] variables;
+ QueryCompilationContext[] subplans;
+ double totalBranching;
+ double totalCost;
+
+ public CachedPlan(Variable[] variables, QueryCompilationContext[] subplans,
+ double totalBranching, double totalCost) {
+ this.variables = variables;
+ this.subplans = subplans;
+ this.totalBranching = totalBranching;
+ this.totalCost = totalCost;
+ }
+ }
+
+ @Override
+ public void collectConstraints(final ConstraintCollectionContext context) {
+ TIntHashSet vars = new TIntHashSet();
+ collectVars(context.getVariableMap(), vars);
+
+ final Variable continuationFunction = new Variable("continuation");
+ int[] variables = vars.toArray();
+ long variableMask_ = 0L;
+ for(int v : variables)
+ variableMask_ |= 1L << v;
+ final long variableMask = variableMask_;
+
+ context.addConstraint(new QueryConstraint(variables) {
+
+ TLongObjectHashMap<CachedPlan> cache = new TLongObjectHashMap<CachedPlan>();
+
+ private CachedPlan create(long boundVariables) {
+ QueryCompilationContext[] subplans = new QueryCompilationContext[queries.length];
+ double totalBranching = 1.0;
+ double totalCost = 0.0;
+ ArrayList<Variable> solvedVariablesList = new ArrayList<Variable>();
+ for(int v : variables)
+ if( ((boundVariables >> v)&1) == 0 )
+ solvedVariablesList.add(context.getVariable(v));
+ Variable[] solvedVariables = solvedVariablesList.toArray(new Variable[solvedVariablesList.size()]);
+ for(int i=0;i<queries.length;++i) {
+ Expression[] parameters = new Expression[solvedVariables.length];
+ for(int j=0;j<solvedVariables.length;++j)
+ parameters[j] = new EVariable(solvedVariables[j]);
+ EApply cont = new EApply(Locations.NO_LOCATION, Types.PROC,
+ new EVariable(continuationFunction), parameters);
+ cont.setType(context.getQueryCompilationContext().getContinuation().getType());
+ subplans[i] = context.getQueryCompilationContext().createSubcontext(cont);
+ try {
+ new QExists(solvedVariables, queries[i]).generate(subplans[i]);
+ } catch (UnsolvableQueryException e) {
+ return new CachedPlan(null, null, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
+ }
+ totalBranching += subplans[i].getBranching();
+ totalCost += subplans[i].getCost();
+ }
+ return new CachedPlan(solvedVariables, subplans, totalBranching, totalCost);
+ }
+
+ private CachedPlan get(long boundVariables) {
+ boundVariables &= variableMask;
+ CachedPlan plan = cache.get(boundVariables);
+ if(plan == null) {
+ plan = create(boundVariables);
+ cache.put(boundVariables, plan);
+ }
+ return plan;
+ }
+
+ @Override
+ public double getSolutionCost(long boundVariables) {
+ return get(boundVariables).totalCost;
+ }
+
+ @Override
+ public double getSolutionBranching(long boundVariables) {
+ return get(boundVariables).totalBranching;
+ }
+
+ @Override
+ public boolean canBeSolvedFrom(long boundVariables) {
+ return get(boundVariables).totalCost != Double.POSITIVE_INFINITY;
+ }
+
+ @Override
+ public void generate(QueryCompilationContext context) {
+ CachedPlan plan = get(finalBoundVariables);
+
+ Expression[] disjuncts = new Expression[plan.subplans.length];
+ for(int i=0;i<plan.subplans.length;++i)
+ disjuncts[i] = plan.subplans[i].getContinuation().copy(context.getTypingContext());
+ Expression result = context.disjunction(disjuncts);
+
+ ReplaceContext replaceContext = new ReplaceContext(context.getTypingContext());
+ Variable[] newVariables = new Variable[plan.variables.length];
+ for(int i=0;i<newVariables.length;++i) {
+ Variable oldVariable = plan.variables[i];
+ Variable newVariable = new Variable(oldVariable.getName(), oldVariable.getType());
+ newVariables[i] = newVariable;
+ oldVariable.setName(oldVariable.getName() + "_temp");
+ replaceContext.varMap.put(oldVariable, new EVariable(newVariable));
+ }
+
+ Expression functionDefinition = context.getContinuation().replace(replaceContext);
+ boolean first = true;
+ for(int i=plan.variables.length-1;i>=0;--i) {
+ functionDefinition = new ESimpleLambda(
+ first ? Types.PROC /* FIXME */ : Types.NO_EFFECTS,
+ newVariables[i], functionDefinition);
+ first = false;
+ }
+ continuationFunction.setType(functionDefinition.getType());
+
+ context.setContinuation(new ESimpleLet(
+ continuationFunction,
+ functionDefinition,
+ result));
+ }
+ });
+ }
+
+ @Override
+ public Diff[] derivate(THashMap<LocalRelation, Diffable> diffables) throws DerivateException {
+ Diff[][] diffs = new Diff[queries.length][];
+ int totalDiffCount = 0;
+ for(int i=0;i<queries.length;++i) {
+ Diff[] ds = queries[i].derivate(diffables);
+ diffs[i] = ds;
+ totalDiffCount += ds.length;
+ }
+ if(totalDiffCount == 0)
+ return NO_DIFF;
+ else {
+ Diff[] result = new Diff[totalDiffCount];
+ int i=0;
+ for(Diff[] ds : diffs)
+ for(Diff diff : ds)
+ result[i++] = diff;
+ return result;
+ }
+ }
+
+ @Override
+ public Query replace(ReplaceContext context) {
+ Query[] newQueries = new Query[queries.length];
+ for(int i=0;i<queries.length;++i)
+ newQueries[i] = queries[i].replace(context);
+ return new QDisjunction(newQueries);
+ }
+
+ @Override
+ public Query removeRelations(Set<SCLRelation> relations) {
+ for(int i=0;i<queries.length;++i) {
+ Query query = queries[i];
+ Query newQuery = query.removeRelations(relations);
+ if(query != newQuery) {
+ ArrayList<Query> newQueries = new ArrayList<Query>(queries.length);
+ for(int j=0;j<i;++j)
+ newQueries.add(queries[j]);
+ if(newQuery != EMPTY_QUERY)
+ newQueries.add(newQuery);
+ for(++i;i<queries.length;++i) {
+ query = queries[i];
+ newQuery = query.removeRelations(relations);
+ if(newQuery != EMPTY_QUERY)
+ newQueries.add(newQuery);
+ }
+ if(newQueries.isEmpty())
+ return EMPTY_QUERY;
+ else if(newQueries.size()==1)
+ return newQueries.get(0);
+ else
+ return new QDisjunction(newQueries.toArray(new Query[newQueries.size()]));
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public void accept(QueryVisitor visitor) {
+ visitor.visit(this);
+ }
+
+ @Override
+ public Query accept(QueryTransformer transformer) {
+ return transformer.transform(this);
+ }
+
+}