]> gerrit.simantics Code Review - simantics/platform.git/blob - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/expressions/ERuleset.java
(refs #7375) Replaced ExpressionDecorator by ExpressionTransformer
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ERuleset.java
1 package org.simantics.scl.compiler.elaboration.expressions;
2
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;
19
20 import java.util.ArrayList;
21 import java.util.Set;
22
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
24 import org.simantics.scl.compiler.common.names.Names;
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
28 import org.simantics.scl.compiler.elaboration.query.Query;
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
34 import org.simantics.scl.compiler.errors.Locations;
35 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
36 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
37 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
38 import org.simantics.scl.compiler.types.Type;
39 import org.simantics.scl.compiler.types.Types;
40 import org.simantics.scl.compiler.types.exceptions.MatchException;
41 import org.simantics.scl.compiler.types.kinds.Kinds;
42
43 import gnu.trove.impl.Constants;
44 import gnu.trove.map.hash.THashMap;
45 import gnu.trove.map.hash.TObjectIntHashMap;
46 import gnu.trove.set.hash.THashSet;
47 import gnu.trove.set.hash.TIntHashSet;
48
49 public class ERuleset extends SimplifiableExpression {
50     LocalRelation[] relations;
51     DatalogRule[] rules;
52     Expression in;
53     
54     public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {
55         this.relations = relations;
56         this.rules = rules;
57         this.in = in;
58     }
59
60     public static class DatalogRule {
61         public long location;
62         public LocalRelation headRelation;
63         public Expression[] headParameters;
64         public Query body;
65         public Variable[] variables;
66         
67         public DatalogRule(LocalRelation headRelation, Expression[] headParameters,
68                 Query body) {
69             this.headRelation = headRelation;
70             this.headParameters = headParameters;
71             this.body = body;
72         }
73         
74         public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,
75                 Query body, Variable[] variables) {
76             this.location = location;
77             this.headRelation = headRelation;
78             this.headParameters = headParameters;
79             this.body = body;
80             this.variables = variables;
81         }
82
83         public void setLocationDeep(long loc) {
84             this.location = loc;
85             for(Expression parameter : headParameters)
86                 parameter.setLocationDeep(loc);
87             body.setLocationDeep(loc);
88         }
89         
90         @Override
91         public String toString() {
92             StringBuilder b = new StringBuilder();
93             ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
94             visitor.visit(this);
95             return b.toString();
96         }
97
98         public void forVariables(VariableProcedure procedure) {
99             for(Expression headParameter : headParameters)
100                 headParameter.forVariables(procedure);
101             body.forVariables(procedure);
102         }
103     }
104     
105     private void checkRuleTypes(TypingContext context) {
106         // Create relation variables
107         for(DatalogRule rule : rules) {
108             Type[] parameterTypes =  rule.headRelation.getParameterTypes();
109             Expression[] parameters = rule.headParameters;
110             for(Variable variable : rule.variables)
111                 variable.setType(Types.metaVar(Kinds.STAR));
112             for(int i=0;i<parameters.length;++i)
113                 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);
114             rule.body.checkType(context);
115         }
116     }
117     
118     @Override
119     public Expression checkBasicType(TypingContext context, Type requiredType) {
120         checkRuleTypes(context);
121         in = in.checkBasicType(context, requiredType);
122         return compile(context);
123     }
124     
125     @Override
126     public Expression inferType(TypingContext context) {
127         checkRuleTypes(context);
128         in = in.inferType(context);
129         return compile(context);
130     }
131     
132     @Override
133     public Expression checkIgnoredType(TypingContext context) {
134         checkRuleTypes(context);
135         in = in.checkIgnoredType(context);
136         return compile(context);
137     }
138     
139     @Override
140     public void collectFreeVariables(THashSet<Variable> vars) {
141         for(DatalogRule rule : rules) {
142             for(Expression parameter : rule.headParameters)
143                 parameter.collectFreeVariables(vars);
144             rule.body.collectFreeVariables(vars);
145             for(Variable var : rule.variables)
146                 vars.remove(var);
147         }
148         in.collectFreeVariables(vars);
149     }
150     
151     @Override
152     public void collectRefs(TObjectIntHashMap<Object> allRefs,
153             TIntHashSet refs) {
154         for(DatalogRule rule : rules) {
155             for(Expression parameter : rule.headParameters)
156                 parameter.collectRefs(allRefs, refs);
157             rule.body.collectRefs(allRefs, refs);
158         }
159         in.collectRefs(allRefs, refs);
160     }
161     
162     @Override
163     public void collectVars(TObjectIntHashMap<Variable> allVars,
164             TIntHashSet vars) {
165         for(DatalogRule rule : rules) {
166             for(Expression parameter : rule.headParameters)
167                 parameter.collectVars(allVars, vars);
168             rule.body.collectVars(allVars, vars);
169         }
170         in.collectVars(allVars, vars);
171     }
172     
173     @Override
174     public void collectEffects(THashSet<Type> effects) {
175         throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");
176     }
177     
178     @Override
179     public Expression resolve(TranslationContext context) {
180         throw new InternalCompilerError();
181     }
182     
183     static class LocalRelationAux {
184         Variable handleFunc;
185     }
186     
187     public Expression compile(TypingContext context) {
188         // Create a map from relations to their ids
189         TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,
190                 Constants.DEFAULT_LOAD_FACTOR, -1);
191         for(int i=0;i<relations.length;++i)
192             relationsToIds.put(relations[i], i);
193         
194         // Create a table from relations to the other relations they depend on
195         TIntHashSet[] refsSets = new TIntHashSet[relations.length];
196         int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);
197         for(int i=0;i<relations.length;++i)
198             refsSets[i] = new TIntHashSet(setCapacity);
199         
200         for(DatalogRule rule : rules) {
201             int headRelationId = relationsToIds.get(rule.headRelation);
202             TIntHashSet refsSet = refsSets[headRelationId];
203             rule.body.collectRelationRefs(relationsToIds, refsSet);
204             for(Expression parameter : rule.headParameters)
205                 parameter.collectRelationRefs(relationsToIds, refsSet);
206         }
207         
208         // Convert refsSets to an array
209         final int[][] refs = new int[relations.length][];
210         for(int i=0;i<relations.length;++i)
211             refs[i] = refsSets[i].toArray();
212         
213         // Find strongly connected components of the function refs
214         final ArrayList<int[]> components = new ArrayList<int[]>();
215         
216         new StronglyConnectedComponents(relations.length) {
217             @Override
218             protected void reportComponent(int[] component) {
219                 components.add(component);
220             }
221             
222             @Override
223             protected int[] findDependencies(int u) {
224                 return refs[u];
225             }
226         }.findComponents();
227         
228         // If there is just one component, compile it
229         if(components.size() == 1) {
230             return compileStratified(context);
231         }
232         
233         // Inverse of components array 
234         int[] strataPerRelation = new int[relations.length];
235         for(int i=0;i<components.size();++i)
236             for(int k : components.get(i))
237                 strataPerRelation[k] = i;
238         
239         // Collects rules belonging to each strata
240         @SuppressWarnings("unchecked")
241         ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];
242         for(int i=0;i<components.size();++i)
243             rulesPerStrata[i] = new ArrayList<DatalogRule>();
244         for(DatalogRule rule : rules) {
245             int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];
246             rulesPerStrata[stratum].add(rule);
247         }
248         
249         // Create stratified system
250         Expression cur = this.in;
251         for(int stratum=components.size()-1;stratum >= 0;--stratum) {
252             int[] cs = components.get(stratum);
253             LocalRelation[] curRelations = new LocalRelation[cs.length];
254             for(int i=0;i<cs.length;++i)
255                 curRelations[i] = relations[cs[i]];
256             ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];
257             cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);
258         }
259         return cur;
260     }
261     
262     private Expression compileStratified(TypingContext context) {
263         Expression continuation = Expressions.tuple();
264         
265         // Create stacks
266         Variable[] stacks = new Variable[relations.length];
267         for(int i=0;i<relations.length;++i) {
268             LocalRelation relation = relations[i];
269             Type[] parameterTypes = relation.getParameterTypes();
270             stacks[i] = newVar("stack" + relation.getName(),
271                     Types.apply(Names.MList_T, Types.tuple(parameterTypes))
272                     );
273         }
274
275         // Simplify subexpressions and collect derivatives
276         THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);
277         for(int i=0;i<relations.length;++i) {
278             LocalRelation relation = relations[i];
279             Type[] parameterTypes = relation.getParameterTypes();
280             Variable[] parameters = new Variable[parameterTypes.length];
281             for(int j=0;j<parameterTypes.length;++j)
282                 parameters[j] = new Variable("p" + j, parameterTypes[j]);
283             diffables.put(relations[i], new Diffable(i, relation, parameters));
284         }
285         @SuppressWarnings("unchecked")
286         ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];
287         for(int i=0;i<relations.length;++i)
288             updateExpressions[i] = new ArrayList<Expression>(2);
289         ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); 
290         for(DatalogRule rule : rules) {
291             int id = diffables.get(rule.headRelation).id;
292             Expression appendExp = apply(context.getCompilationContext(), Types.PROC, Names.MList_add, Types.tuple(rule.headRelation.getParameterTypes()),
293                     var(stacks[id]),
294                     tuple(rule.headParameters)
295                     );
296             Diff[] diffs;
297             try {
298                 diffs = rule.body.derivate(diffables);
299             } catch(DerivateException e) {
300                 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");
301                 return new EError();
302             }
303             for(Diff diff : diffs)
304                 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));
305             if(diffs.length == 0)
306                 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));
307             else {
308                 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());
309                 if(query != Query.EMPTY_QUERY)
310                     seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));
311             }
312         }
313         
314         // Iterative solving of relations
315
316         Variable[] loops = new Variable[relations.length];
317         for(int i=0;i<loops.length;++i)
318             loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));
319         continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);
320         
321         Expression[] loopDefs = new Expression[relations.length];
322         for(int i=0;i<relations.length;++i) {
323             LocalRelation relation = relations[i];
324             Type[] parameterTypes = relation.getParameterTypes();
325             Variable[] parameters = diffables.get(relation).parameters;
326             
327             Variable counter = newVar("counter", Types.INTEGER);
328             
329             Type rowType = Types.tuple(parameterTypes);
330             Variable row = newVar("row", rowType);
331             
332             Expression handleRow = tuple();
333             for(Expression updateExpression : updateExpressions[i])
334                 handleRow = seq(updateExpression, handleRow);
335             handleRow = if_(
336                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, rowType,
337                             var(relation.table), var(row)),
338                     handleRow,
339                     tuple()
340                     );
341             handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));
342             Expression failure =
343                     if_(isZeroInteger(var(counter)),
344                         tuple(),
345                         apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))
346                        );
347             Expression body = matchWithDefault(
348                     apply(context.getCompilationContext(), Types.PROC, Names.MList_removeLast, rowType, var(stacks[i])),
349                     Just(as(row, tuple(vars(parameters)))), handleRow,
350                     failure);
351             
352             loopDefs[i] = lambda(Types.PROC, counter, body); 
353         }
354         continuation = letRec(loops, loopDefs, continuation);
355         
356         // Seed relations
357         for(Expression seedExpression : seedExpressions)
358             continuation = seq(seedExpression, continuation);
359         
360         // Create stacks
361         for(int i=0;i<stacks.length;++i)
362             continuation = let(stacks[i],
363                     apply(context.getCompilationContext(), Types.PROC, Names.MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),
364                     continuation);
365         
366         continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);
367         
368         // Create relations
369         for(LocalRelation relation : relations)
370             continuation = let(relation.table,
371                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),
372                     continuation);
373         
374         return seq(continuation, in);
375     }
376
377     @Override
378     protected void updateType() throws MatchException {
379         setType(in.getType());
380     }
381     
382     @Override
383     public void setLocationDeep(long loc) {
384         if(location == Locations.NO_LOCATION) {
385             location = loc;
386             for(DatalogRule rule : rules)
387                 rule.setLocationDeep(loc);
388         }
389     }
390     
391     @Override
392     public void accept(ExpressionVisitor visitor) {
393         visitor.visit(this);
394     }
395
396     public DatalogRule[] getRules() {
397         return rules;
398     }
399     
400     public Expression getIn() {
401         return in;
402     }
403
404     @Override
405     public void forVariables(VariableProcedure procedure) {
406         for(DatalogRule rule : rules)
407             rule.forVariables(procedure);
408         in.forVariables(procedure);
409     }
410     
411     @Override
412     public Expression accept(ExpressionTransformer transformer) {
413         return transformer.transform(this);
414     }
415
416 }