]> gerrit.simantics Code Review - simantics/platform.git/blob - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/expressions/ERuleset.java
migrated to svn revision 33108
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ERuleset.java
1 package org.simantics.scl.compiler.elaboration.expressions;\r
2 \r
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;\r
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;\r
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;\r
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;\r
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;\r
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;\r
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;\r
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;\r
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;\r
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;\r
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;\r
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;\r
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;\r
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;\r
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;\r
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;\r
19 \r
20 import java.util.ArrayList;\r
21 import java.util.Set;\r
22 \r
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;\r
24 import org.simantics.scl.compiler.common.names.Name;\r
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;\r
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;\r
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;\r
28 import org.simantics.scl.compiler.elaboration.query.Query;\r
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;\r
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;\r
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;\r
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;\r
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;\r
34 import org.simantics.scl.compiler.errors.Locations;\r
35 import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;\r
36 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;\r
37 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;\r
38 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;\r
39 import org.simantics.scl.compiler.types.TCon;\r
40 import org.simantics.scl.compiler.types.Type;\r
41 import org.simantics.scl.compiler.types.Types;\r
42 import org.simantics.scl.compiler.types.exceptions.MatchException;\r
43 import org.simantics.scl.compiler.types.kinds.Kinds;\r
44 \r
45 import gnu.trove.impl.Constants;\r
46 import gnu.trove.map.hash.THashMap;\r
47 import gnu.trove.map.hash.TObjectIntHashMap;\r
48 import gnu.trove.set.hash.THashSet;\r
49 import gnu.trove.set.hash.TIntHashSet;\r
50 \r
51 public class ERuleset extends SimplifiableExpression {\r
52     LocalRelation[] relations;\r
53     DatalogRule[] rules;\r
54     Expression in;\r
55     \r
56     public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {\r
57         this.relations = relations;\r
58         this.rules = rules;\r
59         this.in = in;\r
60     }\r
61 \r
62     public static class DatalogRule {\r
63         public long location;\r
64         public LocalRelation headRelation;\r
65         public Expression[] headParameters;\r
66         public Query body;\r
67         public Variable[] variables;\r
68         \r
69         public DatalogRule(LocalRelation headRelation, Expression[] headParameters,\r
70                 Query body) {\r
71             this.headRelation = headRelation;\r
72             this.headParameters = headParameters;\r
73             this.body = body;\r
74         }\r
75         \r
76         public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,\r
77                 Query body, Variable[] variables) {\r
78             this.location = location;\r
79             this.headRelation = headRelation;\r
80             this.headParameters = headParameters;\r
81             this.body = body;\r
82             this.variables = variables;\r
83         }\r
84 \r
85         public void setLocationDeep(long loc) {\r
86             this.location = loc;\r
87             for(Expression parameter : headParameters)\r
88                 parameter.setLocationDeep(loc);\r
89             body.setLocationDeep(loc);\r
90         }\r
91         \r
92         @Override\r
93         public String toString() {\r
94             StringBuilder b = new StringBuilder();\r
95             ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);\r
96             visitor.visit(this);\r
97             return b.toString();\r
98         }\r
99 \r
100         public void forVariables(VariableProcedure procedure) {\r
101             for(Expression headParameter : headParameters)\r
102                 headParameter.forVariables(procedure);\r
103             body.forVariables(procedure);\r
104         }\r
105     }\r
106     \r
107     private void checkRuleTypes(TypingContext context) {\r
108         // Create relation variables\r
109         for(DatalogRule rule : rules) {\r
110             Type[] parameterTypes =  rule.headRelation.getParameterTypes();\r
111             Expression[] parameters = rule.headParameters;\r
112             for(Variable variable : rule.variables)\r
113                 variable.setType(Types.metaVar(Kinds.STAR));\r
114             for(int i=0;i<parameters.length;++i)\r
115                 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);\r
116             rule.body.checkType(context);\r
117         }\r
118     }\r
119     \r
120     @Override\r
121     public Expression checkBasicType(TypingContext context, Type requiredType) {\r
122         checkRuleTypes(context);\r
123         in = in.checkBasicType(context, requiredType);\r
124         return compile(context);\r
125     }\r
126     \r
127     @Override\r
128     public Expression inferType(TypingContext context) {\r
129         checkRuleTypes(context);\r
130         in = in.inferType(context);\r
131         return compile(context);\r
132     }\r
133     \r
134     @Override\r
135     public Expression checkIgnoredType(TypingContext context) {\r
136         checkRuleTypes(context);\r
137         in = in.checkIgnoredType(context);\r
138         return compile(context);\r
139     }\r
140     \r
141     @Override\r
142     public void collectFreeVariables(THashSet<Variable> vars) {\r
143         for(DatalogRule rule : rules) {\r
144             for(Expression parameter : rule.headParameters)\r
145                 parameter.collectFreeVariables(vars);\r
146             rule.body.collectFreeVariables(vars);\r
147             for(Variable var : rule.variables)\r
148                 vars.remove(var);\r
149         }\r
150         in.collectFreeVariables(vars);\r
151     }\r
152     \r
153     @Override\r
154     public void collectRefs(TObjectIntHashMap<Object> allRefs,\r
155             TIntHashSet refs) {\r
156         for(DatalogRule rule : rules) {\r
157             for(Expression parameter : rule.headParameters)\r
158                 parameter.collectRefs(allRefs, refs);\r
159             rule.body.collectRefs(allRefs, refs);\r
160         }\r
161         in.collectRefs(allRefs, refs);\r
162     }\r
163     \r
164     @Override\r
165     public void collectVars(TObjectIntHashMap<Variable> allVars,\r
166             TIntHashSet vars) {\r
167         for(DatalogRule rule : rules) {\r
168             for(Expression parameter : rule.headParameters)\r
169                 parameter.collectVars(allVars, vars);\r
170             rule.body.collectVars(allVars, vars);\r
171         }\r
172         in.collectVars(allVars, vars);\r
173     }\r
174     \r
175     @Override\r
176     public void collectEffects(THashSet<Type> effects) {\r
177         throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");\r
178     }\r
179     \r
180     @Override\r
181     public Expression decorate(ExpressionDecorator decorator) {\r
182         return decorator.decorate(this);\r
183     }\r
184     \r
185     @Override\r
186     public Expression resolve(TranslationContext context) {\r
187         throw new InternalCompilerError();\r
188     }\r
189     \r
190     static class LocalRelationAux {\r
191         Variable handleFunc;\r
192     }\r
193     \r
194     public static final TCon MSet = Types.con("MSet", "T");\r
195     private static final Name MSet_add = Name.create("MSet", "add");\r
196     private static final Name MSet_create = Name.create("MSet", "create");\r
197 \r
198     private static final TCon MList = Types.con("MList", "T");\r
199     private static final Name MList_add = Name.create("MList", "add");\r
200     private static final Name MList_create = Name.create("MList", "create");\r
201     private static final Name MList_removeLast = Name.create("MList", "removeLast");\r
202     \r
203     public Expression compile(TypingContext context) {\r
204         // Create a map from relations to their ids\r
205         TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,\r
206                 Constants.DEFAULT_LOAD_FACTOR, -1);\r
207         for(int i=0;i<relations.length;++i)\r
208             relationsToIds.put(relations[i], i);\r
209         \r
210         // Create a table from relations to the other relations they depend on\r
211         TIntHashSet[] refsSets = new TIntHashSet[relations.length];\r
212         int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);\r
213         for(int i=0;i<relations.length;++i)\r
214             refsSets[i] = new TIntHashSet(setCapacity);\r
215         \r
216         for(DatalogRule rule : rules) {\r
217             int headRelationId = relationsToIds.get(rule.headRelation);\r
218             TIntHashSet refsSet = refsSets[headRelationId];\r
219             rule.body.collectRelationRefs(relationsToIds, refsSet);\r
220             for(Expression parameter : rule.headParameters)\r
221                 parameter.collectRelationRefs(relationsToIds, refsSet);\r
222         }\r
223         \r
224         // Convert refsSets to an array\r
225         final int[][] refs = new int[relations.length][];\r
226         for(int i=0;i<relations.length;++i)\r
227             refs[i] = refsSets[i].toArray();\r
228         \r
229         // Find strongly connected components of the function refs\r
230         final ArrayList<int[]> components = new ArrayList<int[]>();\r
231         \r
232         new StronglyConnectedComponents(relations.length) {\r
233             @Override\r
234             protected void reportComponent(int[] component) {\r
235                 components.add(component);\r
236             }\r
237             \r
238             @Override\r
239             protected int[] findDependencies(int u) {\r
240                 return refs[u];\r
241             }\r
242         }.findComponents();\r
243         \r
244         // If there is just one component, compile it\r
245         if(components.size() == 1) {\r
246             return compileStratified(context);\r
247         }\r
248         \r
249         // Inverse of components array \r
250         int[] strataPerRelation = new int[relations.length];\r
251         for(int i=0;i<components.size();++i)\r
252             for(int k : components.get(i))\r
253                 strataPerRelation[k] = i;\r
254         \r
255         // Collects rules belonging to each strata\r
256         @SuppressWarnings("unchecked")\r
257         ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];\r
258         for(int i=0;i<components.size();++i)\r
259             rulesPerStrata[i] = new ArrayList<DatalogRule>();\r
260         for(DatalogRule rule : rules) {\r
261             int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];\r
262             rulesPerStrata[stratum].add(rule);\r
263         }\r
264         \r
265         // Create stratified system\r
266         Expression cur = this.in;\r
267         for(int stratum=components.size()-1;stratum >= 0;--stratum) {\r
268             int[] cs = components.get(stratum);\r
269             LocalRelation[] curRelations = new LocalRelation[cs.length];\r
270             for(int i=0;i<cs.length;++i)\r
271                 curRelations[i] = relations[cs[i]];\r
272             ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];\r
273             cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);\r
274         }\r
275         return cur;\r
276     }\r
277     \r
278     private Expression compileStratified(TypingContext context) {\r
279         Expression continuation = Expressions.tuple();\r
280         \r
281         // Create stacks\r
282         Variable[] stacks = new Variable[relations.length];\r
283         for(int i=0;i<relations.length;++i) {\r
284             LocalRelation relation = relations[i];\r
285             Type[] parameterTypes = relation.getParameterTypes();\r
286             stacks[i] = newVar("stack" + relation.getName(),\r
287                     Types.apply(MList, Types.tuple(parameterTypes))\r
288                     );\r
289         }\r
290 \r
291         // Simplify subexpressions and collect derivatives\r
292         THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);\r
293         for(int i=0;i<relations.length;++i) {\r
294             LocalRelation relation = relations[i];\r
295             Type[] parameterTypes = relation.getParameterTypes();\r
296             Variable[] parameters = new Variable[parameterTypes.length];\r
297             for(int j=0;j<parameterTypes.length;++j)\r
298                 parameters[j] = new Variable("p" + j, parameterTypes[j]);\r
299             diffables.put(relations[i], new Diffable(i, relation, parameters));\r
300         }\r
301         @SuppressWarnings("unchecked")\r
302         ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];\r
303         for(int i=0;i<relations.length;++i)\r
304             updateExpressions[i] = new ArrayList<Expression>(2);\r
305         ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); \r
306         for(DatalogRule rule : rules) {\r
307             int id = diffables.get(rule.headRelation).id;\r
308             Expression appendExp = apply(context, Types.PROC, MList_add, Types.tuple(rule.headRelation.getParameterTypes()),\r
309                     var(stacks[id]),\r
310                     tuple(rule.headParameters)\r
311                     );\r
312             Diff[] diffs;\r
313             try {\r
314                 diffs = rule.body.derivate(diffables);\r
315             } catch(DerivateException e) {\r
316                 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");\r
317                 return new EError();\r
318             }\r
319             for(Diff diff : diffs)\r
320                 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));\r
321             if(diffs.length == 0)\r
322                 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));\r
323             else {\r
324                 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());\r
325                 if(query != Query.EMPTY_QUERY)\r
326                     seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));\r
327             }\r
328         }\r
329         \r
330         // Iterative solving of relations\r
331 \r
332         Variable[] loops = new Variable[relations.length];\r
333         for(int i=0;i<loops.length;++i)\r
334             loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));\r
335         continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);\r
336         \r
337         Expression[] loopDefs = new Expression[relations.length];\r
338         for(int i=0;i<relations.length;++i) {\r
339             LocalRelation relation = relations[i];\r
340             Type[] parameterTypes = relation.getParameterTypes();\r
341             Variable[] parameters = diffables.get(relation).parameters;\r
342             \r
343             Variable counter = newVar("counter", Types.INTEGER);\r
344             \r
345             Type rowType = Types.tuple(parameterTypes);\r
346             Variable row = newVar("row", rowType);\r
347             \r
348             Expression handleRow = tuple();\r
349             for(Expression updateExpression : updateExpressions[i])\r
350                 handleRow = seq(updateExpression, handleRow);\r
351             handleRow = if_(\r
352                     apply(context, Types.PROC, MSet_add, rowType,\r
353                             var(relation.table), var(row)),\r
354                     handleRow,\r
355                     tuple()\r
356                     );\r
357             handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));\r
358             Expression failure =\r
359                     if_(isZeroInteger(var(counter)),\r
360                         tuple(),\r
361                         apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))\r
362                        );\r
363             Expression body = matchWithDefault(\r
364                     apply(context, Types.PROC, MList_removeLast, rowType, var(stacks[i])),\r
365                     Just(as(row, tuple(vars(parameters)))), handleRow,\r
366                     failure);\r
367             \r
368             loopDefs[i] = lambda(Types.PROC, counter, body); \r
369         }\r
370         continuation = letRec(loops, loopDefs, continuation);\r
371         \r
372         // Seed relations\r
373         for(Expression seedExpression : seedExpressions)\r
374             continuation = seq(seedExpression, continuation);\r
375         \r
376         // Create stacks\r
377         for(int i=0;i<stacks.length;++i)\r
378             continuation = let(stacks[i],\r
379                     apply(context, Types.PROC, MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),\r
380                     continuation);\r
381         \r
382         continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);\r
383         \r
384         // Create relations\r
385         for(LocalRelation relation : relations)\r
386             continuation = let(relation.table,\r
387                     apply(context, Types.PROC, MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),\r
388                     continuation);\r
389         \r
390         return seq(continuation, in);\r
391     }\r
392 \r
393     @Override\r
394     protected void updateType() throws MatchException {\r
395         setType(in.getType());\r
396     }\r
397     \r
398     @Override\r
399     public void setLocationDeep(long loc) {\r
400         if(location == Locations.NO_LOCATION) {\r
401             location = loc;\r
402             for(DatalogRule rule : rules)\r
403                 rule.setLocationDeep(loc);\r
404         }\r
405     }\r
406     \r
407     @Override\r
408     public void accept(ExpressionVisitor visitor) {\r
409         visitor.visit(this);\r
410     }\r
411 \r
412     public DatalogRule[] getRules() {\r
413         return rules;\r
414     }\r
415     \r
416     public Expression getIn() {\r
417         return in;\r
418     }\r
419 \r
420     @Override\r
421     public void forVariables(VariableProcedure procedure) {\r
422         for(DatalogRule rule : rules)\r
423             rule.forVariables(procedure);\r
424         in.forVariables(procedure);\r
425     }\r
426     \r
427     @Override\r
428     public Expression accept(ExpressionTransformer transformer) {\r
429         return transformer.transform(this);\r
430     }\r
431 \r
432 }\r