]> gerrit.simantics Code Review - simantics/python.git/blobdiff - org.simantics.pythonlink.win32.x86_64/src/sclpy.c
Plug leaks from PySequence_GetItem.
[simantics/python.git] / org.simantics.pythonlink.win32.x86_64 / src / sclpy.c
index 4881cc7d7c26d4a20979542530230724dc1df5f2..caed56dc8c8ab30e3cc9c0d2437969eca7470571 100644 (file)
@@ -48,50 +48,89 @@ static PyObject *
 writeToSCL(PyObject *self, PyObject *args)
 {
     if (currentEnv != NULL && sclWriter != NULL) {
-       JNIEnv *env = currentEnv;
-
                Py_UNICODE *what;
                Py_ssize_t length;
+       JNIEnv *env = currentEnv;
+
                if (!PyArg_ParseTuple(args, "u#", &what, &length))
-                       return Py_BuildValue("");
+                       Py_RETURN_NONE;
 
                {
-               jclass writerClass = (*env)->FindClass(env, WRITER_CLASS);
-               jmethodID writeMethod = (*env)->GetMethodID(env, writerClass, "write", "([CII)V");
-               jcharArray chars = (*env)->NewCharArray(env, (jsize)length);
-
-               (*env)->SetCharArrayRegion(env, chars, 0, length, what);
-               (*env)->CallVoidMethod(env, sclWriter, writeMethod, chars, 0, length);
+                       PyThreadState *my_ts = PyThreadState_Get();
+                       if (my_ts == main_ts) {
+                               jclass writerClass = (*env)->FindClass(env, WRITER_CLASS);
+                               jmethodID writeMethod = (*env)->GetMethodID(env, writerClass, "write", "([CII)V");
+                               jcharArray chars = (*env)->NewCharArray(env, (jsize)length);
+
+                               (*env)->SetCharArrayRegion(env, chars, 0, length, what);
+                               Py_BEGIN_ALLOW_THREADS
+                               (*env)->CallVoidMethod(env, sclWriter, writeMethod, chars, 0, length);
+                               Py_END_ALLOW_THREADS
+                       } else {
+                               //TODO
+                       }
                }
     }
 
-    return Py_BuildValue("");
+       Py_RETURN_NONE;
+}
+
+static PyObject *
+flushSCL(PyObject *self, PyObject *args)
+{
+    if (currentEnv != NULL && sclWriter != NULL) {
+       JNIEnv *env = currentEnv;
+       PyThreadState *my_ts = PyThreadState_Get();
+       if (my_ts != main_ts) {
+               // TODO: Process calls from other threads
+                       Py_RETURN_NONE;
+       }
+
+       {
+                       jclass writerClass = (*env)->FindClass(env, WRITER_CLASS);
+                       jmethodID flushMethod = (*env)->GetMethodID(env, writerClass, "flush", "()V");
+
+               Py_BEGIN_ALLOW_THREADS
+                       (*env)->CallVoidMethod(env, sclWriter, flushMethod);
+                       Py_END_ALLOW_THREADS
+       }
+    }
+
+    Py_RETURN_NONE;
 }
 
 static PyMethodDef sclWriterMethods[] = {
     {"write", writeToSCL, METH_VARARGS, "Write something."},
+       {"flush", flushSCL, METH_VARARGS, "Flush output."},
     {NULL, NULL, 0, NULL}
 };
 
+JNIEXPORT void JNICALL Java_org_simantics_pythonlink_PythonContext_initializePython(JNIEnv *env, jobject thisObj, jobject writer) {
+    Py_Initialize();
 
-JNIEXPORT jlong JNICALL Java_org_simantics_pythonlink_PythonContext_createContextImpl(JNIEnv *env, jobject thisObj) {
-       char name[16];
+    {
+       static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "sclwriter", NULL, -1, sclWriterMethods, };
+               PyObject *m = PyModule_Create(&moduledef);
 
-       if (!main_ts) {
-        Py_Initialize();
+       sclWriter = (*env)->NewGlobalRef(env, writer);
 
-        {
-               static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "sclwriter", NULL, -1, sclWriterMethods, };
-                       PyObject *m = PyModule_Create(&moduledef);
+       if (m == NULL) throwException(env, PYTHON_EXCEPTION, "Failed to create SCL writer module");
 
-               if (m == NULL) throwException(env, PYTHON_EXCEPTION, "Failed to create SCL writer module");
-                       PySys_SetObject("stdout", m);
-                       PySys_SetObject("stderr", m);
-        }
+               PySys_SetObject("stdout", m);
+               PySys_SetObject("stderr", m);
+    }
+
+       hasNumpy = _import_array();
+       hasNumpy = hasNumpy != -1;
 
-       hasNumpy = _import_array();
-       hasNumpy = hasNumpy != -1;
-       main_ts = PyEval_SaveThread();
+       main_ts = PyEval_SaveThread();
+}
+
+JNIEXPORT jlong JNICALL Java_org_simantics_pythonlink_PythonContext_createContextImpl(JNIEnv *env, jobject thisObj) {
+       char name[16];
+
+       if (!main_ts) {
+               return 0;
        }
 
        sprintf(name, "SCL_%d", ++moduleCount);
@@ -115,6 +154,7 @@ JNIEXPORT jlong JNICALL Java_org_simantics_pythonlink_PythonContext_createContex
                        PyDict_SetItemString(dict, "__builtin__", builtin);
                        PyDict_SetItemString(dict, "__builtins__", builtin);
 
+                       PyEval_SaveThread();
                        return (jlong)modDef;
                }
        }
@@ -650,9 +690,11 @@ jobjectArray pythonSequenceAsStringArray(JNIEnv *env, PyObject *seq) {
                PyObject *item = PySequence_GetItem(seq, i);
                if (PyUnicode_Check(item)) {
                        jstring value = pythonStringAsJavaString(env, item);
+                       Py_DECREF(item);
                        (*env)->SetObjectArrayElement(env, array, i, value);
                }
                else {
+                       Py_DECREF(item);
                        throwPythonException(env, "List item not a string");
                        return NULL;
                }
@@ -671,10 +713,12 @@ jdoubleArray pythonSequenceAsDoubleArray(JNIEnv *env, PyObject *seq) {
        for (i = 0; i < jlen; i++) {
                PyObject *item = PySequence_GetItem(seq, i);
                if (PyFloat_Check(item)) {
-                       double value = PyFloat_AsDouble(item);
+                       jdouble value = PyFloat_AsDouble(item);
+                       Py_DECREF(item);
                        (*env)->SetDoubleArrayRegion(env, array, i, 1, &value);
                }
                else {
+                       Py_DECREF(item);
                        throwPythonException(env, "List item not a floating point value");
                        return NULL;
                }
@@ -714,6 +758,7 @@ jobjectArray pythonSequenceAsObjectArray(JNIEnv *env, PyObject *seq) {
        for (i = 0; i < jlen; i++) {
                PyObject *item = PySequence_GetItem(seq, i);
                jobject object = pythonObjectAsObject(env, item);
+               Py_DECREF(item);
                (*env)->SetObjectArrayElement(env, array, i, object);
        }
 
@@ -731,9 +776,11 @@ jbooleanArray pythonSequenceAsBooleanArray(JNIEnv *env, PyObject *seq) {
                PyObject *item = PySequence_GetItem(seq, i);
                if (PyBool_Check(item)) {
                        jboolean value = item == Py_True;
+                       Py_DECREF(item);
                        (*env)->SetBooleanArrayRegion(env, array, i, 1, &value);
                }
                else {
+                       Py_DECREF(item);
                        throwPythonException(env, "List item not a boolean");
                        return NULL;
                }
@@ -753,9 +800,11 @@ jintArray pythonSequenceAsIntegerArray(JNIEnv *env, PyObject *seq) {
                PyObject *item = PySequence_GetItem(seq, i);
                if (PyLong_Check(item)) {
                        jint value = PyLong_AsLong(item);
+                       Py_DECREF(item);
                        (*env)->SetIntArrayRegion(env, array, i, 1, &value);
                }
                else {
+                       Py_DECREF(item);
                        throwPythonException(env, "List item not an integer");
                        return NULL;
                }
@@ -775,9 +824,11 @@ jlongArray pythonSequenceAsLongArray(JNIEnv *env, PyObject *seq) {
                PyObject *item = PySequence_GetItem(seq, i);
                if (PyLong_Check(item)) {
                        jlong value = PyLong_AsLongLong(item);
+                       Py_DECREF(item);
                        (*env)->SetLongArrayRegion(env, array, i, 1, &value);
                }
                else {
+                       Py_DECREF(item);
                        throwPythonException(env, "List item not an integer");
                        return NULL;
                }
@@ -856,9 +907,10 @@ jobject pythonDictionaryAsMap(JNIEnv *env, PyObject *dict) {
     ##VariableImpl(                                                     \
         JNIEnv *env, jobject thisObj, jlong contextID, jstring variableName, \
         jtype value) {                                                  \
-            PyObject *module = getModule(contextID);                    \
+            PyObject *module;                                           \
                                                                         \
             PyEval_RestoreThread(main_ts);                              \
+            module = getModule(contextID);                              \
             setPythonVariable(module, getPythonString(env, variableName), \
                               j2py(env, value));                        \
             PyEval_SaveThread();                                        \
@@ -883,8 +935,6 @@ JNIEXPORT void JNICALL
 Java_org_simantics_pythonlink_PythonContext_setPythonNDArrayVariableImpl(
                JNIEnv *env, jobject thisObj, jlong contextID, jstring variableName,
                jobject value) {
-       PyObject *module = getModule(contextID);
-
        if (!hasNumpy) {
                throwPythonException(env, "Importing numpy failed");
                return;
@@ -892,6 +942,7 @@ Java_org_simantics_pythonlink_PythonContext_setPythonNDArrayVariableImpl(
 
        PyEval_RestoreThread(main_ts);
        {
+               PyObject *module = getModule(contextID);
                PyObject *pythonName = getPythonString(env, variableName);
                PyObject *val = getPythonNDArray(env, value);
 
@@ -904,81 +955,88 @@ JNIEXPORT void JNICALL
 Java_org_simantics_pythonlink_PythonContext_setPythonVariantVariableImpl(
                JNIEnv *env, jobject thisObj, jlong contextID, jstring variableName,
                jobject value, jobject binding) {
-       PyObject *module = getModule(contextID);
+       PyObject *module;
 
        PyEval_RestoreThread(main_ts);
+       module = getModule(contextID);
        setPythonVariable(module, getPythonString(env, variableName),
                                          getPythonObject(env, value, binding));
        PyEval_SaveThread();
 }
 
+static PyObject *getExceptionMessage(PyObject *exceptionType, PyObject *exception, PyObject *traceback) {
+       PyObject *formatExc = NULL, *args = NULL;
+       PyObject *tracebackModule = PyImport_ImportModule("traceback");
+       if (!tracebackModule) {
+               return NULL;
+       }
+
+       if (exception && traceback) {
+               formatExc = PyDict_GetItemString(PyModule_GetDict(tracebackModule), "format_exception");
+               args = PyTuple_Pack(3, exceptionType, exception, traceback);
+       }
+       else if (exception) {
+               formatExc = PyDict_GetItemString(PyModule_GetDict(tracebackModule), "format_exception_only");
+               args = PyTuple_Pack(2, exceptionType, exception);
+       }
+
+       Py_DECREF(tracebackModule);
+
+       if (formatExc != NULL && args != NULL) {
+               PyObject *result = PyObject_CallObject(formatExc, args);
+               Py_XDECREF(args);
+               Py_XDECREF(formatExc);
+               return result;
+       }
+       else {
+               Py_XDECREF(args);
+               Py_XDECREF(formatExc);
+               return NULL;
+       }
+}
+
 JNIEXPORT jint JNICALL
 Java_org_simantics_pythonlink_PythonContext_executePythonStatementImpl(
                JNIEnv *env, jobject thisObj, jlong contextID, jstring statement) {
-       PyObject *module = getModule(contextID);
-
        const char *utfchars = (*env)->GetStringUTFChars(env, statement, NULL);
 
        PyEval_RestoreThread(main_ts);
        PyErr_Clear();
        {
-               jclass sclReportingWriterClass = (*env)->FindClass(env, SCL_REPORTING_WRITER_CLASS);
-               jmethodID constructor = (*env)->GetMethodID(env, sclReportingWriterClass, "<init>", "()V");
-               jmethodID flushMethod = (*env)->GetMethodID(env, sclReportingWriterClass, "flush", "()V");
+               PyObject *module = getModule(contextID);
 
                PyObject *globals;
 
                globals = PyModule_GetDict(module);
 
                currentEnv = env;
-               if (sclReportingWriterClass && constructor)
-                       sclWriter = (*env)->NewObject(env, sclReportingWriterClass, constructor);
-               else
-                       sclWriter = NULL;
 
                {
                        PyObject *result = PyRun_String(utfchars, Py_file_input, globals, globals);
 
                        PyObject *exceptionType = PyErr_Occurred();
                        if (exceptionType != NULL) {
-                               PyObject *exception, *traceback;
+                               PyObject *exception, *traceback, *message;
                                PyErr_Fetch(&exceptionType, &exception, &traceback);
 
-                               {
-                                       PyObject *tracebackModule = PyImport_ImportModule("traceback");
-                                       if (tracebackModule != NULL) {
-                                               PyObject *formatExc = PyDict_GetItemString(PyModule_GetDict(tracebackModule), "format_exception");
-                                               if (formatExc != NULL) {
-                                                       PyObject *args = PyTuple_Pack(3, exceptionType, exception, traceback);
-                                                       PyObject *message = PyObject_CallObject(formatExc, args);
-                                                       if (message != NULL) {
-                                                               PyObject *emptyStr = PyUnicode_FromString("");
-                                                               PyObject *joined = PyUnicode_Join(emptyStr, message);
-                                                               char *messageStr = PyUnicode_AsUTF8(joined);
-                                                               throwPythonException(env, messageStr);
-                                                               Py_DECREF(joined);
-                                                               Py_DECREF(emptyStr);
-                                                               Py_DECREF(message);
-                                                       }
-                                                       else {
-                                                               PyTypeObject
-                                                                       *ty = (PyTypeObject *)exceptionType;
-                                                               throwPythonException(
-                                                                               env, ty ? ty->tp_name
-                                                                                               : "Internal error, null exception type");
-                                                       }
-                                                       Py_DECREF(args);
-                                                       Py_DECREF(formatExc);
-                                               }
-                                               else {
-                                                       throwPythonException(env, "Internal error, no format_exc function");
-                                               }
-                                               Py_DECREF(tracebackModule);
-                                       }
-                                       else {
-                                               throwPythonException(env, "Internal error, no traceback module");
-                                       }
+                               message = getExceptionMessage(exceptionType, exception, traceback);
+                               if (message != NULL) {
+                                       PyObject *emptyStr = PyUnicode_FromString("");
+                                       PyObject *joined = PyUnicode_Join(emptyStr, message);
+                                       char *messageStr = PyUnicode_AsUTF8(joined);
+                                       throwPythonException(env, messageStr);
+                                       Py_DECREF(joined);
+                                       Py_DECREF(emptyStr);
+                                       Py_DECREF(message);
+                               }
+                               else {
+                                       PyTypeObject
+                                               *ty = (PyTypeObject *)exceptionType;
+                                       throwPythonException(
+                                                       env, ty ? ty->tp_name
+                                                                       : "Internal error, null exception type");
                                }
+
                                Py_XDECREF(exceptionType);
                                Py_XDECREF(exception);
                                Py_XDECREF(traceback);
@@ -987,12 +1045,7 @@ Java_org_simantics_pythonlink_PythonContext_executePythonStatementImpl(
                        PyEval_SaveThread();
                        (*env)->ReleaseStringUTFChars(env, statement, utfchars);
 
-                       if (sclWriter != NULL) {
-                               (*env)->CallVoidMethod(env, sclWriter, flushMethod);
-                       }
-
                        currentEnv = NULL;
-                       sclWriter = NULL;
 
                        return result != NULL ? 0 : 1;
                }