--- /dev/null
+package org.simantics.scl.compiler.elaboration.query.compilation;\r
+\r
+import gnu.trove.map.hash.THashMap;\r
+import gnu.trove.procedure.TObjectObjectProcedure;\r
+import gnu.trove.set.hash.TIntHashSet;\r
+\r
+import java.util.ArrayList;\r
+\r
+import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;\r
+import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;\r
+import org.simantics.scl.compiler.elaboration.expressions.EVariable;\r
+import org.simantics.scl.compiler.elaboration.expressions.Expression;\r
+import org.simantics.scl.compiler.elaboration.expressions.Variable;\r
+import org.simantics.scl.compiler.errors.Locations;\r
+import org.simantics.scl.compiler.types.TVar;\r
+import org.simantics.scl.compiler.types.Type;\r
+\r
+public class ExpressionConstraint extends QueryConstraint {\r
+ Variable variable;\r
+ Expression expression;\r
+ boolean isPattern;\r
+ \r
+ long forwardVariableMask;\r
+ long backwardVariableMask;\r
+ \r
+ ArrayList<Variable> globalVariables;\r
+ \r
+ public ExpressionConstraint(final ConstraintCollectionContext context, Variable variable,\r
+ Expression expression, boolean isPattern) {\r
+ this.variable = variable;\r
+ this.expression = expression;\r
+ this.isPattern = isPattern;\r
+ \r
+ final TIntHashSet vars = new TIntHashSet();\r
+ expression.collectVars(context.getVariableMap(), vars);\r
+ \r
+ int var1 = context.variableMap.get(variable);\r
+ vars.add(var1);\r
+ backwardVariableMask = 1L << var1;\r
+ \r
+ variables = vars.toArray();\r
+ \r
+ for(int v : variables)\r
+ forwardVariableMask |= 1L << v;\r
+ forwardVariableMask ^= backwardVariableMask;\r
+ \r
+ this.globalVariables = context.variables;\r
+ }\r
+ \r
+ private boolean canBeSolvedForwards(long boundVariables) {\r
+ return (forwardVariableMask & boundVariables) == forwardVariableMask;\r
+ }\r
+ \r
+ private boolean canBeSolvedBackwards(long boundVariables) {\r
+ return (backwardVariableMask & boundVariables) == backwardVariableMask;\r
+ }\r
+ \r
+ @Override\r
+ public boolean canBeSolvedFrom(long boundVariables) {\r
+ return canBeSolvedForwards(boundVariables) || (isPattern && canBeSolvedBackwards(boundVariables)); \r
+ }\r
+ \r
+ @Override\r
+ public double getSolutionCost(long boundVariables) {\r
+ return 1.0;\r
+ }\r
+ \r
+ @Override\r
+ public double getSolutionBranching(long boundVariables) {\r
+ if(canBeSolvedForwards(boundVariables))\r
+ return (boundVariables&1)==0 ? 1.0 : 0.95;\r
+ else if(isPattern && canBeSolvedBackwards(boundVariables))\r
+ return 0.95;\r
+ else\r
+ return Double.POSITIVE_INFINITY;\r
+ }\r
+ \r
+ @Override\r
+ public void generate(final QueryCompilationContext context) {\r
+ if(canBeSolvedForwards(finalBoundVariables)) {\r
+ if(canBeSolvedBackwards(finalBoundVariables))\r
+ context.equalityCondition(expression.location, new EVariable(variable), expression);\r
+ else\r
+ context.let(variable, expression);\r
+ }\r
+ else if(canBeSolvedBackwards(finalBoundVariables)) {\r
+ Expression pattern = expression;\r
+ \r
+ long mask = forwardVariableMask & finalBoundVariables;\r
+ THashMap<Variable, Expression> map = new THashMap<Variable, Expression>();\r
+ if(mask != 0L) {\r
+ for(int variableId : variables)\r
+ if( ((mask >> variableId)&1L) == 1L ) {\r
+ Variable original = globalVariables.get(variableId);\r
+ Variable newVariable = new Variable(original.getName() + "_temp", original.getType());\r
+ map.put(original, new EVariable(newVariable));\r
+ }\r
+ \r
+ ReplaceContext replaceContext = new ReplaceContext(new THashMap<TVar,Type>(0), map, context.getTypingContext());\r
+ pattern = pattern.replace(replaceContext);\r
+ }\r
+ context.match(pattern, new EVariable(variable), true);\r
+ map.forEachEntry(new TObjectObjectProcedure<Variable, Expression>() {\r
+ @Override\r
+ public boolean execute(Variable a, Expression b) {\r
+ context.equalityCondition(Locations.NO_LOCATION, new EVariable(a), b);\r
+ return true;\r
+ }\r
+ });\r
+ }\r
+ else\r
+ throw new InternalCompilerError(expression.location, "Error happened when tried to solve the query.");\r
+ }\r
+}\r