(refs #7375) Replaced collectFreeVariables method by a visitor
[simantics/platform.git] / bundles / org.simantics.scl.compiler / src / org / simantics / scl / compiler / elaboration / expressions / ELet.java
1 package org.simantics.scl.compiler.elaboration.expressions;
2
3 import java.util.ArrayList;
4
5 import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
6 import org.simantics.scl.compiler.compilation.CompilationContext;
7 import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
8 import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext;
9 import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
10 import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
11 import org.simantics.scl.compiler.errors.Locations;
12 import org.simantics.scl.compiler.internal.codegen.references.BoundVar;
13 import org.simantics.scl.compiler.internal.codegen.references.IVal;
14 import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
15 import org.simantics.scl.compiler.internal.codegen.writer.RecursiveDefinitionWriter;
16 import org.simantics.scl.compiler.internal.elaboration.decomposed.DecomposedExpression;
17 import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
18 import org.simantics.scl.compiler.types.Type;
19 import org.simantics.scl.compiler.types.Types;
20 import org.simantics.scl.compiler.types.exceptions.MatchException;
21 import org.simantics.scl.compiler.types.kinds.Kinds;
22
23 import gnu.trove.map.hash.TObjectIntHashMap;
24 import gnu.trove.set.hash.TIntHashSet;
25
26 /**
27  * Generated maily from EPreLet
28  */
29 public class ELet extends Expression {
30     public Assignment[] assignments;
31     public Expression in;
32     
33     public ELet(long loc, Assignment[] assignments, Expression in) {
34         super(loc);
35         this.assignments = assignments;
36         this.in = in;
37     }
38     
39     @Override
40     public void collectVars(TObjectIntHashMap<Variable> allVars,
41             TIntHashSet vars) {
42         for(Assignment assign : assignments)
43             assign.value.collectVars(allVars, vars);
44         in.collectVars(allVars, vars);
45     }
46     
47     @Override
48     protected void updateType() throws MatchException {
49         setType(in.getType());
50     }
51    
52     /**
53      * Splits let 
54      */
55     @Override
56     public Expression simplify(SimplificationContext context) {
57         
58         // Simplify assignments
59         for(Assignment assignment : assignments) {
60             assignment.value = assignment.value.simplify(context);
61         }
62         
63         // Find strongly connected components
64         final TObjectIntHashMap<Variable> allVars = new TObjectIntHashMap<Variable>(
65                 2*assignments.length, 0.5f, -1);
66
67         for(int i=0;i<assignments.length;++i)
68             for(Variable var : assignments[i].pattern.getFreeVariables())
69                 allVars.put(var, i);
70         final boolean isRecursive[] = new boolean[assignments.length];
71         final ArrayList<int[]> components = new ArrayList<int[]>(Math.max(10, assignments.length)); 
72         new StronglyConnectedComponents(assignments.length) {
73             @Override
74             protected int[] findDependencies(int u) {
75                 TIntHashSet vars = new TIntHashSet();
76                 assignments[u].value.collectVars(allVars, vars);
77                 if(vars.contains(u))
78                     isRecursive[u] = true;
79                 return vars.toArray();
80             }
81
82             @Override
83             protected void reportComponent(int[] component) {
84                 components.add(component);
85             }
86
87         }.findComponents();
88
89         // Simplify in
90         Expression result = in.simplify(context);
91         
92         // Handle each component
93         for(int j=components.size()-1;j>=0;--j) {
94             int[] component = components.get(j);
95             boolean recursive = component.length > 1 || isRecursive[component[0]];
96             if(recursive) {
97                 Assignment[] cAssignments = new Assignment[component.length];
98                 for(int i=0;i<component.length;++i)
99                     cAssignments[i] = assignments[component[i]];
100                 result = new ELet(location, cAssignments, result);
101             }
102             else {
103                 Assignment assignment = assignments[component[0]];
104                 Expression pattern = assignment.pattern;
105                 
106                 if(pattern instanceof EVariable) {
107                     EVariable pvar = (EVariable)pattern;
108                     result = new ESimpleLet(location, pvar.variable, assignment.value, result);
109                 }
110                 else {
111                     result = new EMatch(location, new Expression[] {assignment.value},
112                                     new Case(new Expression[] {pattern}, result));
113                 }
114             }
115         }
116         
117         return result;
118     }
119
120     @Override
121     public Expression resolve(TranslationContext context) {
122         throw new InternalCompilerError("ELet should be already resolved.");
123     }
124     
125     @Override
126     public Expression replace(ReplaceContext context) {
127         Assignment[] newAssignments = new Assignment[assignments.length];
128         for(int i=0;i<assignments.length;++i)
129             newAssignments[i] = assignments[i].replace(context);            
130         Expression newIn = in.replace(context);
131         return new ELet(getLocation(), newAssignments, newIn);
132     }
133     
134     @Override
135     public IVal toVal(CompilationContext context, CodeWriter w) {
136         // Create bound variables
137         BoundVar[] vars = new BoundVar[assignments.length];
138         for(int i=0;i<assignments.length;++i) {
139             Expression pattern = assignments[i].pattern;
140             if(!(pattern instanceof EVariable))
141                 throw new InternalCompilerError("Cannot handle pattern targets in recursive assignments.");
142             vars[i] = new BoundVar(pattern.getType());
143             ((EVariable)pattern).getVariable().setVal(vars[i]);
144         }
145         
146         // Create values
147         RecursiveDefinitionWriter rdw = w.createRecursiveDefinition();
148         long range = Locations.NO_LOCATION;
149         for(Assignment assign2 : assignments) {
150             range = Locations.combine(range, assign2.pattern.location);
151             range = Locations.combine(range, assign2.value.location);
152         }
153         rdw.setLocation(range);
154         for(int i=0;i<assignments.length;++i) {
155             DecomposedExpression decomposed = 
156                     DecomposedExpression.decompose(context.errorLog, assignments[i].value);
157             CodeWriter newW = rdw.createFunction(vars[i], 
158                     decomposed.typeParameters,
159                     decomposed.effect,
160                     decomposed.returnType, 
161                     decomposed.parameterTypes);
162             IVal[] parameters = newW.getParameters();
163             for(int j=0;j<parameters.length;++j)
164                 decomposed.parameters[j].setVal(parameters[j]);
165             newW.return_(decomposed.body.toVal(context, newW));
166         }
167         return in.toVal(context, w);
168     }
169         
170     private void checkAssignments(TypingContext context) {
171         for(Assignment assign : assignments)
172             assign.pattern = assign.pattern.checkTypeAsPattern(context, Types.metaVar(Kinds.STAR));
173         for(Assignment assign : assignments)
174             assign.value = assign.value.checkType(context, assign.pattern.getType());
175     }
176     
177     @Override
178     public Expression inferType(TypingContext context) {
179         checkAssignments(context);
180         in = in.inferType(context);
181         return this;
182     }
183     
184     @Override
185     public Expression checkBasicType(TypingContext context, Type requiredType) {
186         checkAssignments(context);
187         in = in.checkType(context, requiredType);
188         return this;
189     }
190     
191     @Override
192     public Expression checkIgnoredType(TypingContext context) {
193         checkAssignments(context);
194         in = in.checkIgnoredType(context);
195         return this;
196     }
197     
198     @Override
199     public void setLocationDeep(long loc) {
200         if(location == Locations.NO_LOCATION) {
201             location = loc;
202             for(Assignment assignment : assignments)
203                 assignment.setLocationDeep(loc);
204             in.setLocationDeep(loc);
205         }
206     }
207     
208     @Override
209     public void accept(ExpressionVisitor visitor) {
210         visitor.visit(this);
211     }
212     
213     @Override
214     public Expression accept(ExpressionTransformer transformer) {
215         return transformer.transform(this);
216     }
217     
218     @Override
219     public int getSyntacticFunctionArity() {
220         return in.getSyntacticFunctionArity();
221     }
222
223 }