/******************************************************************************* * Copyright (c) 2010 Association for Decentralized Information Management in * Industry THTH ry. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * VTT Technical Research Centre of Finland - initial API and implementation *******************************************************************************/ package org.simantics.databoard.method; import gnu.trove.map.hash.TObjectIntHashMap; import java.io.EOFException; import java.io.IOException; import java.net.Socket; import java.net.SocketException; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.simantics.databoard.Bindings; import org.simantics.databoard.annotations.Union; import org.simantics.databoard.binding.Binding; import org.simantics.databoard.binding.RecordBinding; import org.simantics.databoard.binding.UnionBinding; import org.simantics.databoard.serialization.Serializer; import org.simantics.databoard.serialization.SerializerConstructionException; import org.simantics.databoard.util.binary.BinaryReadable; import org.simantics.databoard.util.binary.BinaryWriteable; import org.simantics.databoard.util.binary.InputStreamReadable; import org.simantics.databoard.util.binary.OutputStreamWriteable; /** * Connection is a class that handles request-response communication over a * socket. *

* Requests have asynchronous result. The result can be acquired using one of * the three methods: * 1) Blocking read AsyncResult.waitForResponse() * 2) Poll AsyncResult.getResponse() * 3) Listen AsyncResult.setListener() *

* The socket must be established before Connection is instantiated. * Closing connection does not close its Socket. * If the socket is closed before connection there an error is thrown. * The error is available by placing listener. * The proper order to close a connection is to close Connection first * and then Socket. * * @author Toni Kalajainen */ public class TcpConnection implements MethodInterface { public static final ExecutorService SHARED_EXECUTOR_SERVICE = new ThreadPoolExecutor(0, Integer.MAX_VALUE, 100L, TimeUnit.MILLISECONDS, new SynchronousQueue()); static final Serializer MESSAGE_SERIALIZER = Bindings.getSerializerUnchecked( Bindings.getBindingUnchecked(Message.class) ); static Charset UTF8 = Charset.forName("UTF8"); Handshake local, remote; Interface remoteType; MethodTypeDefinition[] localMethods, remoteMethods; HashMap localMethodsMap, remoteMethodsMap; Socket socket; // if false, there is an error in the socket or the connection has been shutdown boolean active = true; // Objects used for handling local services MethodInterface methodInterface; // Objects used for reading data ConcurrentHashMap requests = new ConcurrentHashMap(); List inIdentities = new ArrayList(); BinaryReadable in; int maxRecvSize; // Object used for writing data public ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE; TObjectIntHashMap outIdentities = new TObjectIntHashMap(); BinaryWriteable out; AtomicInteger requestCounter = new AtomicInteger(0); int maxSendSize; // Cached method descriptions Map methodTypes = new ConcurrentHashMap(); /** * Handshake a socket * * @param socket * @param localData local data * @return the remote data * @throws IOException * @throws RuntimeException unexpected error (BindingException or EncodingException) */ public static Handshake handshake(final Socket socket, final Handshake localData) throws IOException { final BinaryReadable bin = new InputStreamReadable( socket.getInputStream(), Long.MAX_VALUE ); final BinaryWriteable bout = new OutputStreamWriteable( socket.getOutputStream() ); ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE; // do hand-shake final Exception[] writeError = new Exception[1]; final Semaphore s = new Semaphore(0); writeExecutor.execute(new Runnable() { @Override public void run() { try { TObjectIntHashMap outIdentities = new TObjectIntHashMap(); Handshake.SERIALIZER.serialize(bout, outIdentities, localData); bout.flush(); outIdentities.clear(); } catch (IOException e) { writeError[0] = e; } finally { s.release(1); } }}); // Read remote peer's handshake List inIdentities = new ArrayList(); Handshake result = (Handshake) Handshake.SERIALIZER.deserialize(bin, inIdentities); inIdentities.clear(); // Check that write was ok try { s.acquire(1); Exception e = writeError[0]; if (e!=null && e instanceof IOException) throw (IOException) e; if (e!=null) throw new RuntimeException(e); } catch (InterruptedException e) { throw new RuntimeException(e); } finally { // writeExecutor.shutdown(); } return result; } /** * Create a connection to a hand-shaken socket * * @param socket * @param methodInterface local method handler * @param localData * @param remoteData * @throws IOException */ public TcpConnection(Socket socket, MethodInterface methodInterface, Handshake localData, Handshake remoteData) throws IOException { if (socket==null || localData==null || remoteData==null) throw new IllegalArgumentException("null arg"); this.methodInterface = methodInterface; this.socket = socket; this.local = localData; this.remote = remoteData; this.maxSendSize = Math.min(localData.sendMsgLimit, remoteData.recvMsgLimit); this.maxRecvSize = Math.min(localData.recvMsgLimit, remoteData.sendMsgLimit); this.localMethods = local.methods; this.remoteMethods = remote.methods; this.remoteType = new Interface(this.remoteMethods); this.localMethodsMap = new HashMap(); this.remoteMethodsMap = new HashMap(); for (int i=0; i0) { System.out.print(System.identityHashCode( k.getType().requestType.getComponentType(0) ) ); } System.out.println(); } */ throw new MethodNotSupportedException(description.getName()); } int id = remoteMethodsMap.get(description); try { return new MethodImpl(id, binding); } catch (SerializerConstructionException e) { throw new MethodNotSupportedException(e); } } @Override public Method getMethod(MethodTypeDefinition description) throws MethodNotSupportedException { // producer suggests object bindings if (!remoteMethodsMap.containsKey(description)) { throw new MethodNotSupportedException(description.getName()); } int id = remoteMethodsMap.get(description); RecordBinding reqBinding = (RecordBinding) Bindings.getMutableBinding(description.getType().getRequestType()); Binding resBinding = Bindings.getMutableBinding(description.getType().getResponseType()); UnionBinding errBinding = (UnionBinding) Bindings.getMutableBinding(description.getType().getErrorType()); MethodTypeBinding binding = new MethodTypeBinding(description, reqBinding, resBinding, errBinding); try { return new MethodImpl(id, binding); } catch (SerializerConstructionException e) { // Generic binding should work throw new MethodNotSupportedException(e); } } public Socket getSocket() { return socket; } public interface ConnectionListener { /** * There was an error and connection was closed * * @param error */ void onError(Exception error); /** * close() was invoked */ void onClosed(); } CopyOnWriteArrayList listeners = new CopyOnWriteArrayList(); public synchronized void addConnectionListener(ConnectionListener listener) { listeners.add( listener ); } public void removeConnectionListener(ConnectionListener listener) { listeners.remove( listener ); } class MethodImpl implements Method { int methodId; MethodTypeBinding methodBinding; Serializer responseSerializer; Serializer requestSerializer; Serializer errorSerializer; MethodImpl(int methodId, MethodTypeBinding methodBinding) throws SerializerConstructionException { this.methodId = methodId; this.methodBinding = methodBinding; this.requestSerializer = Bindings.getSerializer( methodBinding.getRequestBinding() ); this.responseSerializer = Bindings.getSerializer( methodBinding.getResponseBinding() ); this.errorSerializer = Bindings.getSerializer( methodBinding.getErrorBinding() ); } @Override public AsyncResult invoke(final Object request) { // Write, async final PendingRequest result = new PendingRequest(this, requestCounter.getAndIncrement()); requests.put(result.requestId, result); if (!active) { result.setInvokeException(new InvokeException(new ConnectionClosedException())); } else { writeExecutor.execute(new Runnable() { @Override public void run() { synchronized(TcpConnection.this) { try { int size= requestSerializer.getSize(request, outIdentities); if (size>maxSendSize) { result.setInvokeException(new InvokeException(new MessageOverflowException())); return; } outIdentities.clear(); RequestHeader reqHeader = new RequestHeader(); reqHeader.methodId = methodId; reqHeader.requestId = result.requestId; MESSAGE_SERIALIZER.serialize(out, outIdentities, reqHeader); outIdentities.clear(); out.writeInt(size); requestSerializer.serialize(out, outIdentities, request); outIdentities.clear(); out.flush(); } catch (IOException e) { result.setInvokeException(new InvokeException(e)); } catch (RuntimeException e) { result.setInvokeException(new InvokeException(e)); } } }}); } return result; } @Override public MethodTypeBinding getMethodBinding() { return methodBinding; } } void setClosed() { for (ConnectionListener listener : listeners) listener.onClosed(); } void setError(Exception e) { for (ConnectionListener listener : listeners) listener.onError(e); close(); } /** * Get method interface that handles services locally (service requests by peer) * * @return local method interface */ public MethodInterface getLocalMethodInterface() { return methodInterface; } /** * Get method interface that handles services locally (service requests by peer) * * @return local method interface */ public MethodTypeDefinition[] getLocalMethodDescriptions() { return localMethods; } public MethodInterface getRemoteMethodInterface() { return this; } /** * Close the connection. All pending service request are canceled. * The socket is not closed. */ public void close() { active = false; // cancel all pending requests ArrayList reqs = new ArrayList(requests.values()); for (PendingRequest pr : reqs) { pr.setInvokeException(new InvokeException(new ConnectionClosedException())); } requests.values().removeAll(reqs); // shutdown inthread thread.interrupt(); // for (ConnectionListener listener : listeners) // listener.onClosed(); } /** * Get the active connection of current thread * * @return Connection or null if current thread does not run connection */ public static TcpConnection getCurrentConnection() { Thread t = Thread.currentThread(); if (t instanceof ConnectionThread == false) return null; ConnectionThread ct = (ConnectionThread) t; return ct.getConnection(); } /** * Connection Thread deserializes incoming messages from the TCP Stream. * */ class ConnectionThread extends Thread { public ConnectionThread() { setDaemon(true); } public TcpConnection getConnection() { return TcpConnection.this; } public void run() { while (!Thread.interrupted()) { try { Message header = (Message) MESSAGE_SERIALIZER.deserialize(in, inIdentities); if (header instanceof RequestHeader) { final RequestHeader reqHeader = (RequestHeader) header; int size = in.readInt(); if (size>maxRecvSize) { setError(new MessageOverflowException()); return; } int methodId = reqHeader.methodId; if (methodId<0||methodId>=localMethods.length) { setError(new Exception("ProtocolError")); return; } MethodTypeDefinition methodDescription = localMethods[methodId]; // Let back-end determine bindings try { final Method method = methodInterface.getMethod(methodDescription); final MethodTypeBinding methodBinding = method.getMethodBinding(); // Deserialize payload final Object request = Bindings.getSerializerUnchecked(methodBinding.getRequestBinding()).deserialize(in, inIdentities); inIdentities.clear(); // Invoke method method.invoke(request).setListener(new InvokeListener() { @Override public void onCompleted(final Object response) { // Write RESP writeExecutor.execute(new Runnable() { @Override public void run() { synchronized(TcpConnection.this) { try { Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getResponseBinding()); int size = serializer.getSize(response, outIdentities); outIdentities.clear(); if (size > maxSendSize) { ResponseTooLargeError tooLarge = new ResponseTooLargeError(); tooLarge.requestId = reqHeader.requestId; MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge); outIdentities.clear(); return; } ResponseHeader respHeader = new ResponseHeader(); respHeader.requestId = reqHeader.requestId; MESSAGE_SERIALIZER.serialize(out, outIdentities, respHeader); outIdentities.clear(); out.writeInt(size); serializer.serialize(out, outIdentities, response); outIdentities.clear(); out.flush(); } catch (IOException e) { setError(e); } catch (RuntimeException e) { setError(e); } } }}); } @Override public void onException(final Exception cause) { // Write ERRO writeExecutor.execute(new Runnable() { @Override public void run() { synchronized(TcpConnection.this) { try { Exception_ msg = new Exception_(); msg.message = cause.getClass().getName()+": "+cause.getMessage(); MESSAGE_SERIALIZER.serialize(out, outIdentities, msg); outIdentities.clear(); out.flush(); } catch (IOException e) { setError(e); } catch (RuntimeException e) { setError(e); } } }}); } @Override public void onExecutionError(final Object error) { // Write ERRO writeExecutor.execute(new Runnable() { @Override public void run() { synchronized(TcpConnection.this) { try { Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getErrorBinding()); int size = serializer.getSize(error, outIdentities); outIdentities.clear(); if (size > maxSendSize) { ResponseTooLargeError tooLarge = new ResponseTooLargeError(); tooLarge.requestId = reqHeader.requestId; MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge); outIdentities.clear(); return; } ExecutionError_ errorHeader = new ExecutionError_(); errorHeader.requestId = reqHeader.requestId; MESSAGE_SERIALIZER.serialize(out, outIdentities, errorHeader); outIdentities.clear(); out.writeInt(size); serializer.serialize(out, outIdentities, error); outIdentities.clear(); out.flush(); } catch (IOException e) { setError(e); } catch (RuntimeException e) { setError(e); } } }}); }}); } catch (MethodNotSupportedException e) { in.skipBytes(size); // return with an error final InvalidMethodError error = new InvalidMethodError(); error.requestId = reqHeader.requestId; writeExecutor.execute(new Runnable() { @Override public void run() { synchronized(TcpConnection.this) { try { MESSAGE_SERIALIZER.serialize(out, outIdentities, error); outIdentities.clear(); out.flush(); } catch (IOException e) { setError(e); } catch (RuntimeException e) { setError(e); } } }}); continue; } } else if (header instanceof ResponseHeader) { int requestId = ((ResponseHeader)header).requestId; PendingRequest req = requests.remove(requestId); if (req==null) { setError(new RuntimeException("Request by id "+requestId+" does not exist")); return; } int size = in.readInt(); if (size>maxRecvSize) { // TODO SOMETHING } Object response = req.method.responseSerializer.deserialize(in, inIdentities); inIdentities.clear(); req.setResponse(response); } else if (header instanceof ExecutionError_) { int requestId = ((ExecutionError_)header).requestId; PendingRequest req = requests.remove(requestId); if (req==null) { setError(new RuntimeException("Request by id "+requestId+" does not exist")); return; } int size = in.readInt(); if (size>maxRecvSize) { // TODO SOMETHING } Object executionError = req.method.errorSerializer.deserialize(in, inIdentities); inIdentities.clear(); req.setExecutionError(executionError); } else if (header instanceof Exception_) { int requestId = ((Exception_)header).requestId; PendingRequest req = requests.remove(requestId); req.setExecutionError(new Exception(((Exception_)header).message)); } else if (header instanceof InvalidMethodError) { int requestId = ((InvalidMethodError)header).requestId; PendingRequest req = requests.remove(requestId); req.setInvokeException(new InvokeException(new MethodNotSupportedException("?"))); } else if (header instanceof ResponseTooLargeError) { int requestId = ((ResponseTooLargeError)header).requestId; PendingRequest req = requests.remove(requestId); req.setInvokeException(new InvokeException(new MessageOverflowException())); } } catch (EOFException e) { setClosed(); break; } catch (SocketException e) { if (e.getMessage().equals("Socket Closed")) setClosed(); else setError(e); break; } catch (IOException e) { setError(e); break; } } try { socket.close(); } catch (IOException e) { } // Close pending requests close(); }; } // Thread that reads input data ConnectionThread thread = new ConnectionThread(); class PendingRequest extends AsyncResultImpl { MethodImpl method; // request id int requestId; public PendingRequest(MethodImpl method, int requestId) { this.method = method; this.requestId = requestId; } } @Union({RequestHeader.class, ResponseHeader.class, ExecutionError_.class, Exception_.class, InvalidMethodError.class, ResponseTooLargeError.class}) public static class Message {} public static class RequestHeader extends Message { public int requestId; public int methodId; // Request Object public RequestHeader() {} } public static class ResponseHeader extends Message { public int requestId; // Response object public ResponseHeader() {} } // Error while invoking a method public static class ExecutionError_ extends Message { public int requestId; // Error object public ExecutionError_() {} } // MethodName does not exist public static class InvalidMethodError extends Message { public int requestId; public InvalidMethodError() {} } // Exception, not in method but somewhere else public static class Exception_ extends Message { public int requestId; public String message; public Exception_() {} } public static class ResponseTooLargeError extends Message { public int requestId; public ResponseTooLargeError() {} } }