diff --git a/netty-rpc-client/src/main/java/com/netty/rpc/client/connect/ConnectionManager.java b/netty-rpc-client/src/main/java/com/netty/rpc/client/connect/ConnectionManager.java index edd6168..9669a0b 100644 --- a/netty-rpc-client/src/main/java/com/netty/rpc/client/connect/ConnectionManager.java +++ b/netty-rpc-client/src/main/java/com/netty/rpc/client/connect/ConnectionManager.java @@ -32,6 +32,7 @@ public class ConnectionManager { 600L, TimeUnit.SECONDS, new LinkedBlockingQueue(1000)); private Map connectedServerNodes = new ConcurrentHashMap<>(); + private CopyOnWriteArraySet rpcProtocolSet = new CopyOnWriteArraySet<>(); private ReentrantLock lock = new ReentrantLock(); private Condition connected = lock.newCondition(); private long waitTimeout = 5000; @@ -50,53 +51,54 @@ public static ConnectionManager getInstance() { } public void updateConnectedServer(List serviceList) { - threadPoolExecutor.submit(new Runnable() { - @Override - public void run() { - if (serviceList != null) { - if (serviceList.size() > 0) { - // Update local serverNodes cache - HashSet serviceSet = new HashSet<>(serviceList.size()); - for (int i = 0; i < serviceList.size(); ++i) { - RpcProtocol rpcProtocol = serviceList.get(i); - serviceSet.add(rpcProtocol); - } + // Now using 2 collections to manage the service info and TCP connections because making the connection is async + // Once service info is updated on ZK, will trigger this function + // Actually client should only care about the service it is using + if (serviceList != null && serviceList.size() > 0) { + // Update local serverNodes cache + HashSet serviceSet = new HashSet<>(serviceList.size()); + for (int i = 0; i < serviceList.size(); ++i) { + RpcProtocol rpcProtocol = serviceList.get(i); + serviceSet.add(rpcProtocol); + } - // Add new server info - for (final RpcProtocol rpcProtocol : serviceSet) { - if (!connectedServerNodes.keySet().contains(rpcProtocol)) { - connectServerNode(rpcProtocol); - } - } + // Add new server info + for (final RpcProtocol rpcProtocol : serviceSet) { + if (!rpcProtocolSet.contains(rpcProtocol)) { + rpcProtocolSet.add(rpcProtocol); + connectServerNode(rpcProtocol); + } + } - // Close and remove invalid server nodes - for (RpcProtocol rpcProtocol : connectedServerNodes.keySet()) { - if (!serviceSet.contains(rpcProtocol)) { - logger.info("Remove invalid service: " + rpcProtocol.toJson()); - RpcClientHandler handler = connectedServerNodes.get(rpcProtocol); - if (handler != null) { - handler.close(); - } - connectedServerNodes.remove(rpcProtocol); - } - } - } else { - // No available service - logger.error("No available service!"); - for (RpcProtocol rpcProtocol : connectedServerNodes.keySet()) { - RpcClientHandler handler = connectedServerNodes.get(rpcProtocol); - handler.close(); - connectedServerNodes.remove(rpcProtocol); - } + // Close and remove invalid server nodes + for (RpcProtocol rpcProtocol : rpcProtocolSet) { + if (!serviceSet.contains(rpcProtocol)) { + logger.info("Remove invalid service: " + rpcProtocol.toJson()); + RpcClientHandler handler = connectedServerNodes.get(rpcProtocol); + if (handler != null) { + handler.close(); } + connectedServerNodes.remove(rpcProtocol); + rpcProtocolSet.remove(rpcProtocol); } } - }); + } else { + // No available service + logger.error("No available service!"); + for (RpcProtocol rpcProtocol : rpcProtocolSet) { + RpcClientHandler handler = connectedServerNodes.get(rpcProtocol); + if (handler != null) { + handler.close(); + } + connectedServerNodes.remove(rpcProtocol); + rpcProtocolSet.remove(rpcProtocol); + } + } } private void connectServerNode(RpcProtocol rpcProtocol) { - logger.info("New service: {}, version:{}, uuid: {}, host: {}, port:{}", rpcProtocol.getServiceName(), - rpcProtocol.getVersion(), rpcProtocol.getUuid(), rpcProtocol.getHost(), rpcProtocol.getPort()); + logger.info("New service: {}, version:{}, host: {}, port:{}", rpcProtocol.getServiceName(), + rpcProtocol.getVersion(), rpcProtocol.getHost(), rpcProtocol.getPort()); final InetSocketAddress remotePeer = new InetSocketAddress(rpcProtocol.getHost(), rpcProtocol.getPort()); threadPoolExecutor.submit(new Runnable() { @Override @@ -111,10 +113,13 @@ public void run() { @Override public void operationComplete(final ChannelFuture channelFuture) throws Exception { if (channelFuture.isSuccess()) { - logger.info("Successfully connect to remote server. remote peer = " + remotePeer); + logger.info("Successfully connect to remote server, remote peer = " + remotePeer); RpcClientHandler handler = channelFuture.channel().pipeline().get(RpcClientHandler.class); connectedServerNodes.put(rpcProtocol, handler); + handler.setRpcProtocol(rpcProtocol); signalAvailableHandler(); + } else { + logger.error("Can not connect to remote server, remote peer = " + remotePeer); } } }); @@ -152,15 +157,28 @@ public RpcClientHandler chooseHandler(String serviceKey) throws Exception { } } RpcProtocol rpcProtocol = loadBalance.route(serviceKey, connectedServerNodes); - return connectedServerNodes.get(rpcProtocol); + RpcClientHandler handler = connectedServerNodes.get(rpcProtocol); + if (handler != null) { + return handler; + } else { + throw new Exception("Can not get available connection"); + } + } + + public void removeHandler(RpcProtocol rpcProtocol) { + rpcProtocolSet.remove(rpcProtocol); + connectedServerNodes.remove(rpcProtocol); } public void stop() { isRunning = false; - for (RpcProtocol rpcProtocol : connectedServerNodes.keySet()) { + for (RpcProtocol rpcProtocol : rpcProtocolSet) { RpcClientHandler handler = connectedServerNodes.get(rpcProtocol); - handler.close(); + if (handler != null) { + handler.close(); + } connectedServerNodes.remove(rpcProtocol); + rpcProtocolSet.remove(rpcProtocol); } signalAvailableHandler(); threadPoolExecutor.shutdown(); diff --git a/netty-rpc-client/src/main/java/com/netty/rpc/client/discovery/ServiceDiscovery.java b/netty-rpc-client/src/main/java/com/netty/rpc/client/discovery/ServiceDiscovery.java index 08059ba..f248946 100644 --- a/netty-rpc-client/src/main/java/com/netty/rpc/client/discovery/ServiceDiscovery.java +++ b/netty-rpc-client/src/main/java/com/netty/rpc/client/discovery/ServiceDiscovery.java @@ -29,8 +29,8 @@ public ServiceDiscovery(String registryAddress) { private void discoveryService() { try { - // Get init service info - logger.info("Get init service info"); + // Get initial service info + logger.info("Get initial service info"); getServiceAndUpdateServer(); // Add watch listener curatorClient.watchPathChildrenNode(Constant.ZK_REGISTRY_PATH, new PathChildrenCacheListener() { @@ -38,10 +38,14 @@ private void discoveryService() { public void childEvent(CuratorFramework curatorFramework, PathChildrenCacheEvent pathChildrenCacheEvent) throws Exception { PathChildrenCacheEvent.Type type = pathChildrenCacheEvent.getType(); switch (type) { + case CONNECTION_RECONNECTED: + logger.info("Reconnected to zk, try to get latest service list"); + getServiceAndUpdateServer(); + break; case CHILD_ADDED: case CHILD_UPDATED: case CHILD_REMOVED: - logger.info("Service info updated, try to get latest service list"); + logger.info("Service info changed, try to get latest service list"); getServiceAndUpdateServer(); break; } @@ -63,8 +67,7 @@ private void getServiceAndUpdateServer() { RpcProtocol rpcProtocol = RpcProtocol.fromJson(json); dataList.add(rpcProtocol); } - logger.debug("Node data: {}", dataList); - logger.debug("Service discovery triggered updating connected server node."); + logger.debug("Service node data: {}", dataList); //Update the service info based on the latest data UpdateConnectedServer(dataList); } catch (Exception e) { diff --git a/netty-rpc-client/src/main/java/com/netty/rpc/client/handler/RpcClientHandler.java b/netty-rpc-client/src/main/java/com/netty/rpc/client/handler/RpcClientHandler.java index 1f4b7d1..65cb95d 100644 --- a/netty-rpc-client/src/main/java/com/netty/rpc/client/handler/RpcClientHandler.java +++ b/netty-rpc-client/src/main/java/com/netty/rpc/client/handler/RpcClientHandler.java @@ -1,8 +1,10 @@ package com.netty.rpc.client.handler; +import com.netty.rpc.client.connect.ConnectionManager; import com.netty.rpc.codec.Beat; import com.netty.rpc.codec.RpcRequest; import com.netty.rpc.codec.RpcResponse; +import com.netty.rpc.protocol.RpcProtocol; import io.netty.buffer.Unpooled; import io.netty.channel.*; import io.netty.handler.timeout.IdleStateEvent; @@ -21,6 +23,7 @@ public class RpcClientHandler extends SimpleChannelInboundHandler { private ConcurrentHashMap pendingRPC = new ConcurrentHashMap<>(); private volatile Channel channel; private SocketAddress remotePeer; + private RpcProtocol rpcProtocol; @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { @@ -82,4 +85,14 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc super.userEventTriggered(ctx, evt); } } + + public void setRpcProtocol(RpcProtocol rpcProtocol) { + this.rpcProtocol = rpcProtocol; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + super.channelInactive(ctx); + ConnectionManager.getInstance().removeHandler(rpcProtocol); + } } diff --git a/netty-rpc-common/src/main/java/com/netty/rpc/protocol/RpcProtocol.java b/netty-rpc-common/src/main/java/com/netty/rpc/protocol/RpcProtocol.java index 57455d1..5d5b54d 100644 --- a/netty-rpc-common/src/main/java/com/netty/rpc/protocol/RpcProtocol.java +++ b/netty-rpc-common/src/main/java/com/netty/rpc/protocol/RpcProtocol.java @@ -7,8 +7,9 @@ public class RpcProtocol implements Serializable { private static final long serialVersionUID = -1102180003395190700L; - private String uuid; + // service host private String host; + // service port private int port; // interface name private String serviceName; @@ -30,15 +31,14 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; RpcProtocol that = (RpcProtocol) o; return port == that.port && - uuid.equals(that.uuid) && host.equals(that.host) && serviceName.equals(that.serviceName) && - version.equals(this.version); + version.equals(that.version); } @Override public int hashCode() { - return Objects.hash(uuid, host, port, serviceName, version); + return Objects.hash(host, port, serviceName, version); } @Override @@ -46,14 +46,6 @@ public String toString() { return toJson(); } - public String getUuid() { - return uuid; - } - - public void setUuid(String uuid) { - this.uuid = uuid; - } - public String getHost() { return host; } diff --git a/netty-rpc-common/src/main/java/com/netty/rpc/zookeeper/CuratorClient.java b/netty-rpc-common/src/main/java/com/netty/rpc/zookeeper/CuratorClient.java index 0619504..2cebb83 100644 --- a/netty-rpc-common/src/main/java/com/netty/rpc/zookeeper/CuratorClient.java +++ b/netty-rpc-common/src/main/java/com/netty/rpc/zookeeper/CuratorClient.java @@ -7,6 +7,7 @@ import org.apache.curator.framework.recipes.cache.PathChildrenCacheListener; import org.apache.curator.framework.recipes.cache.TreeCache; import org.apache.curator.framework.recipes.cache.TreeCacheListener; +import org.apache.curator.framework.state.ConnectionStateListener; import org.apache.curator.retry.ExponentialBackoffRetry; import org.apache.zookeeper.CreateMode; import org.apache.zookeeper.Watcher; @@ -19,7 +20,8 @@ public class CuratorClient { public CuratorClient(String connectString, String namespace, int sessionTimeout, int connectionTimeout) { client = CuratorFrameworkFactory.builder().namespace(namespace).connectString(connectString) .sessionTimeoutMs(sessionTimeout).connectionTimeoutMs(connectionTimeout) - .retryPolicy(new ExponentialBackoffRetry(2000, 10)).build(); + .retryPolicy(new ExponentialBackoffRetry(1000, 10)) + .build(); client.start(); } @@ -35,6 +37,10 @@ public CuratorFramework getClient() { return client; } + public void addConnectionStateListener(ConnectionStateListener connectionStateListener) { + client.getConnectionStateListenable().addListener(connectionStateListener); + } + public void createPathData(String path, byte[] data) throws Exception { client.create().creatingParentsIfNeeded() .withMode(CreateMode.EPHEMERAL_SEQUENTIAL) diff --git a/netty-rpc-server/src/main/java/com/netty/rpc/server/core/NettyServer.java b/netty-rpc-server/src/main/java/com/netty/rpc/server/core/NettyServer.java index a816230..5236abf 100644 --- a/netty-rpc-server/src/main/java/com/netty/rpc/server/core/NettyServer.java +++ b/netty-rpc-server/src/main/java/com/netty/rpc/server/core/NettyServer.java @@ -63,9 +63,9 @@ public void run() { future.channel().closeFuture().sync(); } catch (Exception e) { if (e instanceof InterruptedException) { - logger.info("Rpc server remoting server stop."); + logger.info("Rpc server remoting server stop"); } else { - logger.error("Rpc server remoting server error.", e); + logger.error("Rpc server remoting server error", e); } } finally { try { diff --git a/netty-rpc-server/src/main/java/com/netty/rpc/server/core/RpcServerHandler.java b/netty-rpc-server/src/main/java/com/netty/rpc/server/core/RpcServerHandler.java index 4608d42..26de790 100644 --- a/netty-rpc-server/src/main/java/com/netty/rpc/server/core/RpcServerHandler.java +++ b/netty-rpc-server/src/main/java/com/netty/rpc/server/core/RpcServerHandler.java @@ -37,7 +37,7 @@ public RpcServerHandler(Map handlerMap, final ThreadPoolExecutor public void channelRead0(final ChannelHandlerContext ctx, final RpcRequest request) { // filter beat ping if (Beat.BEAT_ID.equalsIgnoreCase(request.getRequestId())) { - logger.info("Server read beat-ping."); + logger.info("Server read heartbeat ping"); return; } diff --git a/netty-rpc-server/src/main/java/com/netty/rpc/server/registry/ServiceRegistry.java b/netty-rpc-server/src/main/java/com/netty/rpc/server/registry/ServiceRegistry.java index 06b9ed4..a368d04 100644 --- a/netty-rpc-server/src/main/java/com/netty/rpc/server/registry/ServiceRegistry.java +++ b/netty-rpc-server/src/main/java/com/netty/rpc/server/registry/ServiceRegistry.java @@ -1,10 +1,12 @@ package com.netty.rpc.server.registry; -import cn.hutool.core.util.IdUtil; import com.netty.rpc.config.Constant; import com.netty.rpc.protocol.RpcProtocol; import com.netty.rpc.util.ServiceUtil; import com.netty.rpc.zookeeper.CuratorClient; +import org.apache.curator.framework.CuratorFramework; +import org.apache.curator.framework.state.ConnectionState; +import org.apache.curator.framework.state.ConnectionStateListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,37 +30,42 @@ public ServiceRegistry(String registryAddress) { } public void registerService(String host, int port, Map serviceMap) { - //register service info, format uuid:ip:port - if (serviceMap.size() > 0) { - for (String key : serviceMap.keySet()) { - try { - RpcProtocol rpcProtocol = new RpcProtocol(); - //Add an uuid when register the service so we can distinguish the same ip:port service - String uuid = IdUtil.objectId(); - rpcProtocol.setUuid(uuid); - rpcProtocol.setHost(host); - rpcProtocol.setPort(port); - String[] serviceInfo = key.split(ServiceUtil.SERVICE_CONCAT_TOKEN); - if (serviceInfo.length > 0) { - rpcProtocol.setServiceName(serviceInfo[0]); - if (serviceInfo.length == 2) { - rpcProtocol.setVersion(serviceInfo[1]); - } else { - rpcProtocol.setVersion(""); - } - String serviceData = rpcProtocol.toJson(); - byte[] bytes = serviceData.getBytes(); - String path = Constant.ZK_DATA_PATH + "-" + uuid; - this.curatorClient.createPathData(path, bytes); - pathList.add(path); - logger.info("Registry new service:{}, host:{}, port:{}", key, host, port); + // Register service info + for (String key : serviceMap.keySet()) { + try { + RpcProtocol rpcProtocol = new RpcProtocol(); + rpcProtocol.setHost(host); + rpcProtocol.setPort(port); + String[] serviceInfo = key.split(ServiceUtil.SERVICE_CONCAT_TOKEN); + if (serviceInfo.length > 0) { + rpcProtocol.setServiceName(serviceInfo[0]); + if (serviceInfo.length == 2) { + rpcProtocol.setVersion(serviceInfo[1]); } else { - logger.warn("Can not get service name and version"); + rpcProtocol.setVersion(""); } - } catch (Exception e) { - logger.error("Register service {} fail, exception:{}", key, e.getMessage()); + String serviceData = rpcProtocol.toJson(); + byte[] bytes = serviceData.getBytes(); + String path = Constant.ZK_DATA_PATH + "-" + rpcProtocol.hashCode(); + this.curatorClient.createPathData(path, bytes); + pathList.add(path); + logger.info("Register new service: {}, host: {}, port: {}", key, host, port); + } else { + logger.warn("Can not get service name and version: {}" + key); } + } catch (Exception e) { + logger.error("Register service {} fail, exception: {}", key, e.getMessage()); } + + curatorClient.addConnectionStateListener(new ConnectionStateListener() { + @Override + public void stateChanged(CuratorFramework curatorFramework, ConnectionState connectionState) { + if (connectionState == ConnectionState.RECONNECTED) { + logger.info("Connection state: {}, register service after reconnected", connectionState); + registerService(host, port, serviceMap); + } + } + }); } } diff --git a/netty-rpc-test/src/main/java/com/app/test/client/RpcAsyncTest.java b/netty-rpc-test/src/main/java/com/app/test/client/RpcAsyncTest.java index 4575fa0..405a3fc 100644 --- a/netty-rpc-test/src/main/java/com/app/test/client/RpcAsyncTest.java +++ b/netty-rpc-test/src/main/java/com/app/test/client/RpcAsyncTest.java @@ -15,7 +15,7 @@ public static void main(String[] args) throws InterruptedException { final RpcClient rpcClient = new RpcClient("10.217.59.164:2181"); int threadNum = 1; - final int requestNum = 10; + final int requestNum = 100; Thread[] threads = new Thread[threadNum]; long startTime = System.currentTimeMillis(); @@ -34,6 +34,11 @@ public void run() { } else { System.out.println("result = " + result); } + try { + Thread.sleep(5 * 1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } } catch (Exception e) { System.out.println(e.toString()); } diff --git a/netty-rpc-test/src/main/java/com/app/test/client/RpcTest.java b/netty-rpc-test/src/main/java/com/app/test/client/RpcTest.java index 3df547a..7222431 100644 --- a/netty-rpc-test/src/main/java/com/app/test/client/RpcTest.java +++ b/netty-rpc-test/src/main/java/com/app/test/client/RpcTest.java @@ -12,7 +12,7 @@ public static void main(String[] args) throws InterruptedException { final RpcClient rpcClient = new RpcClient("10.217.59.164:2181"); int threadNum = 1; - final int requestNum = 5; + final int requestNum = 50; Thread[] threads = new Thread[threadNum]; long startTime = System.currentTimeMillis(); @@ -22,17 +22,21 @@ public static void main(String[] args) throws InterruptedException { @Override public void run() { for (int i = 0; i < requestNum; i++) { - final HelloService syncClient = rpcClient.createService(HelloService.class, "1.0"); - String result = syncClient.hello(Integer.toString(i)); - if (!result.equals("Hello " + i)) { - System.out.println("error = " + result); - } else { - System.out.println("result = " + result); - } try { - Thread.sleep(5 * 1000); - } catch (InterruptedException e) { - e.printStackTrace(); + final HelloService syncClient = rpcClient.createService(HelloService.class, "1.0"); + String result = syncClient.hello(Integer.toString(i)); + if (!result.equals("Hello " + i)) { + System.out.println("error = " + result); + } else { + System.out.println("result = " + result); + } + try { + Thread.sleep(5 * 1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } catch (Exception ex) { + System.out.println(ex.toString()); } } }