]> gerrit.simantics Code Review - simantics/platform.git/blob - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/expressions/ERuleset.java
e784813d69fec37c9793ccae98a5ccd38fcc2756
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ERuleset.java
1 package org.simantics.scl.compiler.elaboration.expressions;
2
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;
19
20 import java.util.ArrayList;
21 import java.util.Set;
22
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
24 import org.simantics.scl.compiler.common.names.Names;
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
28 import org.simantics.scl.compiler.elaboration.query.Query;
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
34 import org.simantics.scl.compiler.errors.Locations;
35 import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;
36 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
37 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
38 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
39 import org.simantics.scl.compiler.types.Type;
40 import org.simantics.scl.compiler.types.Types;
41 import org.simantics.scl.compiler.types.exceptions.MatchException;
42 import org.simantics.scl.compiler.types.kinds.Kinds;
43
44 import gnu.trove.impl.Constants;
45 import gnu.trove.map.hash.THashMap;
46 import gnu.trove.map.hash.TObjectIntHashMap;
47 import gnu.trove.set.hash.THashSet;
48 import gnu.trove.set.hash.TIntHashSet;
49
50 public class ERuleset extends SimplifiableExpression {
51     LocalRelation[] relations;
52     DatalogRule[] rules;
53     Expression in;
54     
55     public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {
56         this.relations = relations;
57         this.rules = rules;
58         this.in = in;
59     }
60
61     public static class DatalogRule {
62         public long location;
63         public LocalRelation headRelation;
64         public Expression[] headParameters;
65         public Query body;
66         public Variable[] variables;
67         
68         public DatalogRule(LocalRelation headRelation, Expression[] headParameters,
69                 Query body) {
70             this.headRelation = headRelation;
71             this.headParameters = headParameters;
72             this.body = body;
73         }
74         
75         public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,
76                 Query body, Variable[] variables) {
77             this.location = location;
78             this.headRelation = headRelation;
79             this.headParameters = headParameters;
80             this.body = body;
81             this.variables = variables;
82         }
83
84         public void setLocationDeep(long loc) {
85             this.location = loc;
86             for(Expression parameter : headParameters)
87                 parameter.setLocationDeep(loc);
88             body.setLocationDeep(loc);
89         }
90         
91         @Override
92         public String toString() {
93             StringBuilder b = new StringBuilder();
94             ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
95             visitor.visit(this);
96             return b.toString();
97         }
98
99         public void forVariables(VariableProcedure procedure) {
100             for(Expression headParameter : headParameters)
101                 headParameter.forVariables(procedure);
102             body.forVariables(procedure);
103         }
104     }
105     
106     private void checkRuleTypes(TypingContext context) {
107         // Create relation variables
108         for(DatalogRule rule : rules) {
109             Type[] parameterTypes =  rule.headRelation.getParameterTypes();
110             Expression[] parameters = rule.headParameters;
111             for(Variable variable : rule.variables)
112                 variable.setType(Types.metaVar(Kinds.STAR));
113             for(int i=0;i<parameters.length;++i)
114                 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);
115             rule.body.checkType(context);
116         }
117     }
118     
119     @Override
120     public Expression checkBasicType(TypingContext context, Type requiredType) {
121         checkRuleTypes(context);
122         in = in.checkBasicType(context, requiredType);
123         return compile(context);
124     }
125     
126     @Override
127     public Expression inferType(TypingContext context) {
128         checkRuleTypes(context);
129         in = in.inferType(context);
130         return compile(context);
131     }
132     
133     @Override
134     public Expression checkIgnoredType(TypingContext context) {
135         checkRuleTypes(context);
136         in = in.checkIgnoredType(context);
137         return compile(context);
138     }
139     
140     @Override
141     public void collectFreeVariables(THashSet<Variable> vars) {
142         for(DatalogRule rule : rules) {
143             for(Expression parameter : rule.headParameters)
144                 parameter.collectFreeVariables(vars);
145             rule.body.collectFreeVariables(vars);
146             for(Variable var : rule.variables)
147                 vars.remove(var);
148         }
149         in.collectFreeVariables(vars);
150     }
151     
152     @Override
153     public void collectRefs(TObjectIntHashMap<Object> allRefs,
154             TIntHashSet refs) {
155         for(DatalogRule rule : rules) {
156             for(Expression parameter : rule.headParameters)
157                 parameter.collectRefs(allRefs, refs);
158             rule.body.collectRefs(allRefs, refs);
159         }
160         in.collectRefs(allRefs, refs);
161     }
162     
163     @Override
164     public void collectVars(TObjectIntHashMap<Variable> allVars,
165             TIntHashSet vars) {
166         for(DatalogRule rule : rules) {
167             for(Expression parameter : rule.headParameters)
168                 parameter.collectVars(allVars, vars);
169             rule.body.collectVars(allVars, vars);
170         }
171         in.collectVars(allVars, vars);
172     }
173     
174     @Override
175     public void collectEffects(THashSet<Type> effects) {
176         throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");
177     }
178     
179     @Override
180     public Expression decorate(ExpressionDecorator decorator) {
181         return decorator.decorate(this);
182     }
183     
184     @Override
185     public Expression resolve(TranslationContext context) {
186         throw new InternalCompilerError();
187     }
188     
189     static class LocalRelationAux {
190         Variable handleFunc;
191     }
192     
193     public Expression compile(TypingContext context) {
194         // Create a map from relations to their ids
195         TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,
196                 Constants.DEFAULT_LOAD_FACTOR, -1);
197         for(int i=0;i<relations.length;++i)
198             relationsToIds.put(relations[i], i);
199         
200         // Create a table from relations to the other relations they depend on
201         TIntHashSet[] refsSets = new TIntHashSet[relations.length];
202         int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);
203         for(int i=0;i<relations.length;++i)
204             refsSets[i] = new TIntHashSet(setCapacity);
205         
206         for(DatalogRule rule : rules) {
207             int headRelationId = relationsToIds.get(rule.headRelation);
208             TIntHashSet refsSet = refsSets[headRelationId];
209             rule.body.collectRelationRefs(relationsToIds, refsSet);
210             for(Expression parameter : rule.headParameters)
211                 parameter.collectRelationRefs(relationsToIds, refsSet);
212         }
213         
214         // Convert refsSets to an array
215         final int[][] refs = new int[relations.length][];
216         for(int i=0;i<relations.length;++i)
217             refs[i] = refsSets[i].toArray();
218         
219         // Find strongly connected components of the function refs
220         final ArrayList<int[]> components = new ArrayList<int[]>();
221         
222         new StronglyConnectedComponents(relations.length) {
223             @Override
224             protected void reportComponent(int[] component) {
225                 components.add(component);
226             }
227             
228             @Override
229             protected int[] findDependencies(int u) {
230                 return refs[u];
231             }
232         }.findComponents();
233         
234         // If there is just one component, compile it
235         if(components.size() == 1) {
236             return compileStratified(context);
237         }
238         
239         // Inverse of components array 
240         int[] strataPerRelation = new int[relations.length];
241         for(int i=0;i<components.size();++i)
242             for(int k : components.get(i))
243                 strataPerRelation[k] = i;
244         
245         // Collects rules belonging to each strata
246         @SuppressWarnings("unchecked")
247         ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];
248         for(int i=0;i<components.size();++i)
249             rulesPerStrata[i] = new ArrayList<DatalogRule>();
250         for(DatalogRule rule : rules) {
251             int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];
252             rulesPerStrata[stratum].add(rule);
253         }
254         
255         // Create stratified system
256         Expression cur = this.in;
257         for(int stratum=components.size()-1;stratum >= 0;--stratum) {
258             int[] cs = components.get(stratum);
259             LocalRelation[] curRelations = new LocalRelation[cs.length];
260             for(int i=0;i<cs.length;++i)
261                 curRelations[i] = relations[cs[i]];
262             ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];
263             cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);
264         }
265         return cur;
266     }
267     
268     private Expression compileStratified(TypingContext context) {
269         Expression continuation = Expressions.tuple();
270         
271         // Create stacks
272         Variable[] stacks = new Variable[relations.length];
273         for(int i=0;i<relations.length;++i) {
274             LocalRelation relation = relations[i];
275             Type[] parameterTypes = relation.getParameterTypes();
276             stacks[i] = newVar("stack" + relation.getName(),
277                     Types.apply(Names.MList_T, Types.tuple(parameterTypes))
278                     );
279         }
280
281         // Simplify subexpressions and collect derivatives
282         THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);
283         for(int i=0;i<relations.length;++i) {
284             LocalRelation relation = relations[i];
285             Type[] parameterTypes = relation.getParameterTypes();
286             Variable[] parameters = new Variable[parameterTypes.length];
287             for(int j=0;j<parameterTypes.length;++j)
288                 parameters[j] = new Variable("p" + j, parameterTypes[j]);
289             diffables.put(relations[i], new Diffable(i, relation, parameters));
290         }
291         @SuppressWarnings("unchecked")
292         ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];
293         for(int i=0;i<relations.length;++i)
294             updateExpressions[i] = new ArrayList<Expression>(2);
295         ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); 
296         for(DatalogRule rule : rules) {
297             int id = diffables.get(rule.headRelation).id;
298             Expression appendExp = apply(context.getCompilationContext(), Types.PROC, Names.MList_add, Types.tuple(rule.headRelation.getParameterTypes()),
299                     var(stacks[id]),
300                     tuple(rule.headParameters)
301                     );
302             Diff[] diffs;
303             try {
304                 diffs = rule.body.derivate(diffables);
305             } catch(DerivateException e) {
306                 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");
307                 return new EError();
308             }
309             for(Diff diff : diffs)
310                 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));
311             if(diffs.length == 0)
312                 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));
313             else {
314                 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());
315                 if(query != Query.EMPTY_QUERY)
316                     seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));
317             }
318         }
319         
320         // Iterative solving of relations
321
322         Variable[] loops = new Variable[relations.length];
323         for(int i=0;i<loops.length;++i)
324             loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));
325         continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);
326         
327         Expression[] loopDefs = new Expression[relations.length];
328         for(int i=0;i<relations.length;++i) {
329             LocalRelation relation = relations[i];
330             Type[] parameterTypes = relation.getParameterTypes();
331             Variable[] parameters = diffables.get(relation).parameters;
332             
333             Variable counter = newVar("counter", Types.INTEGER);
334             
335             Type rowType = Types.tuple(parameterTypes);
336             Variable row = newVar("row", rowType);
337             
338             Expression handleRow = tuple();
339             for(Expression updateExpression : updateExpressions[i])
340                 handleRow = seq(updateExpression, handleRow);
341             handleRow = if_(
342                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, rowType,
343                             var(relation.table), var(row)),
344                     handleRow,
345                     tuple()
346                     );
347             handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));
348             Expression failure =
349                     if_(isZeroInteger(var(counter)),
350                         tuple(),
351                         apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))
352                        );
353             Expression body = matchWithDefault(
354                     apply(context.getCompilationContext(), Types.PROC, Names.MList_removeLast, rowType, var(stacks[i])),
355                     Just(as(row, tuple(vars(parameters)))), handleRow,
356                     failure);
357             
358             loopDefs[i] = lambda(Types.PROC, counter, body); 
359         }
360         continuation = letRec(loops, loopDefs, continuation);
361         
362         // Seed relations
363         for(Expression seedExpression : seedExpressions)
364             continuation = seq(seedExpression, continuation);
365         
366         // Create stacks
367         for(int i=0;i<stacks.length;++i)
368             continuation = let(stacks[i],
369                     apply(context.getCompilationContext(), Types.PROC, Names.MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),
370                     continuation);
371         
372         continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);
373         
374         // Create relations
375         for(LocalRelation relation : relations)
376             continuation = let(relation.table,
377                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),
378                     continuation);
379         
380         return seq(continuation, in);
381     }
382
383     @Override
384     protected void updateType() throws MatchException {
385         setType(in.getType());
386     }
387     
388     @Override
389     public void setLocationDeep(long loc) {
390         if(location == Locations.NO_LOCATION) {
391             location = loc;
392             for(DatalogRule rule : rules)
393                 rule.setLocationDeep(loc);
394         }
395     }
396     
397     @Override
398     public void accept(ExpressionVisitor visitor) {
399         visitor.visit(this);
400     }
401
402     public DatalogRule[] getRules() {
403         return rules;
404     }
405     
406     public Expression getIn() {
407         return in;
408     }
409
410     @Override
411     public void forVariables(VariableProcedure procedure) {
412         for(DatalogRule rule : rules)
413             rule.forVariables(procedure);
414         in.forVariables(procedure);
415     }
416     
417     @Override
418     public Expression accept(ExpressionTransformer transformer) {
419         return transformer.transform(this);
420     }
421
422 }