]> gerrit.simantics Code Review - simantics/platform.git/blob - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/expressions/ERuleset.java
e512460298d5bb40e0316d5b4020d0c21f8184a3
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ERuleset.java
1 package org.simantics.scl.compiler.elaboration.expressions;\r
2 \r
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;\r
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;\r
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;\r
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;\r
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;\r
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;\r
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;\r
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;\r
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;\r
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;\r
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;\r
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;\r
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;\r
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;\r
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;\r
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;\r
19 \r
20 import java.util.ArrayList;\r
21 import java.util.Set;\r
22 \r
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;\r
24 import org.simantics.scl.compiler.common.names.Name;\r
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;\r
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;\r
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;\r
28 import org.simantics.scl.compiler.elaboration.query.Query;\r
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;\r
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;\r
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;\r
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;\r
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;\r
34 import org.simantics.scl.compiler.errors.Locations;\r
35 import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;\r
36 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;\r
37 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;\r
38 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;\r
39 import org.simantics.scl.compiler.types.TCon;\r
40 import org.simantics.scl.compiler.types.Type;\r
41 import org.simantics.scl.compiler.types.Types;\r
42 import org.simantics.scl.compiler.types.exceptions.MatchException;\r
43 import org.simantics.scl.compiler.types.kinds.Kinds;\r
44 \r
45 import gnu.trove.impl.Constants;\r
46 import gnu.trove.map.hash.THashMap;\r
47 import gnu.trove.map.hash.TObjectIntHashMap;\r
48 import gnu.trove.set.hash.THashSet;\r
49 import gnu.trove.set.hash.TIntHashSet;\r
50 \r
51 public class ERuleset extends SimplifiableExpression {\r
52     LocalRelation[] relations;\r
53     DatalogRule[] rules;\r
54     Expression in;\r
55     \r
56     public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {\r
57         this.relations = relations;\r
58         this.rules = rules;\r
59         this.in = in;\r
60     }\r
61 \r
62     public static class DatalogRule {\r
63         public long location;\r
64         public LocalRelation headRelation;\r
65         public Expression[] headParameters;\r
66         public Query body;\r
67         public Variable[] variables;\r
68         \r
69         public DatalogRule(LocalRelation headRelation, Expression[] headParameters,\r
70                 Query body) {\r
71             this.headRelation = headRelation;\r
72             this.headParameters = headParameters;\r
73             this.body = body;\r
74         }\r
75         \r
76         public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,\r
77                 Query body, Variable[] variables) {\r
78             this.location = location;\r
79             this.headRelation = headRelation;\r
80             this.headParameters = headParameters;\r
81             this.body = body;\r
82             this.variables = variables;\r
83         }\r
84 \r
85         public void setLocationDeep(long loc) {\r
86             this.location = loc;\r
87             for(Expression parameter : headParameters)\r
88                 parameter.setLocationDeep(loc);\r
89             body.setLocationDeep(loc);\r
90         }\r
91         \r
92         @Override\r
93         public String toString() {\r
94             StringBuilder b = new StringBuilder();\r
95             ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);\r
96             visitor.visit(this);\r
97             return b.toString();\r
98         }\r
99 \r
100         public void forVariables(VariableProcedure procedure) {\r
101             for(Expression headParameter : headParameters)\r
102                 headParameter.forVariables(procedure);\r
103             body.forVariables(procedure);\r
104         }\r
105     }\r
106     \r
107     private void checkRuleTypes(TypingContext context) {\r
108         // Create relation variables\r
109         for(DatalogRule rule : rules) {\r
110             Type[] parameterTypes =  rule.headRelation.getParameterTypes();\r
111             Expression[] parameters = rule.headParameters;\r
112             for(Variable variable : rule.variables)\r
113                 variable.setType(Types.metaVar(Kinds.STAR));\r
114             for(int i=0;i<parameters.length;++i)\r
115                 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);\r
116             rule.body.checkType(context);\r
117         }\r
118     }\r
119     \r
120     @Override\r
121     public Expression checkBasicType(TypingContext context, Type requiredType) {\r
122         checkRuleTypes(context);\r
123         in = in.checkBasicType(context, requiredType);\r
124         return compile(context);\r
125     }\r
126     \r
127     @Override\r
128     public Expression inferType(TypingContext context) {\r
129         checkRuleTypes(context);\r
130         in = in.inferType(context);\r
131         return compile(context);\r
132     }\r
133     \r
134     @Override\r
135     public void collectFreeVariables(THashSet<Variable> vars) {\r
136         for(DatalogRule rule : rules) {\r
137             for(Expression parameter : rule.headParameters)\r
138                 parameter.collectFreeVariables(vars);\r
139             rule.body.collectFreeVariables(vars);\r
140             for(Variable var : rule.variables)\r
141                 vars.remove(var);\r
142         }\r
143         in.collectFreeVariables(vars);\r
144     }\r
145     \r
146     @Override\r
147     public void collectRefs(TObjectIntHashMap<Object> allRefs,\r
148             TIntHashSet refs) {\r
149         for(DatalogRule rule : rules) {\r
150             for(Expression parameter : rule.headParameters)\r
151                 parameter.collectRefs(allRefs, refs);\r
152             rule.body.collectRefs(allRefs, refs);\r
153         }\r
154         in.collectRefs(allRefs, refs);\r
155     }\r
156     \r
157     @Override\r
158     public void collectVars(TObjectIntHashMap<Variable> allVars,\r
159             TIntHashSet vars) {\r
160         for(DatalogRule rule : rules) {\r
161             for(Expression parameter : rule.headParameters)\r
162                 parameter.collectVars(allVars, vars);\r
163             rule.body.collectVars(allVars, vars);\r
164         }\r
165         in.collectVars(allVars, vars);\r
166     }\r
167     \r
168     @Override\r
169     public void collectEffects(THashSet<Type> effects) {\r
170         throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");\r
171     }\r
172     \r
173     @Override\r
174     public Expression decorate(ExpressionDecorator decorator) {\r
175         return decorator.decorate(this);\r
176     }\r
177     \r
178     @Override\r
179     public Expression resolve(TranslationContext context) {\r
180         throw new InternalCompilerError();\r
181     }\r
182     \r
183     static class LocalRelationAux {\r
184         Variable handleFunc;\r
185     }\r
186     \r
187     public static final TCon MSet = Types.con("MSet", "T");\r
188     private static final Name MSet_add = Name.create("MSet", "add");\r
189     private static final Name MSet_create = Name.create("MSet", "create");\r
190 \r
191     private static final TCon MList = Types.con("MList", "T");\r
192     private static final Name MList_add = Name.create("MList", "add");\r
193     private static final Name MList_create = Name.create("MList", "create");\r
194     private static final Name MList_removeLast = Name.create("MList", "removeLast");\r
195     \r
196     public Expression compile(TypingContext context) {\r
197         // Create a map from relations to their ids\r
198         TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,\r
199                 Constants.DEFAULT_LOAD_FACTOR, -1);\r
200         for(int i=0;i<relations.length;++i)\r
201             relationsToIds.put(relations[i], i);\r
202         \r
203         // Create a table from relations to the other relations they depend on\r
204         TIntHashSet[] refsSets = new TIntHashSet[relations.length];\r
205         int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);\r
206         for(int i=0;i<relations.length;++i)\r
207             refsSets[i] = new TIntHashSet(setCapacity);\r
208         \r
209         for(DatalogRule rule : rules) {\r
210             int headRelationId = relationsToIds.get(rule.headRelation);\r
211             TIntHashSet refsSet = refsSets[headRelationId];\r
212             rule.body.collectRelationRefs(relationsToIds, refsSet);\r
213             for(Expression parameter : rule.headParameters)\r
214                 parameter.collectRelationRefs(relationsToIds, refsSet);\r
215         }\r
216         \r
217         // Convert refsSets to an array\r
218         final int[][] refs = new int[relations.length][];\r
219         for(int i=0;i<relations.length;++i)\r
220             refs[i] = refsSets[i].toArray();\r
221         \r
222         // Find strongly connected components of the function refs\r
223         final ArrayList<int[]> components = new ArrayList<int[]>();\r
224         \r
225         new StronglyConnectedComponents(relations.length) {\r
226             @Override\r
227             protected void reportComponent(int[] component) {\r
228                 components.add(component);\r
229             }\r
230             \r
231             @Override\r
232             protected int[] findDependencies(int u) {\r
233                 return refs[u];\r
234             }\r
235         }.findComponents();\r
236         \r
237         // If there is just one component, compile it\r
238         if(components.size() == 1) {\r
239             return compileStratified(context);\r
240         }\r
241         \r
242         // Inverse of components array \r
243         int[] strataPerRelation = new int[relations.length];\r
244         for(int i=0;i<components.size();++i)\r
245             for(int k : components.get(i))\r
246                 strataPerRelation[k] = i;\r
247         \r
248         // Collects rules belonging to each strata\r
249         @SuppressWarnings("unchecked")\r
250         ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];\r
251         for(int i=0;i<components.size();++i)\r
252             rulesPerStrata[i] = new ArrayList<DatalogRule>();\r
253         for(DatalogRule rule : rules) {\r
254             int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];\r
255             rulesPerStrata[stratum].add(rule);\r
256         }\r
257         \r
258         // Create stratified system\r
259         Expression cur = this.in;\r
260         for(int stratum=components.size()-1;stratum >= 0;--stratum) {\r
261             int[] cs = components.get(stratum);\r
262             LocalRelation[] curRelations = new LocalRelation[cs.length];\r
263             for(int i=0;i<cs.length;++i)\r
264                 curRelations[i] = relations[cs[i]];\r
265             ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];\r
266             cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);\r
267         }\r
268         return cur;\r
269     }\r
270     \r
271     private Expression compileStratified(TypingContext context) {\r
272         Expression continuation = Expressions.tuple();\r
273         \r
274         // Create stacks\r
275         Variable[] stacks = new Variable[relations.length];\r
276         for(int i=0;i<relations.length;++i) {\r
277             LocalRelation relation = relations[i];\r
278             Type[] parameterTypes = relation.getParameterTypes();\r
279             stacks[i] = newVar("stack" + relation.getName(),\r
280                     Types.apply(MList, Types.tuple(parameterTypes))\r
281                     );\r
282         }\r
283 \r
284         // Simplify subexpressions and collect derivatives\r
285         THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);\r
286         for(int i=0;i<relations.length;++i) {\r
287             LocalRelation relation = relations[i];\r
288             Type[] parameterTypes = relation.getParameterTypes();\r
289             Variable[] parameters = new Variable[parameterTypes.length];\r
290             for(int j=0;j<parameterTypes.length;++j)\r
291                 parameters[j] = new Variable("p" + j, parameterTypes[j]);\r
292             diffables.put(relations[i], new Diffable(i, relation, parameters));\r
293         }\r
294         @SuppressWarnings("unchecked")\r
295         ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];\r
296         for(int i=0;i<relations.length;++i)\r
297             updateExpressions[i] = new ArrayList<Expression>(2);\r
298         ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); \r
299         for(DatalogRule rule : rules) {\r
300             int id = diffables.get(rule.headRelation).id;\r
301             Expression appendExp = apply(context, Types.PROC, MList_add, Types.tuple(rule.headRelation.getParameterTypes()),\r
302                     var(stacks[id]),\r
303                     tuple(rule.headParameters)\r
304                     );\r
305             Diff[] diffs;\r
306             try {\r
307                 diffs = rule.body.derivate(diffables);\r
308             } catch(DerivateException e) {\r
309                 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");\r
310                 return new EError();\r
311             }\r
312             for(Diff diff : diffs)\r
313                 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));\r
314             if(diffs.length == 0)\r
315                 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));\r
316             else {\r
317                 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());\r
318                 if(query != Query.EMPTY_QUERY)\r
319                     seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));\r
320             }\r
321         }\r
322         \r
323         // Iterative solving of relations\r
324 \r
325         Variable[] loops = new Variable[relations.length];\r
326         for(int i=0;i<loops.length;++i)\r
327             loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));\r
328         continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);\r
329         \r
330         Expression[] loopDefs = new Expression[relations.length];\r
331         for(int i=0;i<relations.length;++i) {\r
332             LocalRelation relation = relations[i];\r
333             Type[] parameterTypes = relation.getParameterTypes();\r
334             Variable[] parameters = diffables.get(relation).parameters;\r
335             \r
336             Variable counter = newVar("counter", Types.INTEGER);\r
337             \r
338             Type rowType = Types.tuple(parameterTypes);\r
339             Variable row = newVar("row", rowType);\r
340             \r
341             Expression handleRow = tuple();\r
342             for(Expression updateExpression : updateExpressions[i])\r
343                 handleRow = seq(updateExpression, handleRow);\r
344             handleRow = if_(\r
345                     apply(context, Types.PROC, MSet_add, rowType,\r
346                             var(relation.table), var(row)),\r
347                     handleRow,\r
348                     tuple()\r
349                     );\r
350             handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));\r
351             Expression failure =\r
352                     if_(isZeroInteger(var(counter)),\r
353                         tuple(),\r
354                         apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))\r
355                        );\r
356             Expression body = matchWithDefault(\r
357                     apply(context, Types.PROC, MList_removeLast, rowType, var(stacks[i])),\r
358                     Just(as(row, tuple(vars(parameters)))), handleRow,\r
359                     failure);\r
360             \r
361             loopDefs[i] = lambda(Types.PROC, counter, body); \r
362         }\r
363         continuation = letRec(loops, loopDefs, continuation);\r
364         \r
365         // Seed relations\r
366         for(Expression seedExpression : seedExpressions)\r
367             continuation = seq(seedExpression, continuation);\r
368         \r
369         // Create stacks\r
370         for(int i=0;i<stacks.length;++i)\r
371             continuation = let(stacks[i],\r
372                     apply(context, Types.PROC, MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),\r
373                     continuation);\r
374         \r
375         continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);\r
376         \r
377         // Create relations\r
378         for(LocalRelation relation : relations)\r
379             continuation = let(relation.table,\r
380                     apply(context, Types.PROC, MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),\r
381                     continuation);\r
382         \r
383         return seq(continuation, in);\r
384     }\r
385 \r
386     @Override\r
387     protected void updateType() throws MatchException {\r
388         setType(in.getType());\r
389     }\r
390     \r
391     @Override\r
392     public void setLocationDeep(long loc) {\r
393         if(location == Locations.NO_LOCATION) {\r
394             location = loc;\r
395             for(DatalogRule rule : rules)\r
396                 rule.setLocationDeep(loc);\r
397         }\r
398     }\r
399     \r
400     @Override\r
401     public void accept(ExpressionVisitor visitor) {\r
402         visitor.visit(this);\r
403     }\r
404 \r
405     public DatalogRule[] getRules() {\r
406         return rules;\r
407     }\r
408     \r
409     public Expression getIn() {\r
410         return in;\r
411     }\r
412 \r
413     @Override\r
414     public void forVariables(VariableProcedure procedure) {\r
415         for(DatalogRule rule : rules)\r
416             rule.forVariables(procedure);\r
417         in.forVariables(procedure);\r
418     }\r
419     \r
420     @Override\r
421     public Expression accept(ExpressionTransformer transformer) {\r
422         return transformer.transform(this);\r
423     }\r
424 \r
425 }\r