package org.simantics.scl.compiler.elaboration.query.compilation; import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just; import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Nothing; import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply; import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda; import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var; import org.simantics.scl.compiler.common.exceptions.InternalCompilerError; import org.simantics.scl.compiler.common.names.Name; import org.simantics.scl.compiler.common.names.Names; import org.simantics.scl.compiler.compilation.CompilationContext; import org.simantics.scl.compiler.constants.BooleanConstant; import org.simantics.scl.compiler.elaboration.contexts.TypingContext; import org.simantics.scl.compiler.elaboration.expressions.Case; import org.simantics.scl.compiler.elaboration.expressions.EApply; import org.simantics.scl.compiler.elaboration.expressions.EConstant; import org.simantics.scl.compiler.elaboration.expressions.EIf; import org.simantics.scl.compiler.elaboration.expressions.ELiteral; import org.simantics.scl.compiler.elaboration.expressions.EMatch; import org.simantics.scl.compiler.elaboration.expressions.ESimpleLambda; import org.simantics.scl.compiler.elaboration.expressions.ESimpleLet; import org.simantics.scl.compiler.elaboration.expressions.EVariable; import org.simantics.scl.compiler.elaboration.expressions.Expression; import org.simantics.scl.compiler.elaboration.expressions.Variable; import org.simantics.scl.compiler.elaboration.java.Builtins; import org.simantics.scl.compiler.errors.Locations; import org.simantics.scl.compiler.types.TPred; import org.simantics.scl.compiler.types.Type; import org.simantics.scl.compiler.types.Types; import org.simantics.scl.compiler.types.exceptions.MatchException; public class QueryCompilationContext { TypingContext context; QueryCompilationMode mode; Type resultType; Expression continuation; double branching = 1.0; double cost = 0.0; public QueryCompilationContext( TypingContext context, QueryCompilationMode mode, Type resultType, Expression continuation) { this.context = context; this.mode = mode; this.resultType = resultType; this.continuation = continuation; } public Expression failure() { switch(mode) { case ITERATE: return new EConstant(Builtins.TUPLE_CONSTRUCTORS[0]); case GET_FIRST: return new EConstant(Builtins.Nothing, resultType); case GET_ALL: return new EConstant(Builtins.LIST_CONSTRUCTORS[0], resultType); case CHECK: return new ELiteral(new BooleanConstant(false)); default: throw new InternalCompilerError(); } } public Expression disjunction(Expression a, Expression b) { switch(mode) { case ITERATE: return new ESimpleLet(new Variable("_", Types.UNIT), a, b); case GET_FIRST: { Variable var = new Variable("temp", a.getType()); return new EMatch(a, new Case(new EConstant(Builtins.Nothing), b), new Case(new EVariable(var), new EVariable(var))); } case GET_ALL: { try { return new EApply(context.getCompilationContext().getConstant(Names.Prelude_appendList, Types.matchApply(Types.LIST, a.getType())), a, b); } catch (MatchException e) { throw new InternalCompilerError(); } } case CHECK: return new EIf(a, new ELiteral(new BooleanConstant(true)), b); default: throw new InternalCompilerError(); } } public Expression condition(Expression condition, Expression continuation) { return new EIf(condition, continuation, failure()); } public void condition(Expression condition) { continuation = condition(condition, continuation); } public void equalityCondition(long location, Expression a, Expression b) { Type type = a.getType(); condition(new EApply( location, Types.PROC, context.getCompilationContext().getConstant(Names.Builtin_equals, type), new Expression[] { a, b } )); } public void let(Variable variable, Expression value) { continuation = new ESimpleLet(variable, value, continuation); } public void iterateMaybe(Variable variable, Expression value) { continuation = new EMatch(value, new Case(Nothing(variable.getType()), failure()), new Case(Just(var(variable)), continuation)); } public void match(Expression pattern, Expression value, boolean mayFail) { if(mayFail) continuation = new EMatch(value, new Case(pattern, continuation), new Case(new EVariable(new Variable("_", pattern.getType())), failure())); else continuation = new EMatch(value, new Case(pattern, continuation)); } public void iterateList(Variable variable, Expression list) { try { switch(mode) { case ITERATE: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Prelude_iterList, variable.getType(), Types.PROC, Types.tupleConstructor(0)), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), list } ); break; case CHECK: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Prelude_any, variable.getType(), Types.PROC), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), list } ); break; case GET_ALL: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Prelude_concatMap, variable.getType(), Types.PROC, Types.matchApply(Types.LIST, continuation.getType())), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), list } ); break; case GET_FIRST: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Prelude_mapFirst, variable.getType(), Types.PROC, Types.matchApply(Types.MAYBE, continuation.getType())), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), list } ); break; default: throw new InternalCompilerError("iterateList could not handle mode " + mode); } } catch(MatchException e) { throw new InternalCompilerError(e); } } public void iterateVector(Variable variable, Expression vector) { try { switch(mode) { case ITERATE: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Vector_iterVector, variable.getType(), Types.PROC, continuation.getType()), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), vector } ); break; case CHECK: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Vector_anyVector, variable.getType(), Types.PROC), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), vector } ); break; case GET_ALL: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Vector_concatMapVector, variable.getType(), Types.PROC, Types.matchApply(Types.LIST, continuation.getType())), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), vector } ); break; case GET_FIRST: continuation = new EApply( Locations.NO_LOCATION, Types.PROC, context.getCompilationContext().getConstant(Names.Vector_mapFirstVector, variable.getType(), Types.PROC, Types.matchApply(Types.MAYBE, continuation.getType())), new Expression[] { new ESimpleLambda(Types.PROC, variable, continuation), vector } ); break; default: throw new InternalCompilerError("iterateVector could not handle mode " + mode); } } catch(MatchException e) { throw new InternalCompilerError(e); } } public void iterateMSet(Variable variable, Expression set) { try { switch(mode) { case ITERATE: continuation = apply(context.getCompilationContext(), Types.PROC, Names.MSet_iter, variable.getType(), Types.PROC, continuation.getType(), lambda(Types.PROC, variable, continuation), set ); break; case GET_FIRST: continuation = apply(context.getCompilationContext(), Types.PROC, Names.MSet_mapFirst, variable.getType(), Types.PROC, Types.matchApply(Types.MAYBE, continuation.getType()), lambda(Types.PROC, variable, continuation), set ); break; default: throw new InternalCompilerError("iterateMSet could not handle mode " + mode); } } catch(MatchException e) { throw new InternalCompilerError(e); } } public void updateCost(double localBranching, double localCost) { branching *= localBranching; cost *= localBranching; cost += localCost; } public Expression getConstant(Name name, Type[] typeParameters) { return context.getCompilationContext().getConstant(name, typeParameters); } public QueryCompilationContext createCheckContext() { return new QueryCompilationContext(context, QueryCompilationMode.CHECK, null, new ELiteral(new BooleanConstant(true))); } public double getBranching() { return branching; } public double getCost() { return cost; } public QueryCompilationContext createSubcontext(Expression innerExpression) { return new QueryCompilationContext(context, mode, resultType, innerExpression); } public void setContinuation(Expression continuation) { this.continuation = continuation; } public Expression getContinuation() { return continuation; } public Expression disjunction(Expression[] disjuncts) { Expression result = failure(); for(int i=disjuncts.length-1;i>=0;--i) result = disjunction(disjuncts[i], result); return result; } public TypingContext getTypingContext() { return context; } public EVariable getEvidence(long location, TPred pred) { EVariable evidence = new EVariable(location, null); evidence.setType(pred); context.addConstraintDemand(evidence); return evidence; } public CompilationContext getCompilationContext() { return context.getCompilationContext(); } }