]> gerrit.simantics Code Review - simantics/platform.git/blob - bundles/org.simantics.scl.compiler/src/org/simantics/scl/compiler/elaboration/expressions/ERuleset.java
Fixed multiple issues causing dangling references to discarded queries
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ERuleset.java
1 package org.simantics.scl.compiler.elaboration.expressions;
2
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;
19
20 import java.util.ArrayList;
21 import java.util.Set;
22
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
24 import org.simantics.scl.compiler.common.names.Names;
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
28 import org.simantics.scl.compiler.elaboration.query.Query;
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
34 import org.simantics.scl.compiler.errors.Locations;
35 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
36 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
37 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
38 import org.simantics.scl.compiler.types.Type;
39 import org.simantics.scl.compiler.types.Types;
40 import org.simantics.scl.compiler.types.exceptions.MatchException;
41 import org.simantics.scl.compiler.types.kinds.Kinds;
42
43 import gnu.trove.impl.Constants;
44 import gnu.trove.map.hash.THashMap;
45 import gnu.trove.map.hash.TObjectIntHashMap;
46 import gnu.trove.set.hash.TIntHashSet;
47
48 public class ERuleset extends SimplifiableExpression {
49     LocalRelation[] relations;
50     public DatalogRule[] rules;
51     public Expression in;
52     
53     public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {
54         this.relations = relations;
55         this.rules = rules;
56         this.in = in;
57     }
58
59     public static class DatalogRule {
60         public long location;
61         public LocalRelation headRelation;
62         public Expression[] headParameters;
63         public Query body;
64         public Variable[] variables;
65         
66         public DatalogRule(LocalRelation headRelation, Expression[] headParameters,
67                 Query body) {
68             this.headRelation = headRelation;
69             this.headParameters = headParameters;
70             this.body = body;
71         }
72         
73         public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,
74                 Query body, Variable[] variables) {
75             this.location = location;
76             this.headRelation = headRelation;
77             this.headParameters = headParameters;
78             this.body = body;
79             this.variables = variables;
80         }
81
82         public void setLocationDeep(long loc) {
83             this.location = loc;
84             for(Expression parameter : headParameters)
85                 parameter.setLocationDeep(loc);
86             body.setLocationDeep(loc);
87         }
88         
89         @Override
90         public String toString() {
91             StringBuilder b = new StringBuilder();
92             ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
93             visitor.visit(this);
94             return b.toString();
95         }
96     }
97     
98     private void checkRuleTypes(TypingContext context) {
99         // Create relation variables
100         for(DatalogRule rule : rules) {
101             Type[] parameterTypes =  rule.headRelation.getParameterTypes();
102             Expression[] parameters = rule.headParameters;
103             for(Variable variable : rule.variables)
104                 variable.setType(Types.metaVar(Kinds.STAR));
105             for(int i=0;i<parameters.length;++i)
106                 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);
107             rule.body.checkType(context);
108         }
109     }
110     
111     @Override
112     public Expression checkBasicType(TypingContext context, Type requiredType) {
113         checkRuleTypes(context);
114         in = in.checkBasicType(context, requiredType);
115         return compile(context);
116     }
117     
118     @Override
119     public Expression inferType(TypingContext context) {
120         checkRuleTypes(context);
121         in = in.inferType(context);
122         return compile(context);
123     }
124     
125     @Override
126     public Expression checkIgnoredType(TypingContext context) {
127         checkRuleTypes(context);
128         in = in.checkIgnoredType(context);
129         return compile(context);
130     }
131     
132     @Override
133     public Expression resolve(TranslationContext context) {
134         throw new InternalCompilerError();
135     }
136     
137     static class LocalRelationAux {
138         Variable handleFunc;
139     }
140     
141     public Expression compile(TypingContext context) {
142         // Create a map from relations to their ids
143         TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,
144                 Constants.DEFAULT_LOAD_FACTOR, -1);
145         for(int i=0;i<relations.length;++i)
146             relationsToIds.put(relations[i], i);
147         
148         // Create a table from relations to the other relations they depend on
149         TIntHashSet[] refsSets = new TIntHashSet[relations.length];
150         int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);
151         for(int i=0;i<relations.length;++i)
152             refsSets[i] = new TIntHashSet(setCapacity);
153         
154         for(DatalogRule rule : rules) {
155             int headRelationId = relationsToIds.get(rule.headRelation);
156             TIntHashSet refsSet = refsSets[headRelationId];
157             rule.body.collectRelationRefs(relationsToIds, refsSet);
158             for(Expression parameter : rule.headParameters)
159                 parameter.collectRelationRefs(relationsToIds, refsSet);
160         }
161         
162         // Convert refsSets to an array
163         final int[][] refs = new int[relations.length][];
164         for(int i=0;i<relations.length;++i)
165             refs[i] = refsSets[i].toArray();
166         
167         // Find strongly connected components of the function refs
168         final ArrayList<int[]> components = new ArrayList<int[]>();
169         
170         new StronglyConnectedComponents(relations.length) {
171             @Override
172             protected void reportComponent(int[] component) {
173                 components.add(component);
174             }
175             
176             @Override
177             protected int[] findDependencies(int u) {
178                 return refs[u];
179             }
180         }.findComponents();
181         
182         // If there is just one component, compile it
183         if(components.size() == 1) {
184             return compileStratified(context);
185         }
186         
187         // Inverse of components array 
188         int[] strataPerRelation = new int[relations.length];
189         for(int i=0;i<components.size();++i)
190             for(int k : components.get(i))
191                 strataPerRelation[k] = i;
192         
193         // Collects rules belonging to each strata
194         @SuppressWarnings("unchecked")
195         ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];
196         for(int i=0;i<components.size();++i)
197             rulesPerStrata[i] = new ArrayList<DatalogRule>();
198         for(DatalogRule rule : rules) {
199             int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];
200             rulesPerStrata[stratum].add(rule);
201         }
202         
203         // Create stratified system
204         Expression cur = this.in;
205         for(int stratum=components.size()-1;stratum >= 0;--stratum) {
206             int[] cs = components.get(stratum);
207             LocalRelation[] curRelations = new LocalRelation[cs.length];
208             for(int i=0;i<cs.length;++i)
209                 curRelations[i] = relations[cs[i]];
210             ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];
211             cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);
212         }
213         return cur;
214     }
215     
216     private Expression compileStratified(TypingContext context) {
217         Expression continuation = Expressions.tuple();
218         
219         // Create stacks
220         Variable[] stacks = new Variable[relations.length];
221         for(int i=0;i<relations.length;++i) {
222             LocalRelation relation = relations[i];
223             Type[] parameterTypes = relation.getParameterTypes();
224             stacks[i] = newVar("stack" + relation.getName(),
225                     Types.apply(Names.MList_T, Types.tuple(parameterTypes))
226                     );
227         }
228
229         // Simplify subexpressions and collect derivatives
230         THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);
231         for(int i=0;i<relations.length;++i) {
232             LocalRelation relation = relations[i];
233             Type[] parameterTypes = relation.getParameterTypes();
234             Variable[] parameters = new Variable[parameterTypes.length];
235             for(int j=0;j<parameterTypes.length;++j)
236                 parameters[j] = new Variable("p" + j, parameterTypes[j]);
237             diffables.put(relations[i], new Diffable(i, relation, parameters));
238         }
239         @SuppressWarnings("unchecked")
240         ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];
241         for(int i=0;i<relations.length;++i)
242             updateExpressions[i] = new ArrayList<Expression>(2);
243         ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); 
244         for(DatalogRule rule : rules) {
245             int id = diffables.get(rule.headRelation).id;
246             Expression appendExp = apply(context.getCompilationContext(), Types.PROC, Names.MList_add, Types.tuple(rule.headRelation.getParameterTypes()),
247                     var(stacks[id]),
248                     tuple(rule.headParameters)
249                     );
250             Diff[] diffs;
251             try {
252                 diffs = rule.body.derivate(diffables);
253             } catch(DerivateException e) {
254                 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");
255                 return new EError();
256             }
257             for(Diff diff : diffs)
258                 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));
259             if(diffs.length == 0)
260                 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));
261             else {
262                 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());
263                 if(query != Query.EMPTY_QUERY)
264                     seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));
265             }
266         }
267         
268         // Iterative solving of relations
269
270         Variable[] loops = new Variable[relations.length];
271         for(int i=0;i<loops.length;++i)
272             loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));
273         continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);
274         
275         Expression[] loopDefs = new Expression[relations.length];
276         for(int i=0;i<relations.length;++i) {
277             LocalRelation relation = relations[i];
278             Type[] parameterTypes = relation.getParameterTypes();
279             Variable[] parameters = diffables.get(relation).parameters;
280             
281             Variable counter = newVar("counter", Types.INTEGER);
282             
283             Type rowType = Types.tuple(parameterTypes);
284             Variable row = newVar("row", rowType);
285             
286             Expression handleRow = tuple();
287             for(Expression updateExpression : updateExpressions[i])
288                 handleRow = seq(updateExpression, handleRow);
289             handleRow = if_(
290                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, rowType,
291                             var(relation.table), var(row)),
292                     handleRow,
293                     tuple()
294                     );
295             handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));
296             Expression failure =
297                     if_(isZeroInteger(var(counter)),
298                         tuple(),
299                         apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))
300                        );
301             Expression body = matchWithDefault(
302                     apply(context.getCompilationContext(), Types.PROC, Names.MList_removeLast, rowType, var(stacks[i])),
303                     Just(as(row, tuple(vars(parameters)))), handleRow,
304                     failure);
305             
306             loopDefs[i] = lambda(Types.PROC, counter, body); 
307         }
308         continuation = letRec(loops, loopDefs, continuation);
309         
310         // Seed relations
311         for(Expression seedExpression : seedExpressions)
312             continuation = seq(seedExpression, continuation);
313         
314         // Create stacks
315         for(int i=0;i<stacks.length;++i)
316             continuation = let(stacks[i],
317                     apply(context.getCompilationContext(), Types.PROC, Names.MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),
318                     continuation);
319         
320         continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);
321         
322         // Create relations
323         for(LocalRelation relation : relations)
324             continuation = let(relation.table,
325                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),
326                     continuation);
327         
328         return seq(continuation, in);
329     }
330
331     @Override
332     protected void updateType() throws MatchException {
333         setType(in.getType());
334     }
335     
336     @Override
337     public void setLocationDeep(long loc) {
338         if(location == Locations.NO_LOCATION) {
339             location = loc;
340             for(DatalogRule rule : rules)
341                 rule.setLocationDeep(loc);
342         }
343     }
344     
345     @Override
346     public void accept(ExpressionVisitor visitor) {
347         visitor.visit(this);
348     }
349
350     public DatalogRule[] getRules() {
351         return rules;
352     }
353     
354     public Expression getIn() {
355         return in;
356     }
357     
358     @Override
359     public Expression accept(ExpressionTransformer transformer) {
360         return transformer.transform(this);
361     }
362
363 }