(refs #7375) Replaced collectFreeVariables method by a visitor
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ERuleset.java
1 package org.simantics.scl.compiler.elaboration.expressions;
2
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;
19
20 import java.util.ArrayList;
21 import java.util.Set;
22
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
24 import org.simantics.scl.compiler.common.names.Names;
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
28 import org.simantics.scl.compiler.elaboration.query.Query;
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
34 import org.simantics.scl.compiler.errors.Locations;
35 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
36 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
37 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
38 import org.simantics.scl.compiler.types.Type;
39 import org.simantics.scl.compiler.types.Types;
40 import org.simantics.scl.compiler.types.exceptions.MatchException;
41 import org.simantics.scl.compiler.types.kinds.Kinds;
42
43 import gnu.trove.impl.Constants;
44 import gnu.trove.map.hash.THashMap;
45 import gnu.trove.map.hash.TObjectIntHashMap;
46 import gnu.trove.set.hash.TIntHashSet;
47
48 public class ERuleset extends SimplifiableExpression {
49     LocalRelation[] relations;
50     public DatalogRule[] rules;
51     public Expression in;
52     
53     public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {
54         this.relations = relations;
55         this.rules = rules;
56         this.in = in;
57     }
58
59     public static class DatalogRule {
60         public long location;
61         public LocalRelation headRelation;
62         public Expression[] headParameters;
63         public Query body;
64         public Variable[] variables;
65         
66         public DatalogRule(LocalRelation headRelation, Expression[] headParameters,
67                 Query body) {
68             this.headRelation = headRelation;
69             this.headParameters = headParameters;
70             this.body = body;
71         }
72         
73         public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,
74                 Query body, Variable[] variables) {
75             this.location = location;
76             this.headRelation = headRelation;
77             this.headParameters = headParameters;
78             this.body = body;
79             this.variables = variables;
80         }
81
82         public void setLocationDeep(long loc) {
83             this.location = loc;
84             for(Expression parameter : headParameters)
85                 parameter.setLocationDeep(loc);
86             body.setLocationDeep(loc);
87         }
88         
89         @Override
90         public String toString() {
91             StringBuilder b = new StringBuilder();
92             ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
93             visitor.visit(this);
94             return b.toString();
95         }
96     }
97     
98     private void checkRuleTypes(TypingContext context) {
99         // Create relation variables
100         for(DatalogRule rule : rules) {
101             Type[] parameterTypes =  rule.headRelation.getParameterTypes();
102             Expression[] parameters = rule.headParameters;
103             for(Variable variable : rule.variables)
104                 variable.setType(Types.metaVar(Kinds.STAR));
105             for(int i=0;i<parameters.length;++i)
106                 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);
107             rule.body.checkType(context);
108         }
109     }
110     
111     @Override
112     public Expression checkBasicType(TypingContext context, Type requiredType) {
113         checkRuleTypes(context);
114         in = in.checkBasicType(context, requiredType);
115         return compile(context);
116     }
117     
118     @Override
119     public Expression inferType(TypingContext context) {
120         checkRuleTypes(context);
121         in = in.inferType(context);
122         return compile(context);
123     }
124     
125     @Override
126     public Expression checkIgnoredType(TypingContext context) {
127         checkRuleTypes(context);
128         in = in.checkIgnoredType(context);
129         return compile(context);
130     }
131
132     @Override
133     public void collectVars(TObjectIntHashMap<Variable> allVars,
134             TIntHashSet vars) {
135         for(DatalogRule rule : rules) {
136             for(Expression parameter : rule.headParameters)
137                 parameter.collectVars(allVars, vars);
138             rule.body.collectVars(allVars, vars);
139         }
140         in.collectVars(allVars, vars);
141     }
142     
143     @Override
144     public Expression resolve(TranslationContext context) {
145         throw new InternalCompilerError();
146     }
147     
148     static class LocalRelationAux {
149         Variable handleFunc;
150     }
151     
152     public Expression compile(TypingContext context) {
153         // Create a map from relations to their ids
154         TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,
155                 Constants.DEFAULT_LOAD_FACTOR, -1);
156         for(int i=0;i<relations.length;++i)
157             relationsToIds.put(relations[i], i);
158         
159         // Create a table from relations to the other relations they depend on
160         TIntHashSet[] refsSets = new TIntHashSet[relations.length];
161         int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);
162         for(int i=0;i<relations.length;++i)
163             refsSets[i] = new TIntHashSet(setCapacity);
164         
165         for(DatalogRule rule : rules) {
166             int headRelationId = relationsToIds.get(rule.headRelation);
167             TIntHashSet refsSet = refsSets[headRelationId];
168             rule.body.collectRelationRefs(relationsToIds, refsSet);
169             for(Expression parameter : rule.headParameters)
170                 parameter.collectRelationRefs(relationsToIds, refsSet);
171         }
172         
173         // Convert refsSets to an array
174         final int[][] refs = new int[relations.length][];
175         for(int i=0;i<relations.length;++i)
176             refs[i] = refsSets[i].toArray();
177         
178         // Find strongly connected components of the function refs
179         final ArrayList<int[]> components = new ArrayList<int[]>();
180         
181         new StronglyConnectedComponents(relations.length) {
182             @Override
183             protected void reportComponent(int[] component) {
184                 components.add(component);
185             }
186             
187             @Override
188             protected int[] findDependencies(int u) {
189                 return refs[u];
190             }
191         }.findComponents();
192         
193         // If there is just one component, compile it
194         if(components.size() == 1) {
195             return compileStratified(context);
196         }
197         
198         // Inverse of components array 
199         int[] strataPerRelation = new int[relations.length];
200         for(int i=0;i<components.size();++i)
201             for(int k : components.get(i))
202                 strataPerRelation[k] = i;
203         
204         // Collects rules belonging to each strata
205         @SuppressWarnings("unchecked")
206         ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];
207         for(int i=0;i<components.size();++i)
208             rulesPerStrata[i] = new ArrayList<DatalogRule>();
209         for(DatalogRule rule : rules) {
210             int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];
211             rulesPerStrata[stratum].add(rule);
212         }
213         
214         // Create stratified system
215         Expression cur = this.in;
216         for(int stratum=components.size()-1;stratum >= 0;--stratum) {
217             int[] cs = components.get(stratum);
218             LocalRelation[] curRelations = new LocalRelation[cs.length];
219             for(int i=0;i<cs.length;++i)
220                 curRelations[i] = relations[cs[i]];
221             ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];
222             cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);
223         }
224         return cur;
225     }
226     
227     private Expression compileStratified(TypingContext context) {
228         Expression continuation = Expressions.tuple();
229         
230         // Create stacks
231         Variable[] stacks = new Variable[relations.length];
232         for(int i=0;i<relations.length;++i) {
233             LocalRelation relation = relations[i];
234             Type[] parameterTypes = relation.getParameterTypes();
235             stacks[i] = newVar("stack" + relation.getName(),
236                     Types.apply(Names.MList_T, Types.tuple(parameterTypes))
237                     );
238         }
239
240         // Simplify subexpressions and collect derivatives
241         THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);
242         for(int i=0;i<relations.length;++i) {
243             LocalRelation relation = relations[i];
244             Type[] parameterTypes = relation.getParameterTypes();
245             Variable[] parameters = new Variable[parameterTypes.length];
246             for(int j=0;j<parameterTypes.length;++j)
247                 parameters[j] = new Variable("p" + j, parameterTypes[j]);
248             diffables.put(relations[i], new Diffable(i, relation, parameters));
249         }
250         @SuppressWarnings("unchecked")
251         ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];
252         for(int i=0;i<relations.length;++i)
253             updateExpressions[i] = new ArrayList<Expression>(2);
254         ArrayList<Expression> seedExpressions = new ArrayList<Expression>(); 
255         for(DatalogRule rule : rules) {
256             int id = diffables.get(rule.headRelation).id;
257             Expression appendExp = apply(context.getCompilationContext(), Types.PROC, Names.MList_add, Types.tuple(rule.headRelation.getParameterTypes()),
258                     var(stacks[id]),
259                     tuple(rule.headParameters)
260                     );
261             Diff[] diffs;
262             try {
263                 diffs = rule.body.derivate(diffables);
264             } catch(DerivateException e) {
265                 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");
266                 return new EError();
267             }
268             for(Diff diff : diffs)
269                 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));
270             if(diffs.length == 0)
271                 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));
272             else {
273                 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());
274                 if(query != Query.EMPTY_QUERY)
275                     seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));
276             }
277         }
278         
279         // Iterative solving of relations
280
281         Variable[] loops = new Variable[relations.length];
282         for(int i=0;i<loops.length;++i)
283             loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));
284         continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);
285         
286         Expression[] loopDefs = new Expression[relations.length];
287         for(int i=0;i<relations.length;++i) {
288             LocalRelation relation = relations[i];
289             Type[] parameterTypes = relation.getParameterTypes();
290             Variable[] parameters = diffables.get(relation).parameters;
291             
292             Variable counter = newVar("counter", Types.INTEGER);
293             
294             Type rowType = Types.tuple(parameterTypes);
295             Variable row = newVar("row", rowType);
296             
297             Expression handleRow = tuple();
298             for(Expression updateExpression : updateExpressions[i])
299                 handleRow = seq(updateExpression, handleRow);
300             handleRow = if_(
301                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, rowType,
302                             var(relation.table), var(row)),
303                     handleRow,
304                     tuple()
305                     );
306             handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));
307             Expression failure =
308                     if_(isZeroInteger(var(counter)),
309                         tuple(),
310                         apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))
311                        );
312             Expression body = matchWithDefault(
313                     apply(context.getCompilationContext(), Types.PROC, Names.MList_removeLast, rowType, var(stacks[i])),
314                     Just(as(row, tuple(vars(parameters)))), handleRow,
315                     failure);
316             
317             loopDefs[i] = lambda(Types.PROC, counter, body); 
318         }
319         continuation = letRec(loops, loopDefs, continuation);
320         
321         // Seed relations
322         for(Expression seedExpression : seedExpressions)
323             continuation = seq(seedExpression, continuation);
324         
325         // Create stacks
326         for(int i=0;i<stacks.length;++i)
327             continuation = let(stacks[i],
328                     apply(context.getCompilationContext(), Types.PROC, Names.MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),
329                     continuation);
330         
331         continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);
332         
333         // Create relations
334         for(LocalRelation relation : relations)
335             continuation = let(relation.table,
336                     apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),
337                     continuation);
338         
339         return seq(continuation, in);
340     }
341
342     @Override
343     protected void updateType() throws MatchException {
344         setType(in.getType());
345     }
346     
347     @Override
348     public void setLocationDeep(long loc) {
349         if(location == Locations.NO_LOCATION) {
350             location = loc;
351             for(DatalogRule rule : rules)
352                 rule.setLocationDeep(loc);
353         }
354     }
355     
356     @Override
357     public void accept(ExpressionVisitor visitor) {
358         visitor.visit(this);
359     }
360
361     public DatalogRule[] getRules() {
362         return rules;
363     }
364     
365     public Expression getIn() {
366         return in;
367     }
368     
369     @Override
370     public Expression accept(ExpressionTransformer transformer) {
371         return transformer.transform(this);
372     }
373
374 }