1 package org.simantics.scl.compiler.elaboration.expressions;
\r
3 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
\r
4 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.addInteger;
\r
5 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
\r
6 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.as;
\r
7 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
\r
8 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.integer;
\r
9 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.isZeroInteger;
\r
10 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
\r
11 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
\r
12 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
\r
13 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.matchWithDefault;
\r
14 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.newVar;
\r
15 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.seq;
\r
16 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
\r
17 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;
\r
18 import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;
\r
20 import java.util.ArrayList;
\r
21 import java.util.Set;
\r
23 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
\r
24 import org.simantics.scl.compiler.common.names.Name;
\r
25 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
\r
26 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
\r
27 import org.simantics.scl.compiler.elaboration.expressions.printing.ExpressionToStringVisitor;
\r
28 import org.simantics.scl.compiler.elaboration.query.Query;
\r
29 import org.simantics.scl.compiler.elaboration.query.Query.Diff;
\r
30 import org.simantics.scl.compiler.elaboration.query.Query.Diffable;
\r
31 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
\r
32 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
\r
33 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
\r
34 import org.simantics.scl.compiler.errors.Locations;
\r
35 import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;
\r
36 import org.simantics.scl.compiler.internal.elaboration.utils.ForcedClosure;
\r
37 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
\r
38 import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
\r
39 import org.simantics.scl.compiler.types.TCon;
\r
40 import org.simantics.scl.compiler.types.Type;
\r
41 import org.simantics.scl.compiler.types.Types;
\r
42 import org.simantics.scl.compiler.types.exceptions.MatchException;
\r
43 import org.simantics.scl.compiler.types.kinds.Kinds;
\r
45 import gnu.trove.impl.Constants;
\r
46 import gnu.trove.map.hash.THashMap;
\r
47 import gnu.trove.map.hash.TObjectIntHashMap;
\r
48 import gnu.trove.set.hash.THashSet;
\r
49 import gnu.trove.set.hash.TIntHashSet;
\r
51 public class ERuleset extends SimplifiableExpression {
\r
52 LocalRelation[] relations;
\r
53 DatalogRule[] rules;
\r
56 public ERuleset(LocalRelation[] relations, DatalogRule[] rules, Expression in) {
\r
57 this.relations = relations;
\r
62 public static class DatalogRule {
\r
63 public long location;
\r
64 public LocalRelation headRelation;
\r
65 public Expression[] headParameters;
\r
67 public Variable[] variables;
\r
69 public DatalogRule(LocalRelation headRelation, Expression[] headParameters,
\r
71 this.headRelation = headRelation;
\r
72 this.headParameters = headParameters;
\r
76 public DatalogRule(long location, LocalRelation headRelation, Expression[] headParameters,
\r
77 Query body, Variable[] variables) {
\r
78 this.location = location;
\r
79 this.headRelation = headRelation;
\r
80 this.headParameters = headParameters;
\r
82 this.variables = variables;
\r
85 public void setLocationDeep(long loc) {
\r
86 this.location = loc;
\r
87 for(Expression parameter : headParameters)
\r
88 parameter.setLocationDeep(loc);
\r
89 body.setLocationDeep(loc);
\r
93 public String toString() {
\r
94 StringBuilder b = new StringBuilder();
\r
95 ExpressionToStringVisitor visitor = new ExpressionToStringVisitor(b);
\r
96 visitor.visit(this);
\r
97 return b.toString();
\r
100 public void forVariables(VariableProcedure procedure) {
\r
101 for(Expression headParameter : headParameters)
\r
102 headParameter.forVariables(procedure);
\r
103 body.forVariables(procedure);
\r
107 private void checkRuleTypes(TypingContext context) {
\r
108 // Create relation variables
\r
109 for(DatalogRule rule : rules) {
\r
110 Type[] parameterTypes = rule.headRelation.getParameterTypes();
\r
111 Expression[] parameters = rule.headParameters;
\r
112 for(Variable variable : rule.variables)
\r
113 variable.setType(Types.metaVar(Kinds.STAR));
\r
114 for(int i=0;i<parameters.length;++i)
\r
115 parameters[i] = parameters[i].checkType(context, parameterTypes[i]);
\r
116 rule.body.checkType(context);
\r
121 public Expression checkBasicType(TypingContext context, Type requiredType) {
\r
122 checkRuleTypes(context);
\r
123 in = in.checkBasicType(context, requiredType);
\r
124 return compile(context);
\r
128 public Expression inferType(TypingContext context) {
\r
129 checkRuleTypes(context);
\r
130 in = in.inferType(context);
\r
131 return compile(context);
\r
135 public Expression checkIgnoredType(TypingContext context) {
\r
136 checkRuleTypes(context);
\r
137 in = in.checkIgnoredType(context);
\r
138 return compile(context);
\r
142 public void collectFreeVariables(THashSet<Variable> vars) {
\r
143 for(DatalogRule rule : rules) {
\r
144 for(Expression parameter : rule.headParameters)
\r
145 parameter.collectFreeVariables(vars);
\r
146 rule.body.collectFreeVariables(vars);
\r
147 for(Variable var : rule.variables)
\r
150 in.collectFreeVariables(vars);
\r
154 public void collectRefs(TObjectIntHashMap<Object> allRefs,
\r
155 TIntHashSet refs) {
\r
156 for(DatalogRule rule : rules) {
\r
157 for(Expression parameter : rule.headParameters)
\r
158 parameter.collectRefs(allRefs, refs);
\r
159 rule.body.collectRefs(allRefs, refs);
\r
161 in.collectRefs(allRefs, refs);
\r
165 public void collectVars(TObjectIntHashMap<Variable> allVars,
\r
166 TIntHashSet vars) {
\r
167 for(DatalogRule rule : rules) {
\r
168 for(Expression parameter : rule.headParameters)
\r
169 parameter.collectVars(allVars, vars);
\r
170 rule.body.collectVars(allVars, vars);
\r
172 in.collectVars(allVars, vars);
\r
176 public void collectEffects(THashSet<Type> effects) {
\r
177 throw new InternalCompilerError(location, getClass().getSimpleName() + " does not support collectEffects.");
\r
181 public Expression decorate(ExpressionDecorator decorator) {
\r
182 return decorator.decorate(this);
\r
186 public Expression resolve(TranslationContext context) {
\r
187 throw new InternalCompilerError();
\r
190 static class LocalRelationAux {
\r
191 Variable handleFunc;
\r
194 public static final TCon MSet = Types.con("MSet", "T");
\r
195 private static final Name MSet_add = Name.create("MSet", "add");
\r
196 private static final Name MSet_create = Name.create("MSet", "create");
\r
198 private static final TCon MList = Types.con("MList", "T");
\r
199 private static final Name MList_add = Name.create("MList", "add");
\r
200 private static final Name MList_create = Name.create("MList", "create");
\r
201 private static final Name MList_removeLast = Name.create("MList", "removeLast");
\r
203 public Expression compile(TypingContext context) {
\r
204 // Create a map from relations to their ids
\r
205 TObjectIntHashMap<SCLRelation> relationsToIds = new TObjectIntHashMap<SCLRelation>(relations.length,
\r
206 Constants.DEFAULT_LOAD_FACTOR, -1);
\r
207 for(int i=0;i<relations.length;++i)
\r
208 relationsToIds.put(relations[i], i);
\r
210 // Create a table from relations to the other relations they depend on
\r
211 TIntHashSet[] refsSets = new TIntHashSet[relations.length];
\r
212 int setCapacity = Math.min(Constants.DEFAULT_CAPACITY, relations.length);
\r
213 for(int i=0;i<relations.length;++i)
\r
214 refsSets[i] = new TIntHashSet(setCapacity);
\r
216 for(DatalogRule rule : rules) {
\r
217 int headRelationId = relationsToIds.get(rule.headRelation);
\r
218 TIntHashSet refsSet = refsSets[headRelationId];
\r
219 rule.body.collectRelationRefs(relationsToIds, refsSet);
\r
220 for(Expression parameter : rule.headParameters)
\r
221 parameter.collectRelationRefs(relationsToIds, refsSet);
\r
224 // Convert refsSets to an array
\r
225 final int[][] refs = new int[relations.length][];
\r
226 for(int i=0;i<relations.length;++i)
\r
227 refs[i] = refsSets[i].toArray();
\r
229 // Find strongly connected components of the function refs
\r
230 final ArrayList<int[]> components = new ArrayList<int[]>();
\r
232 new StronglyConnectedComponents(relations.length) {
\r
234 protected void reportComponent(int[] component) {
\r
235 components.add(component);
\r
239 protected int[] findDependencies(int u) {
\r
242 }.findComponents();
\r
244 // If there is just one component, compile it
\r
245 if(components.size() == 1) {
\r
246 return compileStratified(context);
\r
249 // Inverse of components array
\r
250 int[] strataPerRelation = new int[relations.length];
\r
251 for(int i=0;i<components.size();++i)
\r
252 for(int k : components.get(i))
\r
253 strataPerRelation[k] = i;
\r
255 // Collects rules belonging to each strata
\r
256 @SuppressWarnings("unchecked")
\r
257 ArrayList<DatalogRule>[] rulesPerStrata = new ArrayList[components.size()];
\r
258 for(int i=0;i<components.size();++i)
\r
259 rulesPerStrata[i] = new ArrayList<DatalogRule>();
\r
260 for(DatalogRule rule : rules) {
\r
261 int stratum = strataPerRelation[relationsToIds.get(rule.headRelation)];
\r
262 rulesPerStrata[stratum].add(rule);
\r
265 // Create stratified system
\r
266 Expression cur = this.in;
\r
267 for(int stratum=components.size()-1;stratum >= 0;--stratum) {
\r
268 int[] cs = components.get(stratum);
\r
269 LocalRelation[] curRelations = new LocalRelation[cs.length];
\r
270 for(int i=0;i<cs.length;++i)
\r
271 curRelations[i] = relations[cs[i]];
\r
272 ArrayList<DatalogRule> curRules = rulesPerStrata[stratum];
\r
273 cur = new ERuleset(curRelations, curRules.toArray(new DatalogRule[curRules.size()]), cur).compileStratified(context);
\r
278 private Expression compileStratified(TypingContext context) {
\r
279 Expression continuation = Expressions.tuple();
\r
282 Variable[] stacks = new Variable[relations.length];
\r
283 for(int i=0;i<relations.length;++i) {
\r
284 LocalRelation relation = relations[i];
\r
285 Type[] parameterTypes = relation.getParameterTypes();
\r
286 stacks[i] = newVar("stack" + relation.getName(),
\r
287 Types.apply(MList, Types.tuple(parameterTypes))
\r
291 // Simplify subexpressions and collect derivatives
\r
292 THashMap<LocalRelation, Diffable> diffables = new THashMap<LocalRelation, Diffable>(relations.length);
\r
293 for(int i=0;i<relations.length;++i) {
\r
294 LocalRelation relation = relations[i];
\r
295 Type[] parameterTypes = relation.getParameterTypes();
\r
296 Variable[] parameters = new Variable[parameterTypes.length];
\r
297 for(int j=0;j<parameterTypes.length;++j)
\r
298 parameters[j] = new Variable("p" + j, parameterTypes[j]);
\r
299 diffables.put(relations[i], new Diffable(i, relation, parameters));
\r
301 @SuppressWarnings("unchecked")
\r
302 ArrayList<Expression>[] updateExpressions = (ArrayList<Expression>[])new ArrayList[relations.length];
\r
303 for(int i=0;i<relations.length;++i)
\r
304 updateExpressions[i] = new ArrayList<Expression>(2);
\r
305 ArrayList<Expression> seedExpressions = new ArrayList<Expression>();
\r
306 for(DatalogRule rule : rules) {
\r
307 int id = diffables.get(rule.headRelation).id;
\r
308 Expression appendExp = apply(context, Types.PROC, MList_add, Types.tuple(rule.headRelation.getParameterTypes()),
\r
310 tuple(rule.headParameters)
\r
314 diffs = rule.body.derivate(diffables);
\r
315 } catch(DerivateException e) {
\r
316 context.getErrorLog().log(e.location, "Recursion must not contain negations or aggragates.");
\r
317 return new EError();
\r
319 for(Diff diff : diffs)
\r
320 updateExpressions[diff.id].add(((EWhen)new EWhen(rule.location, diff.query, appendExp, rule.variables).copy(context)).compile(context));
\r
321 if(diffs.length == 0)
\r
322 seedExpressions.add(((EWhen)new EWhen(rule.location, rule.body, appendExp, rule.variables).copy(context)).compile(context));
\r
324 Query query = rule.body.removeRelations((Set<SCLRelation>)(Set)diffables.keySet());
\r
325 if(query != Query.EMPTY_QUERY)
\r
326 seedExpressions.add(((EWhen)new EWhen(location, query, appendExp, rule.variables).copy(context)).compile(context));
\r
330 // Iterative solving of relations
\r
332 Variable[] loops = new Variable[relations.length];
\r
333 for(int i=0;i<loops.length;++i)
\r
334 loops[i] = newVar("loop" + relations[i].getName(), Types.functionE(Types.INTEGER, Types.PROC, Types.UNIT));
\r
335 continuation = seq(apply(Types.PROC, var(loops[0]), integer(relations.length-1)), continuation);
\r
337 Expression[] loopDefs = new Expression[relations.length];
\r
338 for(int i=0;i<relations.length;++i) {
\r
339 LocalRelation relation = relations[i];
\r
340 Type[] parameterTypes = relation.getParameterTypes();
\r
341 Variable[] parameters = diffables.get(relation).parameters;
\r
343 Variable counter = newVar("counter", Types.INTEGER);
\r
345 Type rowType = Types.tuple(parameterTypes);
\r
346 Variable row = newVar("row", rowType);
\r
348 Expression handleRow = tuple();
\r
349 for(Expression updateExpression : updateExpressions[i])
\r
350 handleRow = seq(updateExpression, handleRow);
\r
352 apply(context, Types.PROC, MSet_add, rowType,
\r
353 var(relation.table), var(row)),
\r
357 handleRow = seq(handleRow, apply(Types.PROC, var(loops[i]), integer(relations.length-1)));
\r
358 Expression failure =
\r
359 if_(isZeroInteger(var(counter)),
\r
361 apply(Types.PROC, var(loops[(i+1)%relations.length]), addInteger(var(counter), integer(-1)))
\r
363 Expression body = matchWithDefault(
\r
364 apply(context, Types.PROC, MList_removeLast, rowType, var(stacks[i])),
\r
365 Just(as(row, tuple(vars(parameters)))), handleRow,
\r
368 loopDefs[i] = lambda(Types.PROC, counter, body);
\r
370 continuation = letRec(loops, loopDefs, continuation);
\r
373 for(Expression seedExpression : seedExpressions)
\r
374 continuation = seq(seedExpression, continuation);
\r
377 for(int i=0;i<stacks.length;++i)
\r
378 continuation = let(stacks[i],
\r
379 apply(context, Types.PROC, MList_create, Types.tuple(relations[i].getParameterTypes()), tuple()),
\r
382 continuation = ForcedClosure.forceClosure(continuation, SCLCompilerConfiguration.EVERY_DATALOG_STRATUM_IN_SEPARATE_METHOD);
\r
384 // Create relations
\r
385 for(LocalRelation relation : relations)
\r
386 continuation = let(relation.table,
\r
387 apply(context, Types.PROC, MSet_create, Types.tuple(relation.getParameterTypes()), tuple()),
\r
390 return seq(continuation, in);
\r
394 protected void updateType() throws MatchException {
\r
395 setType(in.getType());
\r
399 public void setLocationDeep(long loc) {
\r
400 if(location == Locations.NO_LOCATION) {
\r
402 for(DatalogRule rule : rules)
\r
403 rule.setLocationDeep(loc);
\r
408 public void accept(ExpressionVisitor visitor) {
\r
409 visitor.visit(this);
\r
412 public DatalogRule[] getRules() {
\r
416 public Expression getIn() {
\r
421 public void forVariables(VariableProcedure procedure) {
\r
422 for(DatalogRule rule : rules)
\r
423 rule.forVariables(procedure);
\r
424 in.forVariables(procedure);
\r
428 public Expression accept(ExpressionTransformer transformer) {
\r
429 return transformer.transform(this);
\r