package org.simantics.scl.compiler.elaboration.expressions; import gnu.trove.map.hash.TObjectIntHashMap; import gnu.trove.set.hash.THashSet; import gnu.trove.set.hash.TIntHashSet; import org.simantics.scl.compiler.common.exceptions.InternalCompilerError; import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext; import org.simantics.scl.compiler.elaboration.contexts.TranslationContext; import org.simantics.scl.compiler.elaboration.contexts.TypingContext; import org.simantics.scl.compiler.elaboration.modules.SCLValue; import org.simantics.scl.compiler.environment.Environment; import org.simantics.scl.compiler.errors.Locations; import org.simantics.scl.compiler.internal.codegen.references.IVal; import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter; import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator; import org.simantics.scl.compiler.types.Type; import org.simantics.scl.compiler.types.Types; import org.simantics.scl.compiler.types.exceptions.MatchException; import org.simantics.scl.compiler.types.exceptions.UnificationException; import org.simantics.scl.compiler.types.kinds.Kinds; public class EBind extends SimplifiableExpression { public Expression pattern; public Expression value; public Expression in; private EVariable monadEvidence; SCLValue bindFunction; Type monadType; Type valueContentType; Type inContentType; public EBind(long loc, Expression pattern, Expression value, Expression in) { super(loc); this.pattern = pattern; this.value = value; this.in = in; } public EBind(long loc, Expression pattern, Expression value, Expression in, SCLValue bindFunction) { super(loc); this.pattern = pattern; this.value = value; this.in = in; } @Override public void collectRefs(final TObjectIntHashMap allRefs, final TIntHashSet refs) { value.collectRefs(allRefs, refs); in.collectRefs(allRefs, refs); } @Override public void collectVars(TObjectIntHashMap allVars, TIntHashSet vars) { value.collectVars(allVars, vars); in.collectVars(allVars, vars); } @Override protected void updateType() throws MatchException { setType(in.getType()); } @Override public Expression checkBasicType(TypingContext context, Type requiredType) { monadType = Types.metaVar(Kinds.STAR_TO_STAR); inContentType = Types.metaVar(Kinds.STAR); Type monadContent = Types.apply(monadType, inContentType); try { Types.unify(requiredType, monadContent); } catch (UnificationException e) { context.typeError(location, requiredType, monadContent); return this; } Variable variable = new Variable("monadEvidence"); variable.setType(Types.pred(Types.MONAD, monadType)); monadEvidence = new EVariable(getLocation(), variable); monadEvidence.setType(variable.getType()); context.addConstraintDemand(monadEvidence); pattern = pattern.checkTypeAsPattern(context, Types.metaVar(Kinds.STAR)); valueContentType = pattern.getType(); value = value.checkType(context, Types.apply(monadType, valueContentType)); in = in.checkType(context, requiredType); Type inType = in.getType(); setType(inType); return this; } @Override public IVal toVal(Environment env, CodeWriter w) { throw new InternalCompilerError("EBind should be eliminated."); } /** * Splits let */ @Override public Expression simplify(SimplificationContext context) { value = value.simplify(context); in = in.simplify(context); pattern = pattern.simplify(context); long loc = getLocation(); Expression simplified = new EApply(loc, new EConstant(loc, bindFunction, Types.canonical(monadType), Types.canonical(valueContentType), Types.canonical(inContentType)), monadEvidence, value, new ELambda(loc, new Case[] { new Case(new Expression[] { pattern }, in) })); simplified.setType(getType()); return simplified.simplify(context); } @Override public void collectFreeVariables(THashSet vars) { in.collectFreeVariables(vars); value.collectFreeVariables(vars); pattern.removeFreeVariables(vars); } @Override public Expression resolve(TranslationContext context) { value = value.resolve(context); context.pushFrame(); pattern = pattern.resolveAsPattern(context); in = in.resolve(context); context.popFrame(); bindFunction = context.getBindFunction(); return this; } @Override public Expression decorate(ExpressionDecorator decorator) { pattern = pattern.decorate(decorator); value = value.decorate(decorator); in = in.decorate(decorator); return decorator.decorate(this); } @Override public void collectEffects(THashSet effects) { pattern.collectEffects(effects); value.collectEffects(effects); in.collectEffects(effects); } @Override public void setLocationDeep(long loc) { if(location == Locations.NO_LOCATION) { location = loc; pattern.setLocationDeep(loc); value.setLocationDeep(loc); in.setLocationDeep(loc); } } @Override public void accept(ExpressionVisitor visitor) { visitor.visit(this); } @Override public void forVariables(VariableProcedure procedure) { pattern.forVariables(procedure); value.forVariables(procedure); if(monadEvidence != null) monadEvidence.forVariables(procedure); } @Override public Expression accept(ExpressionTransformer transformer) { return transformer.transform(this); } }