package org.simantics.scl.compiler.runtime; import java.io.OutputStreamWriter; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.commons.ClassRemapper; import org.objectweb.asm.commons.Remapper; import org.simantics.scl.compiler.constants.Constant; import org.simantics.scl.compiler.elaboration.modules.SCLValue; import org.simantics.scl.compiler.elaboration.rules.TransformationRule; import org.simantics.scl.compiler.environment.Environment; import org.simantics.scl.compiler.environment.GlobalOnlyEnvironment; import org.simantics.scl.compiler.internal.codegen.types.JavaTypeTranslator; import org.simantics.scl.compiler.internal.codegen.utils.JavaNamingPolicy; import org.simantics.scl.compiler.internal.codegen.utils.TransientClassBuilder; import org.simantics.scl.compiler.internal.decompilation.DecompilerFactory; import org.simantics.scl.compiler.internal.decompilation.IDecompiler; import org.simantics.scl.compiler.module.ConcreteModule; import org.simantics.scl.compiler.module.Module; import org.simantics.scl.compiler.top.SCLCompilerConfiguration; import org.simantics.scl.compiler.top.ValueNotFound; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import gnu.trove.map.hash.THashMap; public class RuntimeModule { private static final Logger LOGGER = LoggerFactory.getLogger(RuntimeModule.class); public static final boolean VALIDATE_CLASS_NAMES = true; public static final boolean TRACE_CLASS_CREATION = true; Module module; ModuleClassLoader classLoader; THashMap valueCache = new THashMap(); TransientClassBuilder classBuilder; RuntimeModuleMap parentModuleMap; class ModuleClassLoader extends ClassLoader implements MutableClassLoader { // Module module; String moduleName; THashMap localClasses = new THashMap(); //ModuleClassLoaderMap parentModules; int transientPackageId = 0; public ModuleClassLoader(ClassLoader parent) { super(parent); this.moduleName = module.getName(); } public synchronized void addClass(String name, byte[] class_) { if(TRACE_CLASS_CREATION) System.out.println("addClass " + name + " (" + class_.length + " bytes)"); if(VALIDATE_CLASS_NAMES) validateClassName(name); localClasses.put(name, class_); } public synchronized void addClasses(Map classes) { if(TRACE_CLASS_CREATION) for(String name : classes.keySet()) System.out.println("addClass " + name + " (" + classes.get(name).length + " bytes)"); if(VALIDATE_CLASS_NAMES) for(String name : classes.keySet()) validateClassName(name); localClasses.putAll(classes); } private void validateClassName(String name) { //System.out.println(name); /*if(!name.startsWith(SCL_PACKAGE_PREFIX) || !extractClassLoaderId(name).equals(moduleName)) throw new IllegalArgumentException("Class name " + name + " does not start with '" + SCL_PACKAGE_PREFIX + moduleName + "$'."); */ } @Override public byte[] getBytes(String name) { // Non-SCL classes are not handled here if(!name.startsWith(SCL_PACKAGE_PREFIX)) return null; // Determine the id of the class loader which is responsible of the class String requestedModuleName = extractClassLoaderId(name); // Is class defined locally in this class loader? if(requestedModuleName.equals(this.moduleName)) { String internalName = name.replace('.', '/'); byte[] bytes = module.getClass(internalName); if(bytes != null) return bytes; return localClasses.get(internalName); } // Find suitable class loader that has this class locally { RuntimeModule parentModule = parentModuleMap.get(requestedModuleName); if(parentModule == null) return null; // Find the class from the ancestor class loader return parentModule.classLoader.getBytes(name); } } synchronized Class getLocalClass(String name) throws ClassNotFoundException { // Is class already loaded Class clazz = findLoadedClass(name); if(clazz != null) return clazz; // If we have bytecode for it, let's define the class String internalName = name.replace('.', '/'); byte[] bytes = module.getClass(internalName); if(bytes == null) { bytes = localClasses.get(internalName); if(bytes == null) throw new ClassNotFoundException(name); } if(SCLCompilerConfiguration.SHOW_DECOMPILED_BYTECODE && SCLCompilerConfiguration.debugFilter(moduleName)) showDecompiledBytecode(internalName); return defineClass(name, bytes, 0, bytes.length); } private Class getClass(String name) throws ClassNotFoundException { System.out.println(moduleName + ":getClass " + name); // If the class is not generated from SCL, use parent class loader if(!name.startsWith(SCL_PACKAGE_PREFIX)) { try { return getParent().loadClass(name); } catch(ClassNotFoundException e) { for(RuntimeModule module : parentModuleMap.values()) try { return module.classLoader.getParent().loadClass(name); } catch(ClassNotFoundException e2) { } } throw new ClassNotFoundException(name); } // Determine the id of the class loader which is responsible of the class String requestedModuleName = extractClassLoaderId(name); // Is class defined locally in this class loader? if(requestedModuleName.equals(this.moduleName)) return getLocalClass(name); // Find suitable class loader that has this class locally { RuntimeModule parentModule = parentModuleMap.get(requestedModuleName); if(parentModule == null) { LOGGER.error("requestedModuleName = " + requestedModuleName); LOGGER.error("this.moduleName = " + this.moduleName); throw new ClassNotFoundException(name); } // Find the class from the ancestor class loader return parentModule.classLoader.getLocalClass(name); } } // protected Class loadClass(String name, boolean resolve) // throws ClassNotFoundException // { // synchronized (getClassLoadingLock(name)) { // // First, check if the class has already been loaded // Class c = findLoadedClass(name); // if (c == null) { // c = getClass(name); // } // if (resolve) { // resolveClass(c); // } // return c; // } // } @Override public synchronized Class loadClass(String name, boolean resolve) throws ClassNotFoundException { Class clazz = getClass(name); if (resolve) resolveClass(clazz); return clazz; } @Override public Class loadClass(String name) throws ClassNotFoundException { return super.loadClass(name); } public Module getModule(String moduleName) { //System.out.println("ModuleClassLoader.getModule(" + moduleName + ")"); if(moduleName.equals(this.moduleName)) return module; else { RuntimeModule parentModule = parentModuleMap.get(moduleName); if(parentModule == null) throw new RuntimeException("Didn't find module " + moduleName + "."); return parentModule.module; } } public String getModuleName() { return moduleName; } public synchronized String getFreshPackageName() { return moduleName + "$" + (++transientPackageId); } @Override public THashMap getConstantCache() { return null; } @Override public ClassLoader getClassLoader() { return this; } private void showDecompiledBytecode(String className) { IDecompiler decompiler = DecompilerFactory.getDecompiler(); if(decompiler == null) return; decompiler.decompile(this, className, new OutputStreamWriter(System.out)); } } public Environment moduleEnvironment = new GlobalOnlyEnvironment() { @Override protected Collection getModules() { ArrayList result = new ArrayList(parentModuleMap.size() + 1); result.add(module); for(RuntimeModule rm : parentModuleMap.values()) result.add(rm.module); return result; } @Override protected Module getModule(String name) { return classLoader.getModule(name); } @Override public void collectRules(Collection rules) { } @Override public List getFieldAccessors(String name) { // TODO Not clear if this is needed. return null; } }; public RuntimeModule(Module module, RuntimeModuleMap parentModuleMap, ClassLoader parentClassLoader) { if(parentClassLoader == null) throw new NullPointerException(); this.module = module; this.parentModuleMap = parentModuleMap; this.classLoader = new ModuleClassLoader(parentClassLoader); this.classBuilder = new TransientClassBuilder(classLoader, new JavaTypeTranslator(moduleEnvironment)); } public Object getValue(String name) throws ValueNotFound { // First try cache if(valueCache.containsKey(name)) return valueCache.get(name); // Try to resolve the name SCLValue valueConstructor = module.getValue(name); if(valueConstructor == null) throw new ValueNotFound(module.getName() + "/" + name); // Realize the value and cache it Object value = valueConstructor.realizeValue(getClassBuilder()); valueCache.put(name, value); return value; } private TransientClassBuilder getClassBuilder() { return classBuilder; } public Module getModule() { return module; } public MutableClassLoader getMutableClassLoader() { return classLoader; } public static String extractClassLoaderId(String className) { int p = className.indexOf('$', MutableClassLoader.SCL_PACKAGE_PREFIX_LENGTH); return JavaNamingPolicy.classNameToModuleName(p < 0 ? className.substring(MutableClassLoader.SCL_PACKAGE_PREFIX_LENGTH) : className.substring(MutableClassLoader.SCL_PACKAGE_PREFIX_LENGTH, p)); } public void dispose() { module.dispose(); module = null; valueCache.clear(); parentModuleMap = null; classLoader = null; classBuilder = null; } public class ClassNameRecordingRemapper extends Remapper { private final Set classNames; public ClassNameRecordingRemapper(Set classNames) { this.classNames = classNames; } @Override public String map(String typeName) { classNames.add(typeName); return super.map(typeName); } @Override public String mapDesc(String desc) { //classNames.add(desc); return super.mapDesc(desc); } @Override public String mapFieldName(String owner, String name, String desc) { return super.mapFieldName(owner, name, desc); } @Override public String mapInvokeDynamicMethodName(String name, String desc) { // TODO Auto-generated method stub return super.mapInvokeDynamicMethodName(name, desc); } @Override public String mapMethodDesc(String desc) { //classNames.add(desc); return super.mapMethodDesc(desc); } @Override public String mapMethodName(String owner, String name, String desc) { // TODO Auto-generated method stub return super.mapMethodName(owner, name, desc); } @Override public String mapSignature(String signature, boolean typeSignature) { //classNames.add(signature); return super.mapSignature(signature, typeSignature); } @Override public String mapType(String type) { classNames.add(type); return super.mapType(type); } @Override public String[] mapTypes(String[] types) { for(String type : types) classNames.add(type); // TODO Auto-generated method stub return super.mapTypes(types); } @Override public Object mapValue(Object value) { //classNames.add(value.toString()); // TODO Auto-generated method stub return super.mapValue(value); } } public Set classReferences(String className) { try { HashSet referencedClasses = new HashSet<>(); ClassNameRecordingRemapper m = new ClassNameRecordingRemapper(referencedClasses); ClassReader cr = new ClassReader(module.getClass(className)); int ASM5 = 5 << 16 | 0 << 8 | 0; //TraceClassVisitor tcv = new TraceClassVisitor(null, new PrintWriter(System.err)); ClassVisitor cv1 = new ClassVisitor(ASM5) {}; ClassVisitor cv = new ClassRemapper(cv1, m); cr.accept(cv, ClassReader.SKIP_DEBUG); System.err.println(className + " refs: " + referencedClasses); return referencedClasses; } catch (Exception e) { e.printStackTrace(); } return null; } public void loadReferences() { ConcreteModule cm = (ConcreteModule)module; try { for(String className : cm.getClasses().keySet()) { Set refs = classReferences(className); for(String s : refs) { String internalName = s.replace('/', '.'); try { classLoader.loadClass(internalName); } catch (Throwable e) { e.printStackTrace(); } } } } catch (Throwable e) { e.printStackTrace(); } } }