--- /dev/null
+package org.simantics.scl.compiler.types;
+
+import org.simantics.scl.compiler.environment.Environment;
+import org.simantics.scl.compiler.internal.types.HashCodeUtils;
+import org.simantics.scl.compiler.types.exceptions.KindUnificationException;
+import org.simantics.scl.compiler.types.exceptions.UnificationException;
+import org.simantics.scl.compiler.types.kinds.Kinds;
+
+import gnu.trove.map.hash.THashMap;
+
+public class Skeletons {
+
+ public static Type canonicalSkeleton(Type type) {
+ while(type instanceof TMetaVar) {
+ TMetaVar metaVar = (TMetaVar)type;
+ if(metaVar.ref != null)
+ type = metaVar.ref;
+ else if(metaVar.skeletonRef != null)
+ type = metaVar.skeletonRef;
+ else
+ return metaVar;
+ }
+ return type;
+ }
+
+ public static Type canonicalSkeleton(THashMap<TMetaVar,Type> unifications, Type type) {
+ while(type instanceof TMetaVar) {
+ TMetaVar metaVar = (TMetaVar)type;
+ if(metaVar.ref != null)
+ type = metaVar.ref;
+ else if(metaVar.skeletonRef != null)
+ type = metaVar.skeletonRef;
+ else {
+ Type temp = unifications.get(metaVar);
+ if(temp == null)
+ return metaVar;
+ else
+ type = temp;
+ }
+ }
+ return type;
+ }
+
+ public static boolean doesSkeletonContain(THashMap<TMetaVar,Type> unifications, Type type, TMetaVar metaVar) {
+ type = canonicalSkeleton(unifications, type);
+ if(type == metaVar)
+ return true;
+ if(type instanceof TFun) {
+ TFun fun = (TFun)type;
+ return doesSkeletonContain(unifications, fun.domain, metaVar)
+ || doesSkeletonContain(unifications, fun.range, metaVar);
+ }
+ if(type instanceof TApply) {
+ TApply apply = (TApply)type;
+ return doesSkeletonContain(unifications, apply.function, metaVar)
+ || doesSkeletonContain(unifications, apply.parameter, metaVar);
+ }
+ if(type instanceof TForAll) {
+ TForAll forAll = (TForAll)type;
+ return doesSkeletonContain(unifications, forAll.type, metaVar);
+ }
+ if(type instanceof TPred) {
+ TPred pred = (TPred)type;
+ for(Type param : pred.parameters)
+ if(doesSkeletonContain(unifications, param, metaVar))
+ return true;
+ return false;
+ }
+ else
+ return false;
+ }
+
+ /**
+ * Returns true, if unification of the skeletons of the types would succeed.
+ */
+ public static boolean areSkeletonsCompatible(THashMap<TMetaVar,Type> unifications, Type a, Type b) {
+ a = canonicalSkeleton(unifications, a);
+ b = canonicalSkeleton(unifications, b);
+ if(a == b)
+ return true;
+ Class<?> ca = a.getClass();
+ Class<?> cb = b.getClass();
+
+ if(ca == TMetaVar.class) {
+ TMetaVar ma = (TMetaVar)a;
+ if(doesSkeletonContain(unifications, b, ma))
+ return false;
+ unifications.put(ma, b);
+ return true;
+ }
+ if(cb == TMetaVar.class) {
+ TMetaVar mb = (TMetaVar)b;
+ if(doesSkeletonContain(unifications, a, mb))
+ return false;
+ unifications.put(mb, a);
+ return true;
+ }
+ if(ca != cb)
+ return false;
+ if(ca == TFun.class) {
+ TFun funA = (TFun)a;
+ TFun funB = (TFun)b;
+ return areSkeletonsCompatible(unifications, funA.domain, funB.domain)
+ && areSkeletonsCompatible(unifications, funA.range, funB.range);
+ }
+ if(ca == TApply.class) {
+ TApply applyA = (TApply)a;
+ TApply applyB = (TApply)b;
+ return areSkeletonsCompatible(unifications, applyA.function, applyB.function)
+ && areSkeletonsCompatible(unifications, applyA.parameter, applyB.parameter);
+ }
+ if(ca == TPred.class) {
+ TPred predA = (TPred)a;
+ TPred predB = (TPred)b;
+ if(predA.typeClass != predB.typeClass)
+ return false;
+ for(int i=0;i<predA.parameters.length;++i)
+ if(!areSkeletonsCompatible(unifications, predA.parameters[i], predB.parameters[i]))
+ return false;
+ return true;
+ }
+ if(ca == TForAll.class) {
+ TForAll forAllA = (TForAll)a;
+ TForAll forAllB = (TForAll)b;
+ TVar temp = Types.var(forAllA.var.getKind());
+ return areSkeletonsCompatible(unifications,
+ forAllA.type.replace(forAllA.var, temp),
+ forAllB.type.replace(forAllB.var, temp));
+ }
+ return false;
+ }
+
+ public static void unifySkeletons(Type a, Type b) throws UnificationException {
+ a = canonicalSkeleton(a);
+ b = canonicalSkeleton(b);
+
+ if(a == b)
+ return;
+ if(a instanceof TMetaVar) {
+ ((TMetaVar) a).setSkeletonRef(b);
+ return;
+ }
+ if(b instanceof TMetaVar) {
+ ((TMetaVar) b).setSkeletonRef(a);
+ return;
+ }
+
+ Class<?> ca = a.getClass();
+ Class<?> cb = b.getClass();
+ if(ca != cb) {
+ throw new UnificationException(a, b);
+ }
+ if(ca == TApply.class)
+ //unifySkeletons((TApply)a, (TApply)b);
+ Types.unify(a, b);
+ else if(ca == TFun.class)
+ unifySkeletons((TFun)a, (TFun)b);
+ else if(ca == TForAll.class)
+ unifySkeletons((TForAll)a, (TForAll)b);
+ else if(ca == TPred.class)
+ //unifySkeletons((TPred)a, (TPred)b);
+ Types.unify(a, b);
+ else if(ca == TUnion.class)
+ unifySkeletons((TUnion)a, (TUnion)b);
+ else // ca == TCon.class || ca = TVar.class
+ throw new UnificationException(a, b);
+ }
+
+ public static void unifySkeletons(TFun a, TFun b) throws UnificationException {
+ unifySkeletons(a.domain, b.domain);
+ unifySkeletons(a.range, b.range);
+ }
+
+ public static void unifySkeletons(TApply a, TApply b) throws UnificationException {
+ unifySkeletons(a.function, b.function);
+ unifySkeletons(a.parameter, b.parameter);
+ }
+
+ public static void unifySkeletons(TForAll a, TForAll b) throws UnificationException {
+ try {
+ Kinds.unify(a.var.getKind(), b.var.getKind());
+ } catch (KindUnificationException e) {
+ throw new UnificationException(a, b);
+ }
+ TVar newVar = Types.var(a.var.getKind());
+ unifySkeletons(a.type.replace(a.var, newVar), b.type.replace(b.var, newVar));
+ }
+
+ public static void unifySkeletons(TPred a, TPred b) throws UnificationException {
+ if(a.typeClass != b.typeClass
+ || a.parameters.length != b.parameters.length)
+ throw new UnificationException(a, b);
+ for(int i=0;i<a.parameters.length;++i)
+ unifySkeletons(a.parameters[i], b.parameters[i]);
+ }
+
+ public static void unifySkeletons(TUnion a, TUnion b) throws UnificationException {
+ // Nothing to do
+ }
+
+ public static Type commonSkeleton(Environment context, Type[] types) {
+ THashMap<Type[], TMetaVar> metaVarMap = new THashMap<Type[], TMetaVar>() {
+ @Override
+ protected boolean equals(Object a, Object b) {
+ return Types.equals((Type[])a, (Type[])b);
+ }
+ @Override
+ protected int hash(Object a) {
+ Type[] types = (Type[])a;
+ int hash = HashCodeUtils.SEED;
+ for(Type type : types)
+ hash = type.hashCode(hash);
+ return hash;
+ }
+ };
+ return commonSkeleton(context, metaVarMap, types);
+ }
+
+ private static TMetaVar metaVarFor(Environment context, THashMap<Type[], TMetaVar> metaVarMap, Type[] types) {
+ TMetaVar result = metaVarMap.get(types);
+ if(result == null) {
+ try {
+ result = Types.metaVar(types[0].inferKind(context));
+ } catch (KindUnificationException e) {
+ result = Types.metaVar(Kinds.STAR);
+ }
+ metaVarMap.put(types, result);
+ }
+ return result;
+ }
+
+ /**
+ * Finds the most specific type that can be unified with the all the types
+ * given as a parameter.
+ */
+ private static Type commonSkeleton(Environment context, THashMap<Type[], TMetaVar> metaVarMap, Type[] types) {
+ for(int i=0;i<types.length;++i)
+ types[i] = canonicalSkeleton(types[i]);
+
+ Type first = types[0];
+ Class<?> clazz = first.getClass();
+ for(int i=1;i<types.length;++i)
+ if(types[i].getClass() != clazz)
+ return metaVarFor(context, metaVarMap, types);
+
+ if(clazz == TCon.class) {
+ for(int i=1;i<types.length;++i)
+ if(types[i] != first)
+ return metaVarFor(context, metaVarMap, types);
+ return first;
+ }
+ else if(clazz == TApply.class) {
+ Type[] functions = new Type[types.length];
+ Type[] parameters = new Type[types.length];
+ for(int i=0;i<types.length;++i) {
+ TApply apply = (TApply)types[i];
+ functions[i] = apply.function;
+ parameters[i] = apply.parameter;
+ }
+ return Types.apply(
+ commonSkeleton(context, metaVarMap, functions),
+ commonSkeleton(context, metaVarMap, parameters));
+ }
+ else if(clazz == TFun.class) {
+ Type[] domains = new Type[types.length];
+ Type[] effects = new Type[types.length];
+ Type[] ranges = new Type[types.length];
+ for(int i=0;i<types.length;++i) {
+ TFun fun = (TFun)types[i];
+ if(fun.domain instanceof TPred)
+ return metaVarFor(context, metaVarMap, types);
+ domains[i] = fun.domain;
+ effects[i] = fun.effect;
+ ranges[i] = fun.range;
+ }
+ return Types.functionE(
+ commonSkeleton(context, metaVarMap, domains),
+ commonEffect(effects),
+ commonSkeleton(context, metaVarMap, ranges));
+ }
+ else
+ return metaVarFor(context, metaVarMap, types);
+ }
+
+ private static Type commonEffect(Type[] effects) {
+ Type first = effects[0];
+ for(int i=1;i<effects.length;++i)
+ if(!Types.equals(first, effects[i]))
+ return Types.metaVar(Kinds.EFFECT);
+ return first;
+ }
+}