--- /dev/null
+package org.simantics.scl.compiler.tests;
+
+import java.util.Arrays;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.simantics.scl.compiler.elaboration.expressions.EVariable;
+import org.simantics.scl.compiler.elaboration.expressions.Expression;
+import org.simantics.scl.compiler.elaboration.expressions.Variable;
+import org.simantics.scl.compiler.environment.AbstractLocalEnvironment;
+import org.simantics.scl.compiler.environment.Environment;
+import org.simantics.scl.compiler.environment.LocalEnvironment;
+import org.simantics.scl.compiler.environment.specification.EnvironmentSpecification;
+import org.simantics.scl.compiler.errors.CompilationErrorFormatter;
+import org.simantics.scl.compiler.module.repository.ImportFailure;
+import org.simantics.scl.compiler.module.repository.ImportFailureException;
+import org.simantics.scl.compiler.module.repository.ModuleRepository;
+import org.simantics.scl.compiler.runtime.RuntimeEnvironment;
+import org.simantics.scl.compiler.source.repository.CompositeModuleSourceRepository;
+import org.simantics.scl.compiler.source.repository.SourceRepositories;
+import org.simantics.scl.compiler.top.ExpressionEvaluator;
+import org.simantics.scl.compiler.top.SCLExpressionCompilationException;
+import org.simantics.scl.compiler.types.Type;
+import org.simantics.scl.compiler.types.Types;
+import org.simantics.scl.runtime.function.Function;
+import org.simantics.scl.runtime.tuple.Tuple0;
+
+import junit.framework.Assert;
+
+public class TestExpressionEvaluator {
+
+ public static final boolean TIMING = false;
+ public static final int COUNT = 10000;
+
+ ModuleRepository moduleRepository;
+
+ RuntimeEnvironment runtimeEnvironment;
+
+ @Before
+ public void initialize() throws Exception {
+ moduleRepository = InitialRepository.getInitialRepository();
+
+ // Environment for compiling expressions
+ EnvironmentSpecification environmentSpecification = new EnvironmentSpecification();
+ environmentSpecification.importModule("Builtin", "");
+ environmentSpecification.importModule("Prelude", "");
+
+ try {
+ runtimeEnvironment = moduleRepository.createRuntimeEnvironment(environmentSpecification,
+ getClass().getClassLoader());
+ } catch(ImportFailureException e) {
+ for(ImportFailure failure : e.failures)
+ System.err.println("Failed to import " + failure.moduleName);
+ throw e;
+ }
+ }
+
+ private void testExpression0(String expressionText,
+ Object expectedValue,
+ Type expectedType) throws Exception {
+ // Compiling and running expression
+ try {
+ Object result = new ExpressionEvaluator(runtimeEnvironment, expressionText)
+ .expectedType(expectedType)
+ .eval();
+ if(expectedValue != null)
+ Assert.assertEquals(expectedValue, result);
+ } catch(SCLExpressionCompilationException e) {
+ System.out.println(CompilationErrorFormatter.toString(expressionText, e.getErrors()));
+ throw e;
+ }
+ }
+
+ private void testExpression(String expressionText,
+ Object expectedValue,
+ Type expectedType) throws Exception {
+ if(TIMING) {
+ System.out.println(expressionText);
+ long beginTime = System.nanoTime();
+ for(int i=0;i<COUNT;++i)
+ testExpression0(expressionText, expectedValue, expectedType);
+ long endTime = System.nanoTime();
+ System.out.println( " " + (endTime-beginTime)*1e-6/COUNT + " ms");
+ }
+ else
+ testExpression0(expressionText, expectedValue, expectedType);
+ }
+
+ @Test
+ public void testExpressionCompiler() throws Exception {
+ testExpression("1",
+ Integer.valueOf(1),
+ Types.INTEGER);
+ testExpression("1+2",
+ Integer.valueOf(3),
+ Types.INTEGER);
+ testExpression("map (\\(_,x) -> x) [(1,2),(2,3)]",
+ Arrays.asList(2.0, 3.0),
+ Types.list(Types.DOUBLE));
+ testExpression("map (\\x -> snd x) [(1,2),(2,3)]",
+ Arrays.asList(2.0, 3.0),
+ Types.list(Types.DOUBLE));
+ testExpression("let f x = x+1 in (f . f . f) 3",
+ Double.valueOf(6.0),
+ Types.DOUBLE);
+ if(!TIMING)
+ testExpression("print \"Hello world!\"",
+ Tuple0.INSTANCE,
+ Types.UNIT);
+ testExpression("[1,2+3,4+5]",
+ Arrays.asList(1,5,9),
+ Types.list(Types.INTEGER));
+ testExpression("let a = 5.3 in let f x = x+a in f 3",
+ Double.valueOf(8.3),
+ Types.DOUBLE);
+ testExpression("let mm x y = if x < y then x else y in mm 2 (mm 1 3)",
+ Double.valueOf(1.0),
+ Types.DOUBLE);
+ }
+
+ @Test
+ public void testLocalEnvironment() throws Exception {
+ String expressionText = "a + b";
+ LocalEnvironment localEnvironment = new AbstractLocalEnvironment() {
+ Variable[] localParameters = new Variable[] {
+ new Variable("a", Types.DOUBLE),
+ new Variable("b", Types.DOUBLE),
+ };
+
+ @Override
+ public Expression resolve(Environment environment, String localName) {
+ if(localName.equals("a"))
+ return new EVariable(localParameters[0]);
+ else if(localName.equals("b"))
+ return new EVariable(localParameters[1]);
+ else
+ return null;
+ }
+
+ @Override
+ protected Variable[] getContextVariables() {
+ return localParameters;
+ }
+ };
+ try {
+ Object result = new ExpressionEvaluator(runtimeEnvironment, expressionText)
+ .localEnvironment(localEnvironment)
+ .expectedType(Types.DOUBLE)
+ .eval();
+ Assert.assertEquals(
+ Double.valueOf(15.0),
+ ((Function)result).apply(7.0, 8.0));
+ } catch(SCLExpressionCompilationException e) {
+ System.out.println(CompilationErrorFormatter.toString(expressionText, e.getErrors()));
+ throw e;
+ }
+ }
+
+ @Test
+ public void testArities() throws Exception {
+ for(int arity=1;arity<50;++arity) {
+ // Build expressions
+ StringBuilder b = new StringBuilder();
+ b.append('\\');
+ for(int i=0;i<arity;++i)
+ b.append("v" + i + " ");
+ b.append("-> ");
+ for(int i=0;i<arity;++i) {
+ if(i > 0)
+ b.append(" + ");
+ b.append("v" + i);
+ }
+ //System.out.println(b.toString());
+
+ // Compile
+ Type expectedType = Types.INTEGER;
+ for(int i=0;i<arity;++i)
+ expectedType = Types.function(Types.INTEGER, expectedType);
+
+ Function function = (Function)new ExpressionEvaluator(runtimeEnvironment, b.toString())
+ .expectedType(expectedType)
+ .interpretIfPossible(false)
+ .eval();
+
+ // Evaluate
+ Object[] parameters = new Object[arity];
+ int sum = 0;
+ for(int i=0;i<arity;++i) {
+ int value = i+1;
+ parameters[i] = value;
+ sum += value;
+ }
+ Object result = function.applyArray(parameters);
+ Assert.assertEquals(sum, result);
+ }
+ }
+}