/*******************************************************************************
* Copyright (c) 2010 Association for Decentralized Information Management in
* Industry THTH ry.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* VTT Technical Research Centre of Finland - initial API and implementation
*******************************************************************************/
package org.simantics.databoard.method;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.EOFException;
import java.io.IOException;
import java.net.Socket;
import java.net.SocketException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.simantics.databoard.Bindings;
import org.simantics.databoard.annotations.Union;
import org.simantics.databoard.binding.Binding;
import org.simantics.databoard.binding.RecordBinding;
import org.simantics.databoard.binding.UnionBinding;
import org.simantics.databoard.serialization.Serializer;
import org.simantics.databoard.serialization.SerializerConstructionException;
import org.simantics.databoard.util.binary.BinaryReadable;
import org.simantics.databoard.util.binary.BinaryWriteable;
import org.simantics.databoard.util.binary.InputStreamReadable;
import org.simantics.databoard.util.binary.OutputStreamWriteable;
/**
* Connection is a class that handles request-response communication over a
* socket.
*
* Requests have asynchronous result. The result can be acquired using one of
* the three methods:
* 1) Blocking read AsyncResult.waitForResponse()
* 2) Poll AsyncResult.getResponse()
* 3) Listen AsyncResult.setListener()
*
* The socket must be established before Connection is instantiated.
* Closing connection does not close its Socket.
* If the socket is closed before connection there an error is thrown.
* The error is available by placing listener.
* The proper order to close a connection is to close Connection first
* and then Socket.
*
* @author Toni Kalajainen
*/
public class TcpConnection implements MethodInterface {
public static final ExecutorService SHARED_EXECUTOR_SERVICE =
new ThreadPoolExecutor(0, Integer.MAX_VALUE, 100L, TimeUnit.MILLISECONDS, new SynchronousQueue());
static final Serializer MESSAGE_SERIALIZER = Bindings.getSerializerUnchecked( Bindings.getBindingUnchecked(Message.class) );
static Charset UTF8 = Charset.forName("UTF8");
Handshake local, remote;
Interface remoteType;
MethodTypeDefinition[] localMethods, remoteMethods;
HashMap localMethodsMap, remoteMethodsMap;
Socket socket;
// if false, there is an error in the socket or the connection has been shutdown
boolean active = true;
// Objects used for handling local services
MethodInterface methodInterface;
// Objects used for reading data
ConcurrentHashMap requests = new ConcurrentHashMap();
List inIdentities = new ArrayList();
BinaryReadable in;
int maxRecvSize;
// Object used for writing data
public ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE;
TObjectIntHashMap outIdentities = new TObjectIntHashMap();
BinaryWriteable out;
AtomicInteger requestCounter = new AtomicInteger(0);
int maxSendSize;
// Cached method descriptions
Map methodTypes = new ConcurrentHashMap();
/**
* Handshake a socket
*
* @param socket
* @param localData local data
* @return the remote data
* @throws IOException
* @throws RuntimeException unexpected error (BindingException or EncodingException)
*/
public static Handshake handshake(final Socket socket, final Handshake localData)
throws IOException
{
final BinaryReadable bin = new InputStreamReadable( socket.getInputStream(), Long.MAX_VALUE );
final BinaryWriteable bout = new OutputStreamWriteable( socket.getOutputStream() );
ExecutorService writeExecutor = SHARED_EXECUTOR_SERVICE;
// do hand-shake
final Exception[] writeError = new Exception[1];
final Semaphore s = new Semaphore(0);
writeExecutor.execute(new Runnable() {
@Override
public void run() {
try {
TObjectIntHashMap outIdentities = new TObjectIntHashMap();
Handshake.SERIALIZER.serialize(bout, outIdentities, localData);
bout.flush();
outIdentities.clear();
} catch (IOException e) {
writeError[0] = e;
} finally {
s.release(1);
}
}});
// Read remote peer's handshake
List inIdentities = new ArrayList();
Handshake result = (Handshake) Handshake.SERIALIZER.deserialize(bin, inIdentities);
inIdentities.clear();
// Check that write was ok
try {
s.acquire(1);
Exception e = writeError[0];
if (e!=null && e instanceof IOException)
throw (IOException) e;
if (e!=null)
throw new RuntimeException(e);
} catch (InterruptedException e) {
throw new RuntimeException(e);
} finally {
// writeExecutor.shutdown();
}
return result;
}
/**
* Create a connection to a hand-shaken socket
*
* @param socket
* @param methodInterface local method handler
* @param localData
* @param remoteData
* @throws IOException
*/
public TcpConnection(Socket socket, MethodInterface methodInterface, Handshake localData, Handshake remoteData)
throws IOException {
if (socket==null || localData==null || remoteData==null)
throw new IllegalArgumentException("null arg");
this.methodInterface = methodInterface;
this.socket = socket;
this.local = localData;
this.remote = remoteData;
this.maxSendSize = Math.min(localData.sendMsgLimit, remoteData.recvMsgLimit);
this.maxRecvSize = Math.min(localData.recvMsgLimit, remoteData.sendMsgLimit);
this.localMethods = local.methods;
this.remoteMethods = remote.methods;
this.remoteType = new Interface(this.remoteMethods);
this.localMethodsMap = new HashMap();
this.remoteMethodsMap = new HashMap();
for (int i=0; i0) {
System.out.print(System.identityHashCode( k.getType().requestType.getComponentType(0) ) );
}
System.out.println();
}
*/
throw new MethodNotSupportedException(description.getName());
}
int id = remoteMethodsMap.get(description);
try {
return new MethodImpl(id, binding);
} catch (SerializerConstructionException e) {
throw new MethodNotSupportedException(e);
}
}
@Override
public Method getMethod(MethodTypeDefinition description)
throws MethodNotSupportedException {
// producer suggests object bindings
if (!remoteMethodsMap.containsKey(description)) {
throw new MethodNotSupportedException(description.getName());
}
int id = remoteMethodsMap.get(description);
RecordBinding reqBinding = (RecordBinding) Bindings.getMutableBinding(description.getType().getRequestType());
Binding resBinding = Bindings.getMutableBinding(description.getType().getResponseType());
UnionBinding errBinding = (UnionBinding) Bindings.getMutableBinding(description.getType().getErrorType());
MethodTypeBinding binding = new MethodTypeBinding(description, reqBinding, resBinding, errBinding);
try {
return new MethodImpl(id, binding);
} catch (SerializerConstructionException e) {
// Generic binding should work
throw new MethodNotSupportedException(e);
}
}
public Socket getSocket()
{
return socket;
}
public interface ConnectionListener {
/**
* There was an error and connection was closed
*
* @param error
*/
void onError(Exception error);
/**
* close() was invoked
*/
void onClosed();
}
CopyOnWriteArrayList listeners = new CopyOnWriteArrayList();
public synchronized void addConnectionListener(ConnectionListener listener) {
listeners.add( listener );
}
public void removeConnectionListener(ConnectionListener listener) {
listeners.remove( listener );
}
class MethodImpl implements Method {
int methodId;
MethodTypeBinding methodBinding;
Serializer responseSerializer;
Serializer requestSerializer;
Serializer errorSerializer;
MethodImpl(int methodId, MethodTypeBinding methodBinding) throws SerializerConstructionException
{
this.methodId = methodId;
this.methodBinding = methodBinding;
this.requestSerializer = Bindings.getSerializer( methodBinding.getRequestBinding() );
this.responseSerializer = Bindings.getSerializer( methodBinding.getResponseBinding() );
this.errorSerializer = Bindings.getSerializer( methodBinding.getErrorBinding() );
}
@Override
public AsyncResult invoke(final Object request) {
// Write, async
final PendingRequest result = new PendingRequest(this, requestCounter.getAndIncrement());
requests.put(result.requestId, result);
if (!active) {
result.setInvokeException(new InvokeException(new ConnectionClosedException()));
} else {
writeExecutor.execute(new Runnable() {
@Override
public void run() {
synchronized(TcpConnection.this) {
try {
int size= requestSerializer.getSize(request, outIdentities);
if (size>maxSendSize) {
result.setInvokeException(new InvokeException(new MessageOverflowException()));
return;
}
outIdentities.clear();
RequestHeader reqHeader = new RequestHeader();
reqHeader.methodId = methodId;
reqHeader.requestId = result.requestId;
MESSAGE_SERIALIZER.serialize(out, outIdentities, reqHeader);
outIdentities.clear();
out.writeInt(size);
requestSerializer.serialize(out, outIdentities, request);
outIdentities.clear();
out.flush();
} catch (IOException e) {
result.setInvokeException(new InvokeException(e));
} catch (RuntimeException e) {
result.setInvokeException(new InvokeException(e));
}
}
}});
}
return result;
}
@Override
public MethodTypeBinding getMethodBinding() {
return methodBinding;
}
}
void setClosed()
{
for (ConnectionListener listener : listeners)
listener.onClosed();
}
void setError(Exception e)
{
for (ConnectionListener listener : listeners)
listener.onError(e);
close();
}
/**
* Get method interface that handles services locally (service requests by peer)
*
* @return local method interface
*/
public MethodInterface getLocalMethodInterface()
{
return methodInterface;
}
/**
* Get method interface that handles services locally (service requests by peer)
*
* @return local method interface
*/
public MethodTypeDefinition[] getLocalMethodDescriptions()
{
return localMethods;
}
public MethodInterface getRemoteMethodInterface() {
return this;
}
/**
* Close the connection. All pending service request are canceled.
* The socket is not closed.
*/
public void close() {
active = false;
// cancel all pending requests
ArrayList reqs = new ArrayList(requests.values());
for (PendingRequest pr : reqs) {
pr.setInvokeException(new InvokeException(new ConnectionClosedException()));
}
requests.values().removeAll(reqs);
// shutdown inthread
thread.interrupt();
// for (ConnectionListener listener : listeners)
// listener.onClosed();
}
/**
* Get the active connection of current thread
*
* @return Connection or null
if current thread does not run connection
*/
public static TcpConnection getCurrentConnection() {
Thread t = Thread.currentThread();
if (t instanceof ConnectionThread == false) return null;
ConnectionThread ct = (ConnectionThread) t;
return ct.getConnection();
}
/**
* Connection Thread deserializes incoming messages from the TCP Stream.
*
*/
class ConnectionThread extends Thread {
public ConnectionThread() {
setDaemon(true);
}
public TcpConnection getConnection() {
return TcpConnection.this;
}
public void run() {
while (!Thread.interrupted()) {
try {
Message header = (Message) MESSAGE_SERIALIZER.deserialize(in, inIdentities);
if (header instanceof RequestHeader) {
final RequestHeader reqHeader = (RequestHeader) header;
int size = in.readInt();
if (size>maxRecvSize) {
setError(new MessageOverflowException());
return;
}
int methodId = reqHeader.methodId;
if (methodId<0||methodId>=localMethods.length) {
setError(new Exception("ProtocolError"));
return;
}
MethodTypeDefinition methodDescription = localMethods[methodId];
// Let back-end determine bindings
try {
final Method method = methodInterface.getMethod(methodDescription);
final MethodTypeBinding methodBinding = method.getMethodBinding();
// Deserialize payload
final Object request = Bindings.getSerializerUnchecked(methodBinding.getRequestBinding()).deserialize(in, inIdentities);
inIdentities.clear();
// Invoke method
method.invoke(request).setListener(new InvokeListener() {
@Override
public void onCompleted(final Object response) {
// Write RESP
writeExecutor.execute(new Runnable() {
@Override
public void run() {
synchronized(TcpConnection.this) {
try {
Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getResponseBinding());
int size = serializer.getSize(response, outIdentities);
outIdentities.clear();
if (size > maxSendSize) {
ResponseTooLargeError tooLarge = new ResponseTooLargeError();
tooLarge.requestId = reqHeader.requestId;
MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge);
outIdentities.clear();
return;
}
ResponseHeader respHeader = new ResponseHeader();
respHeader.requestId = reqHeader.requestId;
MESSAGE_SERIALIZER.serialize(out, outIdentities, respHeader);
outIdentities.clear();
out.writeInt(size);
serializer.serialize(out, outIdentities, response);
outIdentities.clear();
out.flush();
} catch (IOException e) {
setError(e);
} catch (RuntimeException e) {
setError(e);
}
}
}});
}
@Override
public void onException(final Exception cause) {
// Write ERRO
writeExecutor.execute(new Runnable() {
@Override
public void run() {
synchronized(TcpConnection.this) {
try {
Exception_ msg = new Exception_();
msg.message = cause.getClass().getName()+": "+cause.getMessage();
MESSAGE_SERIALIZER.serialize(out, outIdentities, msg);
outIdentities.clear();
out.flush();
} catch (IOException e) {
setError(e);
} catch (RuntimeException e) {
setError(e);
}
}
}});
}
@Override
public void onExecutionError(final Object error) {
// Write ERRO
writeExecutor.execute(new Runnable() {
@Override
public void run() {
synchronized(TcpConnection.this) {
try {
Serializer serializer = Bindings.getSerializerUnchecked(methodBinding.getErrorBinding());
int size = serializer.getSize(error, outIdentities);
outIdentities.clear();
if (size > maxSendSize) {
ResponseTooLargeError tooLarge = new ResponseTooLargeError();
tooLarge.requestId = reqHeader.requestId;
MESSAGE_SERIALIZER.serialize(out, outIdentities, tooLarge);
outIdentities.clear();
return;
}
ExecutionError_ errorHeader = new ExecutionError_();
errorHeader.requestId = reqHeader.requestId;
MESSAGE_SERIALIZER.serialize(out, outIdentities, errorHeader);
outIdentities.clear();
out.writeInt(size);
serializer.serialize(out, outIdentities, error);
outIdentities.clear();
out.flush();
} catch (IOException e) {
setError(e);
} catch (RuntimeException e) {
setError(e);
}
}
}});
}});
} catch (MethodNotSupportedException e) {
in.skipBytes(size);
// return with an error
final InvalidMethodError error = new InvalidMethodError();
error.requestId = reqHeader.requestId;
writeExecutor.execute(new Runnable() {
@Override
public void run() {
synchronized(TcpConnection.this) {
try {
MESSAGE_SERIALIZER.serialize(out, outIdentities, error);
outIdentities.clear();
out.flush();
} catch (IOException e) {
setError(e);
} catch (RuntimeException e) {
setError(e);
}
}
}});
continue;
}
} else if (header instanceof ResponseHeader) {
int requestId = ((ResponseHeader)header).requestId;
PendingRequest req = requests.remove(requestId);
if (req==null) {
setError(new RuntimeException("Request by id "+requestId+" does not exist"));
return;
}
int size = in.readInt();
if (size>maxRecvSize) {
// TODO SOMETHING
}
Object response = req.method.responseSerializer.deserialize(in, inIdentities);
inIdentities.clear();
req.setResponse(response);
} else if (header instanceof ExecutionError_) {
int requestId = ((ExecutionError_)header).requestId;
PendingRequest req = requests.remove(requestId);
if (req==null) {
setError(new RuntimeException("Request by id "+requestId+" does not exist"));
return;
}
int size = in.readInt();
if (size>maxRecvSize) {
// TODO SOMETHING
}
Object executionError = req.method.errorSerializer.deserialize(in, inIdentities);
inIdentities.clear();
req.setExecutionError(executionError);
} else if (header instanceof Exception_) {
int requestId = ((Exception_)header).requestId;
PendingRequest req = requests.remove(requestId);
req.setExecutionError(new Exception(((Exception_)header).message));
} else if (header instanceof InvalidMethodError) {
int requestId = ((InvalidMethodError)header).requestId;
PendingRequest req = requests.remove(requestId);
req.setInvokeException(new InvokeException(new MethodNotSupportedException("?")));
} else if (header instanceof ResponseTooLargeError) {
int requestId = ((ResponseTooLargeError)header).requestId;
PendingRequest req = requests.remove(requestId);
req.setInvokeException(new InvokeException(new MessageOverflowException()));
}
} catch (EOFException e) {
setClosed();
break;
} catch (SocketException e) {
if (e.getMessage().equals("Socket Closed"))
setClosed();
else
setError(e);
break;
} catch (IOException e) {
setError(e);
break;
}
}
try {
socket.close();
} catch (IOException e) {
}
// Close pending requests
close();
};
}
// Thread that reads input data
ConnectionThread thread = new ConnectionThread();
class PendingRequest extends AsyncResultImpl {
MethodImpl method;
// request id
int requestId;
public PendingRequest(MethodImpl method, int requestId) {
this.method = method;
this.requestId = requestId;
}
}
@Union({RequestHeader.class, ResponseHeader.class, ExecutionError_.class, Exception_.class, InvalidMethodError.class, ResponseTooLargeError.class})
public static class Message {}
public static class RequestHeader extends Message {
public int requestId;
public int methodId;
// Request Object
public RequestHeader() {}
}
public static class ResponseHeader extends Message {
public int requestId;
// Response object
public ResponseHeader() {}
}
// Error while invoking a method
public static class ExecutionError_ extends Message {
public int requestId;
// Error object
public ExecutionError_() {}
}
// MethodName does not exist
public static class InvalidMethodError extends Message {
public int requestId;
public InvalidMethodError() {}
}
// Exception, not in method but somewhere else
public static class Exception_ extends Message {
public int requestId;
public String message;
public Exception_() {}
}
public static class ResponseTooLargeError extends Message {
public int requestId;
public ResponseTooLargeError() {}
}
}