(refs #7375) Replaced collectFreeVariables method by a visitor
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / query / QAtom.java
1 package org.simantics.scl.compiler.elaboration.query;
2
3 import java.util.ArrayList;
4 import java.util.Set;
5
6 import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
7 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
8 import org.simantics.scl.compiler.elaboration.expressions.ESimpleLet;
9 import org.simantics.scl.compiler.elaboration.expressions.EVariable;
10 import org.simantics.scl.compiler.elaboration.expressions.Expression;
11 import org.simantics.scl.compiler.elaboration.expressions.QueryTransformer;
12 import org.simantics.scl.compiler.elaboration.expressions.Variable;
13 import org.simantics.scl.compiler.elaboration.expressions.VariableProcedure;
14 import org.simantics.scl.compiler.elaboration.java.EqRelation;
15 import org.simantics.scl.compiler.elaboration.query.compilation.ConstraintCollectionContext;
16 import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
17 import org.simantics.scl.compiler.elaboration.query.compilation.EnforcingContext;
18 import org.simantics.scl.compiler.elaboration.query.compilation.ExpressionConstraint;
19 import org.simantics.scl.compiler.elaboration.query.compilation.RelationConstraint;
20 import org.simantics.scl.compiler.elaboration.relations.CompositeRelation;
21 import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
22 import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
23 import org.simantics.scl.compiler.errors.Locations;
24 import org.simantics.scl.compiler.types.TVar;
25 import org.simantics.scl.compiler.types.Type;
26 import org.simantics.scl.compiler.types.Types;
27
28 import gnu.trove.map.hash.THashMap;
29 import gnu.trove.map.hash.TIntObjectHashMap;
30 import gnu.trove.map.hash.TObjectIntHashMap;
31 import gnu.trove.set.hash.TIntHashSet;
32
33 public class QAtom extends Query {
34     public SCLRelation relation;
35     public Type[] typeParameters;
36     public Expression[] parameters;
37
38     public QAtom(SCLRelation relation, Expression ... parameters) {
39         this.relation = relation;
40         this.parameters = parameters;
41     }
42
43     public QAtom(SCLRelation relation, Type[] typeParameters, Expression ... parameters) {
44         this.relation = relation;
45         this.typeParameters = typeParameters;
46         this.parameters = parameters;
47     }
48
49     @Override
50     public void checkType(TypingContext context) {
51         // Type parameters
52         TVar[] typeVariables = relation.getTypeVariables();
53         typeParameters = new Type[typeVariables.length];
54         for(int i=0;i<typeVariables.length;++i)
55             typeParameters[i] = Types.metaVar(typeVariables[i].getKind());
56
57         // Check parameter types
58         Type[] parameterTypes = relation.getParameterTypes();
59         if(parameterTypes.length != parameters.length)
60             context.getErrorLog().log(location, "Relation is applied with wrong number of parameters.");
61         else
62             for(int i=0;i<parameters.length;++i)
63                 parameters[i] = parameters[i]
64                         .checkType(context, parameterTypes[i].replace(typeVariables, typeParameters));
65     }
66
67     public Expression generateEnforce(EnforcingContext context) {
68         Variable[] variables = new Variable[parameters.length];
69         for(int i=0;i<variables.length;++i)
70             if(parameters[i] instanceof EVariable)
71                 variables[i] = ((EVariable)parameters[i]).getVariable();
72             else
73                 variables[i] = new Variable("p" + i, parameters[i].getType());
74         Expression result = relation.generateEnforce(location, context, typeParameters, variables);
75         for(int i=variables.length-1;i>=0;--i)
76             if(!(parameters[i] instanceof EVariable))
77                 result = new ESimpleLet(
78                         variables[i],
79                         parameters[i],
80                         result
81                         );
82         return result;
83     }
84
85     private static class VariableMaskProcedure implements VariableProcedure {
86         ConstraintCollectionContext context;
87         long requiredVariablesMask = 0L;
88
89         public VariableMaskProcedure(ConstraintCollectionContext context) {
90             this.context = context;
91         }
92
93         @Override
94         public void execute(long location, Variable variable) {
95             int id = context.getVariableMap().get(variable);
96             if(id >= 0)
97                 requiredVariablesMask |= 1L << id;
98         }
99     }
100
101     @Override
102     public void collectConstraints(ConstraintCollectionContext context) {
103         try {
104             // Analyze parameters and find required and optional variables
105             VariableMaskProcedure procedure = new VariableMaskProcedure(context);
106             int[] optionalVariableByParameter = new int[parameters.length];
107             Variable[] varParameters = new Variable[parameters.length];
108             for(int i=0;i<parameters.length;++i) {
109                 Expression parameter = parameters[i];
110                 if(parameter instanceof EVariable) {
111                     Variable variable = ((EVariable)parameter).getVariable();
112                     optionalVariableByParameter[i] = context.getVariableMap().get(variable);
113                     varParameters[i] = variable;
114                 }
115                 else {
116                     Variable temp = new Variable("temp", parameter.getType());
117                     varParameters[i] = temp;
118                     if(parameter.isPattern(0)) {
119                         int tempId = context.addVariable(temp);
120                         context.addConstraint(new ExpressionConstraint(context, temp, parameter, true));
121                         optionalVariableByParameter[i] = tempId;
122                     }
123                     else {
124                         optionalVariableByParameter[i] = -1;
125                         parameter.forVariableUses(procedure);
126                     }
127                 }
128             }
129
130             // Combine required and optional variables
131             TIntHashSet allVariablesSet = new TIntHashSet();
132             for(int v : optionalVariableByParameter)
133                 if(v >= 0)
134                     allVariablesSet.add(v);
135
136             context.addConstraint(new RelationConstraint(allVariablesSet.toArray(), varParameters, this,
137                     optionalVariableByParameter, procedure.requiredVariablesMask));
138         } catch(Exception e) {
139             context.getQueryCompilationContext().getTypingContext().getErrorLog().log(location, e);
140         }
141     }
142
143     @Override
144     public void collectVars(TObjectIntHashMap<Variable> allVars,
145             TIntHashSet vars) {
146         for(Expression parameter : parameters)
147             parameter.collectVars(allVars, vars);
148     }
149
150     @Override
151     public Query replace(ReplaceContext context) {
152         Type[] newTypeParameters;
153         if(typeParameters == null)
154             newTypeParameters = null;
155         else {
156             newTypeParameters = new Type[typeParameters.length];
157             for(int i=0;i<typeParameters.length;++i)
158                 newTypeParameters[i] = typeParameters[i].replace(context.tvarMap);
159         }
160         return new QAtom(relation,
161                 newTypeParameters,
162                 Expression.replace(context, parameters));
163     }
164
165     @SuppressWarnings("unchecked")
166     @Override
167     public Diff[] derivate(THashMap<LocalRelation, Diffable> diffables) throws DerivateException {
168         Diffable diffable = diffables.get(relation);
169         if(diffable == null) {
170             if(relation instanceof CompositeRelation && 
171                     containsReferenceTo((CompositeRelation)relation,
172                             (THashMap<SCLRelation, Diffable>)(THashMap)diffables))
173                 throw new DerivateException(location);
174             return NO_DIFF;
175         }
176         else {
177             Query[] eqs = new Query[parameters.length];
178             for(int i=0;i<parameters.length;++i) {
179                 QAtom eq = new QAtom(EqRelation.INSTANCE, new Expression[] {
180                         new EVariable(diffable.parameters[i]),
181                         parameters[i]
182                 });
183                 eq.setLocationDeep(location);
184                 eq.typeParameters = new Type[] {parameters[i].getType()};
185                 eqs[i] = eq;
186             }
187             return new Diff[] { new Diff(diffable.id, new QConjunction(eqs)) };
188         }
189     }
190
191     private static boolean containsReferenceTo(
192             CompositeRelation relation,
193             THashMap<SCLRelation, Diffable> diffables) {
194         for(SCLRelation r : relation.getSubrelations())
195             if(diffables.containsKey(r))
196                 return true;
197             else if(r instanceof CompositeRelation &&
198                     containsReferenceTo((CompositeRelation)r, diffables))
199                 return true;
200         return false;
201     }
202
203     @Override
204     public Query removeRelations(Set<SCLRelation> relations) {
205         if(relations.contains(relation))
206             return EMPTY_QUERY;
207         else
208             return this;
209     }
210
211     @Override
212     public void setLocationDeep(long loc) {
213         if(location == Locations.NO_LOCATION) {
214             location = loc;
215             for(Expression parameter : parameters)
216                 parameter.setLocationDeep(loc);
217         }
218     }
219
220     @Override
221     public void accept(QueryVisitor visitor) {
222         visitor.visit(this);
223     }
224
225     @Override
226     public void splitToPhases(TIntObjectHashMap<ArrayList<Query>> result) {
227         int phase = relation.getPhase();
228         ArrayList<Query> list = result.get(phase);
229         if(list == null) {
230             list = new ArrayList<Query>();
231             result.put(phase, list);
232         }
233         list.add(this);
234     }
235
236     @Override
237     public Query accept(QueryTransformer transformer) {
238         return transformer.transform(this);
239     }
240
241 }