1 package org.simantics.scl.compiler.elaboration.expressions;
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;
20 import java.util.ArrayList;
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
24 import org.simantics.scl.compiler.common.names.Names;
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
28 import org.simantics.scl.compiler.elaboration.query.Query;
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
34 import org.simantics.scl.compiler.errors.Locations;
35 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
36 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
37 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
38 import org.simantics.scl.compiler.types.Type;
39 import org.simantics.scl.compiler.types.Types;
40 import org.simantics.scl.compiler.types.exceptions.MatchException;
41 import org.simantics.scl.compiler.types.kinds.Kinds;
43 import gnu.trove.impl.Constants;
44 import gnu.trove.map.hash.THashMap;
45 import gnu.trove.map.hash.TObjectIntHashMap;
46 import gnu.trove.set.hash.THashSet;
47 import gnu.trove.set.hash.TIntHashSet;
49 public class ERuleset extends SimplifiableExpression {
50 LocalRelation[] relations;
54 public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {
55 this.relations = relations;
60 public static class DatalogRule {
62 public LocalRelation headRelation;
63 public Expression[] headParameters;
65 public Variable[] variables;
67 public DatalogRule(LocalRelation headRelation, Expression[] headParameters,
69 this.headRelation = headRelation;
70 this.headParameters = headParameters;
74 public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,
75 Query body, Variable[] variables) {
76 this.location = location;
77 this.headRelation = headRelation;
78 this.headParameters = headParameters;
80 this.variables = variables;
83 public void setLocationDeep(long loc) {
85 for(Expression parameter : headParameters)
86 parameter.setLocationDeep(loc);
87 body.setLocationDeep(loc);
91 public String toString() {
92 StringBuilder b = new StringBuilder();
93 ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
98 public void forVariables(VariableProcedure procedure) {
99 for(Expression headParameter : headParameters)
100 headParameter.forVariables(procedure);
101 body.forVariables(procedure);
105 private void checkRuleTypes(TypingContext context) {
106 // Create relation variables
107 for(DatalogRule rule : rules) {
108 Type[] parameterTypes = rule.headRelation.getParameterTypes();
109 Expression[] parameters = rule.headParameters;
110 for(Variable variable : rule.variables)
111 variable.setType(Types.metaVar(Kinds.STAR));
112 for(int i=0;i<parameters.length;++i)
113 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);
114 rule.body.checkType(context);
119 public Expression checkBasicType(TypingContext context, Type requiredType) {
120 checkRuleTypes(context);
121 in = in.checkBasicType(context, requiredType);
122 return compile(context);
126 public Expression inferType(TypingContext context) {
127 checkRuleTypes(context);
128 in = in.inferType(context);
129 return compile(context);
133 public Expression checkIgnoredType(TypingContext context) {
134 checkRuleTypes(context);
135 in = in.checkIgnoredType(context);
136 return compile(context);
140 public void collectFreeVariables(THashSet<Variable> vars) {
141 for(DatalogRule rule : rules) {
142 for(Expression parameter : rule.headParameters)
143 parameter.collectFreeVariables(vars);
144 rule.body.collectFreeVariables(vars);
145 for(Variable var : rule.variables)
148 in.collectFreeVariables(vars);
152 public void collectRefs(TObjectIntHashMap<Object> allRefs,
154 for(DatalogRule rule : rules) {
155 for(Expression parameter : rule.headParameters)
156 parameter.collectRefs(allRefs, refs);
157 rule.body.collectRefs(allRefs, refs);
159 in.collectRefs(allRefs, refs);
163 public void collectVars(TObjectIntHashMap<Variable> allVars,
165 for(DatalogRule rule : rules) {
166 for(Expression parameter : rule.headParameters)
167 parameter.collectVars(allVars, vars);
168 rule.body.collectVars(allVars, vars);
170 in.collectVars(allVars, vars);
174 public void collectEffects(THashSet<Type> effects) {
175 throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");
179 public Expression resolve(TranslationContext context) {
180 throw new InternalCompilerError();
183 static class LocalRelationAux {
187 public Expression compile(TypingContext context) {
188 // Create a map from relations to their ids
189 TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,
190 Constants.DEFAULT_LOAD_FACTOR, -1);
191 for(int i=0;i<relations.length;++i)
192 relationsToIds.put(relations[i], i);
194 // Create a table from relations to the other relations they depend on
195 TIntHashSet[] refsSets = new TIntHashSet[relations.length];
196 int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);
197 for(int i=0;i<relations.length;++i)
198 refsSets[i] = new TIntHashSet(setCapacity);
200 for(DatalogRule rule : rules) {
201 int headRelationId = relationsToIds.get(rule.headRelation);
202 TIntHashSet refsSet = refsSets[headRelationId];
203 rule.body.collectRelationRefs(relationsToIds, refsSet);
204 for(Expression parameter : rule.headParameters)
205 parameter.collectRelationRefs(relationsToIds, refsSet);
208 // Convert refsSets to an array
209 final int[][] refs = new int[relations.length][];
210 for(int i=0;i<relations.length;++i)
211 refs[i] = refsSets[i].toArray();
213 // Find strongly connected components of the function refs
214 final ArrayList<int[]> components = new ArrayList<int[]>();
216 new StronglyConnectedComponents(relations.length) {
218 protected void reportComponent(int[] component) {
219 components.add(component);
223 protected int[] findDependencies(int u) {
228 // If there is just one component, compile it
229 if(components.size() == 1) {
230 return compileStratified(context);
233 // Inverse of components array
234 int[] strataPerRelation = new int[relations.length];
235 for(int i=0;i<components.size();++i)
236 for(int k : components.get(i))
237 strataPerRelation[k] = i;
239 // Collects rules belonging to each strata
240 @SuppressWarnings("unchecked")
241 ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];
242 for(int i=0;i<components.size();++i)
243 rulesPerStrata[i] = new ArrayList<DatalogRule>();
244 for(DatalogRule rule : rules) {
245 int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];
246 rulesPerStrata[stratum].add(rule);
249 // Create stratified system
250 Expression cur = this.in;
251 for(int stratum=components.size()-1;stratum >= 0;--stratum) {
252 int[] cs = components.get(stratum);
253 LocalRelation[] curRelations = new LocalRelation[cs.length];
254 for(int i=0;i<cs.length;++i)
255 curRelations[i] = relations[cs[i]];
256 ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];
257 cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);
262 private Expression compileStratified(TypingContext context) {
263 Expression continuation = Expressions.tuple();
266 Variable[] stacks = new Variable[relations.length];
267 for(int i=0;i<relations.length;++i) {
268 LocalRelation relation = relations[i];
269 Type[] parameterTypes = relation.getParameterTypes();
270 stacks[i] = newVar("stack" + relation.getName(),
271 Types.apply(Names.MList_T, Types.tuple(parameterTypes))
275 // Simplify subexpressions and collect derivatives
276 THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);
277 for(int i=0;i<relations.length;++i) {
278 LocalRelation relation = relations[i];
279 Type[] parameterTypes = relation.getParameterTypes();
280 Variable[] parameters = new Variable[parameterTypes.length];
281 for(int j=0;j<parameterTypes.length;++j)
282 parameters[j] = new Variable("p" + j, parameterTypes[j]);
283 diffables.put(relations[i], new Diffable(i, relation, parameters));
285 @SuppressWarnings("unchecked")
286 ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];
287 for(int i=0;i<relations.length;++i)
288 updateExpressions[i] = new ArrayList<Expression>(2);
289 ArrayList<Expression> seedExpressions = new ArrayList<Expression>();
290 for(DatalogRule rule : rules) {
291 int id = diffables.get(rule.headRelation).id;
292 Expression appendExp = apply(context.getCompilationContext(), Types.PROC, Names.MList_add, Types.tuple(rule.headRelation.getParameterTypes()),
294 tuple(rule.headParameters)
298 diffs = rule.body.derivate(diffables);
299 } catch(DerivateException e) {
300 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");
303 for(Diff diff : diffs)
304 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));
305 if(diffs.length == 0)
306 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));
308 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());
309 if(query != Query.EMPTY_QUERY)
310 seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));
314 // Iterative solving of relations
316 Variable[] loops = new Variable[relations.length];
317 for(int i=0;i<loops.length;++i)
318 loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));
319 continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);
321 Expression[] loopDefs = new Expression[relations.length];
322 for(int i=0;i<relations.length;++i) {
323 LocalRelation relation = relations[i];
324 Type[] parameterTypes = relation.getParameterTypes();
325 Variable[] parameters = diffables.get(relation).parameters;
327 Variable counter = newVar("counter", Types.INTEGER);
329 Type rowType = Types.tuple(parameterTypes);
330 Variable row = newVar("row", rowType);
332 Expression handleRow = tuple();
333 for(Expression updateExpression : updateExpressions[i])
334 handleRow = seq(updateExpression, handleRow);
336 apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, rowType,
337 var(relation.table), var(row)),
341 handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));
343 if_(isZeroInteger(var(counter)),
345 apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))
347 Expression body = matchWithDefault(
348 apply(context.getCompilationContext(), Types.PROC, Names.MList_removeLast, rowType, var(stacks[i])),
349 Just(as(row, tuple(vars(parameters)))), handleRow,
352 loopDefs[i] = lambda(Types.PROC, counter, body);
354 continuation = letRec(loops, loopDefs, continuation);
357 for(Expression seedExpression : seedExpressions)
358 continuation = seq(seedExpression, continuation);
361 for(int i=0;i<stacks.length;++i)
362 continuation = let(stacks[i],
363 apply(context.getCompilationContext(), Types.PROC, Names.MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),
366 continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);
369 for(LocalRelation relation : relations)
370 continuation = let(relation.table,
371 apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),
374 return seq(continuation, in);
378 protected void updateType() throws MatchException {
379 setType(in.getType());
383 public void setLocationDeep(long loc) {
384 if(location == Locations.NO_LOCATION) {
386 for(DatalogRule rule : rules)
387 rule.setLocationDeep(loc);
392 public void accept(ExpressionVisitor visitor) {
396 public DatalogRule[] getRules() {
400 public Expression getIn() {
405 public void forVariables(VariableProcedure procedure) {
406 for(DatalogRule rule : rules)
407 rule.forVariables(procedure);
408 in.forVariables(procedure);
412 public Expression accept(ExpressionTransformer transformer) {
413 return transformer.transform(this);