--- /dev/null
+/*******************************************************************************\r
+ * Copyright (c) 2010 Association for Decentralized Information Management in\r
+ * Industry THTH ry.\r
+ * All rights reserved. This program and the accompanying materials\r
+ * are made available under the terms of the Eclipse Public License v1.0\r
+ * which accompanies this distribution, and is available at\r
+ * http://www.eclipse.org/legal/epl-v10.html\r
+ *\r
+ * Contributors:\r
+ * VTT Technical Research Centre of Finland - initial API and implementation\r
+ *******************************************************************************/\r
+package org.simantics.databoard.method;
+
+import gnu.trove.map.hash.TObjectIntHashMap;\r
+\r
+import java.io.EOFException;\r
+import java.io.IOException;\r
+import java.net.Socket;\r
+import java.net.SocketException;\r
+import java.nio.charset.Charset;\r
+import java.util.ArrayList;\r
+import java.util.HashMap;\r
+import java.util.List;\r
+import java.util.Map;\r
+import java.util.concurrent.ConcurrentHashMap;\r
+import java.util.concurrent.CopyOnWriteArrayList;\r
+import java.util.concurrent.ExecutorService;\r
+import java.util.concurrent.Semaphore;\r
+import java.util.concurrent.SynchronousQueue;\r
+import java.util.concurrent.ThreadPoolExecutor;\r
+import java.util.concurrent.TimeUnit;\r
+import java.util.concurrent.atomic.AtomicInteger;\r
+\r
+import org.simantics.databoard.Bindings;\r
+import org.simantics.databoard.annotations.Union;\r
+import org.simantics.databoard.binding.Binding;\r
+import org.simantics.databoard.binding.RecordBinding;\r
+import org.simantics.databoard.binding.UnionBinding;\r
+import org.simantics.databoard.serialization.Serializer;\r
+import org.simantics.databoard.serialization.SerializerConstructionException;\r
+import org.simantics.databoard.util.binary.BinaryReadable;\r
+import org.simantics.databoard.util.binary.BinaryWriteable;\r
+import org.simantics.databoard.util.binary.InputStreamReadable;\r
+import org.simantics.databoard.util.binary.OutputStreamWriteable;\r
+
+/**
+ * Connection is a class that handles request-response communication over a
+ * socket.
+ * <p>
+ * Requests have asynchronous result. The result can be acquired using one of
+ * the three methods:
+ * 1) Blocking read AsyncResult.waitForResponse()
+ * 2) Poll AsyncResult.getResponse()
+ * 3) Listen AsyncResult.setListener()
+ * <p>
+ * The socket must be established before Connection is instantiated.
+ * Closing connection does not close its Socket.
+ * If the socket is closed before connection there an error is thrown.
+ * The error is available by placing listener.
+ * The proper order to close a connection is to close Connection first
+ * and then Socket.
+ *
+ * @author Toni Kalajainen <toni.kalajainen@vtt.fi>
+ */
+public class TcpConnection implements MethodInterface {
+\r
+ public static final ExecutorService SHARED_EXECUTOR_SERVICE = \r
+ new ThreadPoolExecutor(0, Integer.MAX_VALUE, 100L, TimeUnit.MILLISECONDS, new SynchronousQueue<Runnable>());\r
+
+ static final Serializer MESSAGE_SERIALIZER = Bindings.getSerializerUnchecked( Bindings.getBindingUnchecked(Message.class) );
+ static Charset UTF8 = Charset.forName("UTF8");
+
+ Handshake local, remote;
+
+ Interface remoteType;
+ MethodTypeDefinition[] localMethods, remoteMethods;
+ HashMap<MethodTypeDefinition, Integer> localMethodsMap, remoteMethodsMap;
+
+ Socket socket;
+
+ // if false, there is an error in the socket or the connection has been shutdown
+ boolean active = true;
+
+ // Objects used for handling local services
+ MethodInterface methodInterface;
+
+ // Objects used for reading data
+ ConcurrentHashMap<Integer, PendingRequest> requests = new ConcurrentHashMap<Integer, PendingRequest>();
+ List<Object> inIdentities = new ArrayList<Object>();
+ BinaryReadable in;
+ int maxRecvSize;
+
+ // Object used for writing data
+ public ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE;
+ TObjectIntHashMap<Object> outIdentities = new TObjectIntHashMap<Object>();
+ BinaryWriteable out;
+ AtomicInteger requestCounter = new AtomicInteger(0);
+ int maxSendSize;
+
+ // Cached method descriptions
+ Map<String, MethodType> methodTypes = new ConcurrentHashMap<String, MethodType>();
+
+ /**
+ * Handshake a socket
+ *
+ * @param socket
+ * @param localData local data
+ * @return the remote data
+ * @throws IOException
+ * @throws RuntimeException unexpected error (BindingException or EncodingException)
+ */
+ public static Handshake handshake(final Socket socket, final Handshake localData)
+ throws IOException
+ {
+ final BinaryReadable bin = new InputStreamReadable( socket.getInputStream(), Long.MAX_VALUE );
+ final BinaryWriteable bout = new OutputStreamWriteable( socket.getOutputStream() );
+ ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE;
+
+ // do hand-shake
+ final Exception[] writeError = new Exception[1];
+ final Semaphore s = new Semaphore(0);
+ writeExecutor.execute(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ TObjectIntHashMap<Object> outIdentities = new TObjectIntHashMap<Object>();
+ Handshake.SERIALIZER.serialize(bout, outIdentities, localData);
+ bout.flush();
+ outIdentities.clear();
+ } catch (IOException e) {
+ writeError[0] = e;
+ } finally {
+ s.release(1);
+ }
+ }});
+
+ // Read remote peer's handshake
+ List<Object> inIdentities = new ArrayList<Object>();
+ Handshake result = (Handshake) Handshake.SERIALIZER.deserialize(bin, inIdentities);
+ inIdentities.clear();
+
+ // Check that write was ok
+ try {
+ s.acquire(1);
+ Exception e = writeError[0];
+ if (e!=null && e instanceof IOException)
+ throw (IOException) e;
+ if (e!=null)
+ throw new RuntimeException(e);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ } finally {
+// writeExecutor.shutdown();
+ }
+
+ return result;
+ }
+
+ /**
+ * Create a connection to a hand-shaken socket
+ *
+ * @param socket
+ * @param methodInterface local method handler
+ * @param localData
+ * @param remoteData
+ * @throws IOException
+ */
+ public TcpConnection(Socket socket, MethodInterface methodInterface, Handshake localData, Handshake remoteData)
+ throws IOException {
+ if (socket==null || localData==null || remoteData==null)
+ throw new IllegalArgumentException("null arg");
+
+ this.methodInterface = methodInterface;
+ this.socket = socket;
+ this.local = localData;
+ this.remote = remoteData;
+ this.maxSendSize = Math.min(localData.sendMsgLimit, remoteData.recvMsgLimit);
+ this.maxRecvSize = Math.min(localData.recvMsgLimit, remoteData.sendMsgLimit);
+
+ this.localMethods = local.methods;
+ this.remoteMethods = remote.methods;
+ this.remoteType = new Interface(this.remoteMethods);
+ this.localMethodsMap = new HashMap<MethodTypeDefinition, Integer>();
+ this.remoteMethodsMap = new HashMap<MethodTypeDefinition, Integer>();
+ for (int i=0; i<localMethods.length; i++)
+ localMethodsMap.put(localMethods[i], i);
+ for (int i=0; i<remoteMethods.length; i++)
+ remoteMethodsMap.put(remoteMethods[i], i);
+// remoteMethodsMap.trimToSize();
+// localMethodsMap.trimToSize();
+
+ in = new InputStreamReadable( socket.getInputStream(), Long.MAX_VALUE );
+ out = new OutputStreamWriteable( socket.getOutputStream() );\r
+ \r
+ String threadName = "Connection-"+socket.getInetAddress().getHostAddress()+":"+socket.getPort();\r
+ \r
+ thread.setName( threadName );
+ thread.start();
+ }
+
+ @Override
+ public Interface getInterface() {
+ return remoteType;
+ }
+
+ @Override
+ public Method getMethod(MethodTypeBinding binding)
+ throws MethodNotSupportedException {
+ // consumer suggests object bindings
+ MethodTypeDefinition description = binding.getMethodDefinition();
+ \r
+ if (!remoteMethodsMap.containsKey(description)) {\r
+ /*\r
+ System.out.println("Method not found: "+description);\r
+ System.out.println("Existing methods:" );\r
+ for (MethodTypeDefinition k : remoteMethodsMap.keySet()) {\r
+ System.out.print(k);\r
+ if (k.getType().requestType.getComponentCount()>0) {\r
+ System.out.print(System.identityHashCode( k.getType().requestType.getComponentType(0) ) );\r
+ }\r
+ System.out.println(); \r
+ }\r
+*/
+ throw new MethodNotSupportedException(description.getName());
+ }
+
+ int id = remoteMethodsMap.get(description);
+
+ try {
+ return new MethodImpl(id, binding);
+ } catch (SerializerConstructionException e) {
+ throw new MethodNotSupportedException(e);
+ }
+ }
+
+ @Override
+ public Method getMethod(MethodTypeDefinition description)
+ throws MethodNotSupportedException {
+ // producer suggests object bindings
+ if (!remoteMethodsMap.containsKey(description)) {
+ throw new MethodNotSupportedException(description.getName());
+ }
+ int id = remoteMethodsMap.get(description);
+
+ RecordBinding reqBinding = (RecordBinding) Bindings.getMutableBinding(description.getType().getRequestType());
+ Binding resBinding = Bindings.getMutableBinding(description.getType().getResponseType());
+ UnionBinding errBinding = (UnionBinding) Bindings.getMutableBinding(description.getType().getErrorType());
+ MethodTypeBinding binding = new MethodTypeBinding(description, reqBinding, resBinding, errBinding);
+
+ try {
+ return new MethodImpl(id, binding);
+ } catch (SerializerConstructionException e) {
+ // Generic binding should work
+ throw new MethodNotSupportedException(e);
+ }
+ }
+
+ public Socket getSocket()
+ {
+ return socket;
+ }
+
+ public interface ConnectionListener {
+ /**
+ * There was an error and connection was closed
+ *
+ * @param error
+ */
+ void onError(Exception error);
+
+ /**
+ * close() was invoked
+ */
+ void onClosed();
+ }
+ \r
+ CopyOnWriteArrayList<ConnectionListener> listeners = new CopyOnWriteArrayList<ConnectionListener>();
+
+ public synchronized void addConnectionListener(ConnectionListener listener) {
+ listeners.add( listener );
+ }\r
+ \r
+ public void removeConnectionListener(ConnectionListener listener) {\r
+ listeners.remove( listener );\r
+ }
+
+ class MethodImpl implements Method {
+ int methodId;
+ MethodTypeBinding methodBinding;
+ Serializer responseSerializer;
+ Serializer requestSerializer;
+ Serializer errorSerializer;
+
+ MethodImpl(int methodId, MethodTypeBinding methodBinding) throws SerializerConstructionException
+ {
+ this.methodId = methodId;
+ this.methodBinding = methodBinding;
+ this.requestSerializer = Bindings.getSerializer( methodBinding.getRequestBinding() );
+ this.responseSerializer = Bindings.getSerializer( methodBinding.getResponseBinding() );
+ this.errorSerializer = Bindings.getSerializer( methodBinding.getErrorBinding() );
+ }
+
+ @Override
+ public AsyncResult invoke(final Object request) {
+ // Write, async
+ final PendingRequest result = new PendingRequest(this, requestCounter.getAndIncrement());
+ requests.put(result.requestId, result);\r
+
+ if (!active) {
+ result.setInvokeException(new InvokeException(new ConnectionClosedException()));
+ } else {
+ writeExecutor.execute(new Runnable() {
+ @Override
+ public void run() { \r
+ synchronized(TcpConnection.this) {
+ try {
+ int size= requestSerializer.getSize(request, outIdentities);
+ if (size>maxSendSize) {
+ result.setInvokeException(new InvokeException(new MessageOverflowException()));
+ return;
+ }
+ outIdentities.clear();
+
+ RequestHeader reqHeader = new RequestHeader();
+ reqHeader.methodId = methodId;
+ reqHeader.requestId = result.requestId;
+ MESSAGE_SERIALIZER.serialize(out, outIdentities, reqHeader);
+ outIdentities.clear();
+ out.writeInt(size);
+ requestSerializer.serialize(out, outIdentities, request);
+ outIdentities.clear();
+ out.flush();
+ } catch (IOException e) {
+ result.setInvokeException(new InvokeException(e));
+ } catch (RuntimeException e) {
+ result.setInvokeException(new InvokeException(e));
+ }\r
+ }
+ }});
+ }
+ return result;
+ }
+
+ @Override
+ public MethodTypeBinding getMethodBinding() {
+ return methodBinding;
+ }
+ }
+ \r
+ void setClosed() \r
+ {\r
+ for (ConnectionListener listener : listeners)\r
+ listener.onClosed();\r
+ }
+ void setError(Exception e)
+ {\r
+ for (ConnectionListener listener : listeners)
+ listener.onError(e);
+ close();
+ }
+
+ /**
+ * Get method interface that handles services locally (service requests by peer)
+ *
+ * @return local method interface
+ */
+ public MethodInterface getLocalMethodInterface()
+ {
+ return methodInterface;
+ }
+
+ /**
+ * Get method interface that handles services locally (service requests by peer)
+ *
+ * @return local method interface
+ */
+ public MethodTypeDefinition[] getLocalMethodDescriptions()
+ {
+ return localMethods;
+ }
+ \r
+ public MethodInterface getRemoteMethodInterface() {\r
+ return this;\r
+ }\r
+
+ /**
+ * Close the connection. All pending service request are canceled.
+ * The socket is not closed.
+ */
+ public void close() {
+ active = false;
+ // cancel all pending requests
+ ArrayList<PendingRequest> reqs = new ArrayList<PendingRequest>(requests.values());
+ for (PendingRequest pr : reqs) {
+ pr.setInvokeException(new InvokeException(new ConnectionClosedException()));
+ }
+ requests.values().removeAll(reqs);
+ // shutdown inthread
+ thread.interrupt();
+// for (ConnectionListener listener : listeners)\r
+// listener.onClosed();\r
+ }\r
+ \r
+ /**\r
+ * Get the active connection of current thread\r
+ * \r
+ * @return Connection or <code>null</code> if current thread does not run connection\r
+ */\r
+ public static TcpConnection getCurrentConnection() {\r
+ Thread t = Thread.currentThread();\r
+ if (t instanceof ConnectionThread == false) return null;\r
+ ConnectionThread ct = (ConnectionThread) t;\r
+ return ct.getConnection();\r
+ }\r
+ \r
+ /**\r
+ * Connection Thread deserializes incoming messages from the TCP Stream.\r
+ *\r
+ */\r
+ class ConnectionThread extends Thread {\r
+ public ConnectionThread() {\r
+ setDaemon(true);\r
+ }\r
+ \r
+ public TcpConnection getConnection() {\r
+ return TcpConnection.this;\r
+ }\r
+ \r
+ public void run() {\r
+ while (!Thread.interrupted()) {\r
+ try {\r
+ Message header = (Message) MESSAGE_SERIALIZER.deserialize(in, inIdentities);\r
+ if (header instanceof RequestHeader) {\r
+ final RequestHeader reqHeader = (RequestHeader) header;\r
+\r
+ int size = in.readInt();\r
+ if (size>maxRecvSize) {\r
+ setError(new MessageOverflowException());\r
+ return;\r
+ }\r
+ \r
+ int methodId = reqHeader.methodId;\r
+ if (methodId<0||methodId>=localMethods.length) {\r
+ setError(new Exception("ProtocolError"));\r
+ return;\r
+ }\r
+ MethodTypeDefinition methodDescription = localMethods[methodId];\r
+ // Let back-end determine bindings\r
+ try {\r
+ final Method method = methodInterface.getMethod(methodDescription);\r
+ final MethodTypeBinding methodBinding = method.getMethodBinding();\r
+ // Deserialize payload \r
+ final Object request = Bindings.getSerializerUnchecked(methodBinding.getRequestBinding()).deserialize(in, inIdentities);\r
+ inIdentities.clear();\r
+ \r
+ // Invoke method\r
+ method.invoke(request).setListener(new InvokeListener() {\r
+ @Override\r
+ public void onCompleted(final Object response) {\r
+ // Write RESP\r
+ writeExecutor.execute(new Runnable() {\r
+ @Override\r
+ public void run() {\r
+ synchronized(TcpConnection.this) {\r
+ try {\r
+ Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getResponseBinding());\r
+ int size = serializer.getSize(response, outIdentities);\r
+ outIdentities.clear();\r
+ if (size > maxSendSize) {\r
+ ResponseTooLargeError tooLarge = new ResponseTooLargeError();\r
+ tooLarge.requestId = reqHeader.requestId;\r
+ MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge);\r
+ outIdentities.clear(); \r
+ return;\r
+ }\r
+\r
+ ResponseHeader respHeader = new ResponseHeader();\r
+ respHeader.requestId = reqHeader.requestId;\r
+ MESSAGE_SERIALIZER.serialize(out, outIdentities, respHeader);\r
+ outIdentities.clear();\r
+ out.writeInt(size);\r
+ \r
+ serializer.serialize(out, outIdentities, response);\r
+ outIdentities.clear();\r
+ out.flush();\r
+ } catch (IOException e) {\r
+ setError(e);\r
+ } catch (RuntimeException e) {\r
+ setError(e);\r
+ }\r
+ }\r
+ }});\r
+ }\r
+ @Override\r
+ public void onException(final Exception cause) {\r
+ // Write ERRO\r
+ writeExecutor.execute(new Runnable() {\r
+ @Override\r
+ public void run() {\r
+ synchronized(TcpConnection.this) {\r
+ try {\r
+ Exception_ msg = new Exception_();\r
+ msg.message = cause.getClass().getName()+": "+cause.getMessage(); \r
+ \r
+ MESSAGE_SERIALIZER.serialize(out, outIdentities, msg);\r
+ outIdentities.clear();\r
+ out.flush();\r
+ } catch (IOException e) {\r
+ setError(e);\r
+ } catch (RuntimeException e) {\r
+ setError(e);\r
+ }\r
+ }\r
+ }}); \r
+ }\r
+ @Override\r
+ public void onExecutionError(final Object error) {\r
+ // Write ERRO\r
+ writeExecutor.execute(new Runnable() {\r
+ @Override\r
+ public void run() {\r
+ synchronized(TcpConnection.this) {\r
+ try {\r
+ Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getErrorBinding());\r
+ int size = serializer.getSize(error, outIdentities);\r
+ outIdentities.clear();\r
+ \r
+ if (size > maxSendSize) {\r
+ ResponseTooLargeError tooLarge = new ResponseTooLargeError();\r
+ tooLarge.requestId = reqHeader.requestId;\r
+ MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge);\r
+ outIdentities.clear(); \r
+ return;\r
+ }\r
+ \r
+ ExecutionError_ errorHeader = new ExecutionError_();\r
+ errorHeader.requestId = reqHeader.requestId;\r
+ MESSAGE_SERIALIZER.serialize(out, outIdentities, errorHeader);\r
+ outIdentities.clear();\r
+ out.writeInt(size);\r
+ serializer.serialize(out, outIdentities, error);\r
+ outIdentities.clear();\r
+ out.flush();\r
+ } catch (IOException e) {\r
+ setError(e);\r
+ } catch (RuntimeException e) {\r
+ setError(e);\r
+ }\r
+ }\r
+ }}); \r
+ }});\r
+\r
+ } catch (MethodNotSupportedException e) {\r
+ in.skipBytes(size);\r
+ // return with an error\r
+ final InvalidMethodError error = new InvalidMethodError();\r
+ error.requestId = reqHeader.requestId;\r
+ writeExecutor.execute(new Runnable() {\r
+ @Override\r
+ public void run() {\r
+ synchronized(TcpConnection.this) {\r
+ try {\r
+ MESSAGE_SERIALIZER.serialize(out, outIdentities, error);\r
+ outIdentities.clear();\r
+ out.flush();\r
+ } catch (IOException e) {\r
+ setError(e);\r
+ } catch (RuntimeException e) {\r
+ setError(e);\r
+ }\r
+ }\r
+ }}); \r
+ continue; \r
+ } \r
+ \r
+ \r
+ } else if (header instanceof ResponseHeader) {\r
+ int requestId = ((ResponseHeader)header).requestId;\r
+ PendingRequest req = requests.remove(requestId);\r
+ if (req==null) {\r
+ setError(new RuntimeException("Request by id "+requestId+" does not exist"));\r
+ return; \r
+ } \r
+ int size = in.readInt();\r
+ if (size>maxRecvSize) {\r
+ // TODO SOMETHING\r
+ }\r
+ Object response = req.method.responseSerializer.deserialize(in, inIdentities);\r
+ inIdentities.clear();\r
+ req.setResponse(response);\r
+ } else if (header instanceof ExecutionError_) { \r
+ int requestId = ((ExecutionError_)header).requestId;\r
+ PendingRequest req = requests.remove(requestId);\r
+ if (req==null) {\r
+ setError(new RuntimeException("Request by id "+requestId+" does not exist"));\r
+ return;\r
+ }\r
+ int size = in.readInt();\r
+ if (size>maxRecvSize) {\r
+ // TODO SOMETHING\r
+ }\r
+ Object executionError = req.method.errorSerializer.deserialize(in, inIdentities);\r
+ inIdentities.clear();\r
+ req.setExecutionError(executionError);\r
+ } else if (header instanceof Exception_) {\r
+ int requestId = ((Exception_)header).requestId;\r
+ PendingRequest req = requests.remove(requestId);\r
+ req.setExecutionError(new Exception(((Exception_)header).message));\r
+ } else if (header instanceof InvalidMethodError) {\r
+ int requestId = ((InvalidMethodError)header).requestId;\r
+ PendingRequest req = requests.remove(requestId);\r
+ req.setInvokeException(new InvokeException(new MethodNotSupportedException("?")));\r
+ } else if (header instanceof ResponseTooLargeError) {\r
+ int requestId = ((ResponseTooLargeError)header).requestId;\r
+ PendingRequest req = requests.remove(requestId);\r
+ req.setInvokeException(new InvokeException(new MessageOverflowException()));\r
+ }\r
+ \r
+ } catch (EOFException e) {\r
+ setClosed();\r
+ break;\r
+ } catch (SocketException e) {\r
+ if (e.getMessage().equals("Socket Closed"))\r
+ setClosed();\r
+ else\r
+ setError(e);\r
+ break;\r
+ } catch (IOException e) {\r
+ setError(e);\r
+ break;\r
+ }\r
+ }\r
+ try {\r
+ socket.close();\r
+ } catch (IOException e) {\r
+ }\r
+ // Close pending requests\r
+ close();\r
+ };\r
+ }
+
+ // Thread that reads input data
+ ConnectionThread thread = new ConnectionThread();
+
+ class PendingRequest extends AsyncResultImpl {
+
+ MethodImpl method;
+
+ // request id
+ int requestId;
+
+ public PendingRequest(MethodImpl method, int requestId) {
+ this.method = method;
+ this.requestId = requestId;
+ }
+ }
+
+
+ @Union({RequestHeader.class, ResponseHeader.class, ExecutionError_.class, Exception_.class, InvalidMethodError.class, ResponseTooLargeError.class})
+ public static class Message {}
+
+ public static class RequestHeader extends Message {
+ public int requestId;
+ public int methodId;
+ // Request Object
+ public RequestHeader() {}
+ }
+
+ public static class ResponseHeader extends Message {
+ public int requestId;
+ // Response object
+ public ResponseHeader() {}
+ }
+
+ // Error while invoking a method
+ public static class ExecutionError_ extends Message {
+ public int requestId;
+ // Error object
+ public ExecutionError_() {}
+ }
+
+ // MethodName does not exist
+ public static class InvalidMethodError extends Message {
+ public int requestId;
+ public InvalidMethodError() {}
+ }
+
+ // Exception, not in method but somewhere else
+ public static class Exception_ extends Message {
+ public int requestId;
+ public String message;
+ public Exception_() {}
+ }
+
+ public static class ResponseTooLargeError extends Message {
+ public int requestId;
+ public ResponseTooLargeError() {}
+ }
+
+}
+