]> gerrit.simantics Code Review - simantics/platform.git/blob - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/expressions/ERuleset.java
Merge "Remove unused import in DeleteHandler"
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ERuleset.java
1 package org.simantics.scl.compiler.elaboration.expressions;\r
2 \r
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;\r
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;\r
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;\r
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;\r
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;\r
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;\r
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;\r
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;\r
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;\r
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;\r
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;\r
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;\r
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;\r
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;\r
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;\r
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;\r
19 \r
20 import java.util.ArrayList;\r
21 import java.util.Set;\r
22 \r
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;\r
24 import org.simantics.scl.compiler.common.names.Names;\r
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;\r
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;\r
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;\r
28 import org.simantics.scl.compiler.elaboration.query.Query;\r
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;\r
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;\r
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;\r
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;\r
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;\r
34 import org.simantics.scl.compiler.errors.Locations;\r
35 import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;\r
36 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;\r
37 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;\r
38 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;\r
39 import org.simantics.scl.compiler.types.Type;\r
40 import org.simantics.scl.compiler.types.Types;\r
41 import org.simantics.scl.compiler.types.exceptions.MatchException;\r
42 import org.simantics.scl.compiler.types.kinds.Kinds;\r
43 \r
44 import gnu.trove.impl.Constants;\r
45 import gnu.trove.map.hash.THashMap;\r
46 import gnu.trove.map.hash.TObjectIntHashMap;\r
47 import gnu.trove.set.hash.THashSet;\r
48 import gnu.trove.set.hash.TIntHashSet;\r
49 \r
50 public class ERuleset extends SimplifiableExpression {\r
51     LocalRelation[] relations;\r
52     DatalogRule[] rules;\r
53     Expression in;\r
54     \r
55     public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {\r
56         this.relations = relations;\r
57         this.rules = rules;\r
58         this.in = in;\r
59     }\r
60 \r
61     public static class DatalogRule {\r
62         public long location;\r
63         public LocalRelation headRelation;\r
64         public Expression[] headParameters;\r
65         public Query body;\r
66         public Variable[] variables;\r
67         \r
68         public DatalogRule(LocalRelation headRelation, Expression[] headParameters,\r
69                 Query body) {\r
70             this.headRelation = headRelation;\r
71             this.headParameters = headParameters;\r
72             this.body = body;\r
73         }\r
74         \r
75         public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,\r
76                 Query body, Variable[] variables) {\r
77             this.location = location;\r
78             this.headRelation = headRelation;\r
79             this.headParameters = headParameters;\r
80             this.body = body;\r
81             this.variables = variables;\r
82         }\r
83 \r
84         public void setLocationDeep(long loc) {\r
85             this.location = loc;\r
86             for(Expression parameter : headParameters)\r
87                 parameter.setLocationDeep(loc);\r
88             body.setLocationDeep(loc);\r
89         }\r
90         \r
91         @Override\r
92         public String toString() {\r
93             StringBuilder b = new StringBuilder();\r
94             ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);\r
95             visitor.visit(this);\r
96             return b.toString();\r
97         }\r
98 \r
99         public void forVariables(VariableProcedure procedure) {\r
100             for(Expression headParameter : headParameters)\r
101                 headParameter.forVariables(procedure);\r
102             body.forVariables(procedure);\r
103         }\r
104     }\r
105     \r
106     private void checkRuleTypes(TypingContext context) {\r
107         // Create relation variables\r
108         for(DatalogRule rule : rules) {\r
109             Type[] parameterTypes =  rule.headRelation.getParameterTypes();\r
110             Expression[] parameters = rule.headParameters;\r
111             for(Variable variable : rule.variables)\r
112                 variable.setType(Types.metaVar(Kinds.STAR));\r
113             for(int i=0;i<parameters.length;++i)\r
114                 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);\r
115             rule.body.checkType(context);\r
116         }\r
117     }\r
118     \r
119     @Override\r
120     public Expression checkBasicType(TypingContext context, Type requiredType) {\r
121         checkRuleTypes(context);\r
122         in = in.checkBasicType(context, requiredType);\r
123         return compile(context);\r
124     }\r
125     \r
126     @Override\r
127     public Expression inferType(TypingContext context) {\r
128         checkRuleTypes(context);\r
129         in = in.inferType(context);\r
130         return compile(context);\r
131     }\r
132     \r
133     @Override\r
134     public Expression checkIgnoredType(TypingContext context) {\r
135         checkRuleTypes(context);\r
136         in = in.checkIgnoredType(context);\r
137         return compile(context);\r
138     }\r
139     \r
140     @Override\r
141     public void collectFreeVariables(THashSet<Variable> vars) {\r
142         for(DatalogRule rule : rules) {\r
143             for(Expression parameter : rule.headParameters)\r
144                 parameter.collectFreeVariables(vars);\r
145             rule.body.collectFreeVariables(vars);\r
146             for(Variable var : rule.variables)\r
147                 vars.remove(var);\r
148         }\r
149         in.collectFreeVariables(vars);\r
150     }\r
151     \r
152     @Override\r
153     public void collectRefs(TObjectIntHashMap<Object> allRefs,\r
154             TIntHashSet refs) {\r
155         for(DatalogRule rule : rules) {\r
156             for(Expression parameter : rule.headParameters)\r
157                 parameter.collectRefs(allRefs, refs);\r
158             rule.body.collectRefs(allRefs, refs);\r
159         }\r
160         in.collectRefs(allRefs, refs);\r
161     }\r
162     \r
163     @Override\r
164     public void collectVars(TObjectIntHashMap<Variable> allVars,\r
165             TIntHashSet vars) {\r
166         for(DatalogRule rule : rules) {\r
167             for(Expression parameter : rule.headParameters)\r
168                 parameter.collectVars(allVars, vars);\r
169             rule.body.collectVars(allVars, vars);\r
170         }\r
171         in.collectVars(allVars, vars);\r
172     }\r
173     \r
174     @Override\r
175     public void collectEffects(THashSet<Type> effects) {\r
176         throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");\r
177     }\r
178     \r
179     @Override\r
180     public Expression decorate(ExpressionDecorator decorator) {\r
181         return decorator.decorate(this);\r
182     }\r
183     \r
184     @Override\r
185     public Expression resolve(TranslationContext context) {\r
186         throw new InternalCompilerError();\r
187     }\r
188     \r
189     static class LocalRelationAux {\r
190         Variable handleFunc;\r
191     }\r
192     \r
193     public Expression compile(TypingContext context) {\r
194         // Create a map from relations to their ids\r
195         TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,\r
196                 Constants.DEFAULT_LOAD_FACTOR, -1);\r
197         for(int i=0;i<relations.length;++i)\r
198             relationsToIds.put(relations[i], i);\r
199         \r
200         // Create a table from relations to the other relations they depend on\r
201         TIntHashSet[] refsSets = new TIntHashSet[relations.length];\r
202         int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);\r
203         for(int i=0;i<relations.length;++i)\r
204             refsSets[i] = new TIntHashSet(setCapacity);\r
205         \r
206         for(DatalogRule rule : rules) {\r
207             int headRelationId = relationsToIds.get(rule.headRelation);\r
208             TIntHashSet refsSet = refsSets[headRelationId];\r
209             rule.body.collectRelationRefs(relationsToIds, refsSet);\r
210             for(Expression parameter : rule.headParameters)\r
211                 parameter.collectRelationRefs(relationsToIds, refsSet);\r
212         }\r
213         \r
214         // Convert refsSets to an array\r
215         final int[][] refs = new int[relations.length][];\r
216         for(int i=0;i<relations.length;++i)\r
217             refs[i] = refsSets[i].toArray();\r
218         \r
219         // Find strongly connected components of the function refs\r
220         final ArrayList<int[]> components = new ArrayList<int[]>();\r
221         \r
222         new StronglyConnectedComponents(relations.length) {\r
223             @Override\r
224             protected void reportComponent(int[] component) {\r
225                 components.add(component);\r
226             }\r
227             \r
228             @Override\r
229             protected int[] findDependencies(int u) {\r
230                 return refs[u];\r
231             }\r
232         }.findComponents();\r
233         \r
234         // If there is just one component, compile it\r
235         if(components.size() == 1) {\r
236             return compileStratified(context);\r
237         }\r
238         \r
239         // Inverse of components array \r
240         int[] strataPerRelation = new int[relations.length];\r
241         for(int i=0;i<components.size();++i)\r
242             for(int k : components.get(i))\r
243                 strataPerRelation[k] = i;\r
244         \r
245         // Collects rules belonging to each strata\r
246         @SuppressWarnings("unchecked")\r
247         ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];\r
248         for(int i=0;i<components.size();++i)\r
249             rulesPerStrata[i] = new ArrayList<DatalogRule>();\r
250         for(DatalogRule rule : rules) {\r
251             int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];\r
252             rulesPerStrata[stratum].add(rule);\r
253         }\r
254         \r
255         // Create stratified system\r
256         Expression cur = this.in;\r
257         for(int stratum=components.size()-1;stratum >= 0;--stratum) {\r
258             int[] cs = components.get(stratum);\r
259             LocalRelation[] curRelations = new LocalRelation[cs.length];\r
260             for(int i=0;i<cs.length;++i)\r
261                 curRelations[i] = relations[cs[i]];\r
262             ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];\r
263             cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);\r
264         }\r
265         return cur;\r
266     }\r
267     \r
268     private Expression compileStratified(TypingContext context) {\r
269         Expression continuation = Expressions.tuple();\r
270         \r
271         // Create stacks\r
272         Variable[] stacks = new Variable[relations.length];\r
273         for(int i=0;i<relations.length;++i) {\r
274             LocalRelation relation = relations[i];\r
275             Type[] parameterTypes = relation.getParameterTypes();\r
276             stacks[i] = newVar("stack" + relation.getName(),\r
277                     Types.apply(Names.MList_T, Types.tuple(parameterTypes))\r
278                     );\r
279         }\r
280 \r
281         // Simplify subexpressions and collect derivatives\r
282         THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);\r
283         for(int i=0;i<relations.length;++i) {\r
284             LocalRelation relation = relations[i];\r
285             Type[] parameterTypes = relation.getParameterTypes();\r
286             Variable[] parameters = new Variable[parameterTypes.length];\r
287             for(int j=0;j<parameterTypes.length;++j)\r
288                 parameters[j] = new Variable("p" + j, parameterTypes[j]);\r
289             diffables.put(relations[i], new Diffable(i, relation, parameters));\r
290         }\r
291         @SuppressWarnings("unchecked")\r
292         ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];\r
293         for(int i=0;i<relations.length;++i)\r
294             updateExpressions[i] = new ArrayList<Expression>(2);\r
295         ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); \r
296         for(DatalogRule rule : rules) {\r
297             int id = diffables.get(rule.headRelation).id;\r
298             Expression appendExp = apply(context.getCompilationContext(), Types.PROC, Names.MList_add, Types.tuple(rule.headRelation.getParameterTypes()),\r
299                     var(stacks[id]),\r
300                     tuple(rule.headParameters)\r
301                     );\r
302             Diff[] diffs;\r
303             try {\r
304                 diffs = rule.body.derivate(diffables);\r
305             } catch(DerivateException e) {\r
306                 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");\r
307                 return new EError();\r
308             }\r
309             for(Diff diff : diffs)\r
310                 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));\r
311             if(diffs.length == 0)\r
312                 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));\r
313             else {\r
314                 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());\r
315                 if(query != Query.EMPTY_QUERY)\r
316                     seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));\r
317             }\r
318         }\r
319         \r
320         // Iterative solving of relations\r
321 \r
322         Variable[] loops = new Variable[relations.length];\r
323         for(int i=0;i<loops.length;++i)\r
324             loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));\r
325         continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);\r
326         \r
327         Expression[] loopDefs = new Expression[relations.length];\r
328         for(int i=0;i<relations.length;++i) {\r
329             LocalRelation relation = relations[i];\r
330             Type[] parameterTypes = relation.getParameterTypes();\r
331             Variable[] parameters = diffables.get(relation).parameters;\r
332             \r
333             Variable counter = newVar("counter", Types.INTEGER);\r
334             \r
335             Type rowType = Types.tuple(parameterTypes);\r
336             Variable row = newVar("row", rowType);\r
337             \r
338             Expression handleRow = tuple();\r
339             for(Expression updateExpression : updateExpressions[i])\r
340                 handleRow = seq(updateExpression, handleRow);\r
341             handleRow = if_(\r
342                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, rowType,\r
343                             var(relation.table), var(row)),\r
344                     handleRow,\r
345                     tuple()\r
346                     );\r
347             handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));\r
348             Expression failure =\r
349                     if_(isZeroInteger(var(counter)),\r
350                         tuple(),\r
351                         apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))\r
352                        );\r
353             Expression body = matchWithDefault(\r
354                     apply(context.getCompilationContext(), Types.PROC, Names.MList_removeLast, rowType, var(stacks[i])),\r
355                     Just(as(row, tuple(vars(parameters)))), handleRow,\r
356                     failure);\r
357             \r
358             loopDefs[i] = lambda(Types.PROC, counter, body); \r
359         }\r
360         continuation = letRec(loops, loopDefs, continuation);\r
361         \r
362         // Seed relations\r
363         for(Expression seedExpression : seedExpressions)\r
364             continuation = seq(seedExpression, continuation);\r
365         \r
366         // Create stacks\r
367         for(int i=0;i<stacks.length;++i)\r
368             continuation = let(stacks[i],\r
369                     apply(context.getCompilationContext(), Types.PROC, Names.MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),\r
370                     continuation);\r
371         \r
372         continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);\r
373         \r
374         // Create relations\r
375         for(LocalRelation relation : relations)\r
376             continuation = let(relation.table,\r
377                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),\r
378                     continuation);\r
379         \r
380         return seq(continuation, in);\r
381     }\r
382 \r
383     @Override\r
384     protected void updateType() throws MatchException {\r
385         setType(in.getType());\r
386     }\r
387     \r
388     @Override\r
389     public void setLocationDeep(long loc) {\r
390         if(location == Locations.NO_LOCATION) {\r
391             location = loc;\r
392             for(DatalogRule rule : rules)\r
393                 rule.setLocationDeep(loc);\r
394         }\r
395     }\r
396     \r
397     @Override\r
398     public void accept(ExpressionVisitor visitor) {\r
399         visitor.visit(this);\r
400     }\r
401 \r
402     public DatalogRule[] getRules() {\r
403         return rules;\r
404     }\r
405     \r
406     public Expression getIn() {\r
407         return in;\r
408     }\r
409 \r
410     @Override\r
411     public void forVariables(VariableProcedure procedure) {\r
412         for(DatalogRule rule : rules)\r
413             rule.forVariables(procedure);\r
414         in.forVariables(procedure);\r
415     }\r
416     \r
417     @Override\r
418     public Expression accept(ExpressionTransformer transformer) {\r
419         return transformer.transform(this);\r
420     }\r
421 \r
422 }\r