package org.rx.socks; import lombok.extern.slf4j.Slf4j; import org.rx.core.LogWriter; import org.rx.core.NQuery; import org.rx.beans.DateTime; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; import java.util.Map; import java.util.Timer; import java.util.TimerTask; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; import static org.rx.core.Contract.require; @Slf4j public final class SocketPool extends Traceable implements AutoCloseable { public static final class PooledSocket implements AutoCloseable { private final SocketPool owner; private DateTime lastActive; public final Socket socket; public boolean isConnected() { return !owner.isClosed() && !socket.isClosed() && socket.isConnected(); } public DateTime getLastActive() { return lastActive; } public void setLastActive(DateTime lastActive) { this.lastActive = lastActive; } private PooledSocket(SocketPool owner, Socket socket) { this.owner = owner; this.socket = socket; lastActive = DateTime.utcNow(); } @Override public void close() { owner.returnSocket(this); } } public static final SocketPool Pool = new SocketPool(); private static final int DefaultConnectTimeout = 30000; private static final int DefaultMaxIdleMillis = 120000; private static final int DefaultMaxSocketsCount = 64; private final ConcurrentHashMap> pool; private volatile int connectTimeout; private volatile int maxIdleMillis; private volatile int maxSocketsCount; private final Timer timer; private volatile boolean isTimerRun; public int getConnectTimeout() { return connectTimeout; } public void setConnectTimeout(int connectTimeout) { this.connectTimeout = connectTimeout; } public int getMaxIdleMillis() { return maxIdleMillis; } public void setMaxIdleMillis(int maxIdleMillis) { if (maxIdleMillis <= 0) { maxIdleMillis = DefaultMaxIdleMillis; } this.maxIdleMillis = maxIdleMillis; } public int getMaxSocketsCount() { return maxSocketsCount; } public void setMaxSocketsCount(int maxSocketsCount) { if (maxSocketsCount < 0) { maxSocketsCount = 0; } this.maxSocketsCount = maxSocketsCount; } private SocketPool() { pool = new ConcurrentHashMap<>(); connectTimeout = DefaultConnectTimeout; maxIdleMillis = DefaultMaxIdleMillis; maxSocketsCount = DefaultMaxSocketsCount; String n = "SocketPool"; timer = new Timer(n, true); LogWriter tracer = new LogWriter(); tracer.setPrefix(n + " "); tracer.info("started.."); setTracer(tracer); } @Override protected void freeObjects() { clear(); } private void runTimer() { if (isTimerRun) { return; } synchronized (timer) { if (isTimerRun) { return; } long period = 90000; timer.schedule(new TimerTask() { @Override public void run() { clearIdleSockets(); } }, period, period); isTimerRun = true; } getTracer().info("runTimer.."); } private void clearIdleSockets() { for (Map.Entry > entry : NQuery.of(pool.entrySet())) { ConcurrentLinkedDeque sockets = entry.getValue(); if (sockets == null) { continue; } for (PooledSocket socket : NQuery.of(sockets)) { if (!socket.isConnected() || DateTime.utcNow().subtract(socket.getLastActive()).getTotalMilliseconds() >= maxIdleMillis) { sockets.remove(socket); getTracer().info("clear idle socket[local=%s, remote=%s]..", Sockets.getId(socket.socket, false), Sockets.getId(socket.socket, true)); } } if (sockets.isEmpty()) { pool.remove(entry.getKey()); } } if (pool.size() == 0) { stopTimer(); } } private void stopTimer() { synchronized (timer) { timer.cancel(); timer.purge(); isTimerRun = false; } getTracer().info("stopTimer.."); } private ConcurrentLinkedDeque getSockets(InetSocketAddress remoteAddr) { ConcurrentLinkedDeque sockets = pool.get(remoteAddr); if (sockets == null) { pool.put(remoteAddr, sockets = new ConcurrentLinkedDeque<>()); runTimer(); } return sockets; } public PooledSocket borrowSocket(InetSocketAddress remoteAddr) { checkNotClosed(); require(remoteAddr); boolean isExisted = true; ConcurrentLinkedDeque sockets = getSockets(remoteAddr); PooledSocket pooledSocket; if ((pooledSocket = sockets.pollFirst()) == null) { Socket sock = new Socket(); try { sock.connect(remoteAddr, connectTimeout); } catch (IOException ex) { throw new SocketException(remoteAddr, ex); } pooledSocket = new PooledSocket(this, sock); isExisted = false; } if (!pooledSocket.isConnected()) { if (isExisted) { sockets.remove(pooledSocket); } return borrowSocket(remoteAddr); } Socket sock = pooledSocket.socket; getTracer().info("borrow %s socket[local=%s, remote=%s]..", isExisted ? "existed" : "new", Sockets.getId(sock, false), Sockets.getId(sock, true)); return pooledSocket; } public void returnSocket(PooledSocket pooledSocket) { checkNotClosed(); require(pooledSocket); String action = "return"; try { if (!pooledSocket.isConnected()) { action = "discard closed"; return; } pooledSocket.setLastActive(DateTime.utcNow()); ConcurrentLinkedDeque sockets = getSockets( (InetSocketAddress) pooledSocket.socket.getRemoteSocketAddress()); if (sockets.size() >= maxSocketsCount || sockets.contains(pooledSocket)) { action = "discard contains"; return; } sockets.addFirst(pooledSocket); } finally { Socket sock = pooledSocket.socket; getTracer().info("%s socket[local=%s, remote=%s]..", action, Sockets.getId(sock, false), Sockets.getId(sock, true)); } } public void clear() { checkNotClosed(); for (Socket socket : NQuery.of(pool.values()).selectMany(p -> p).select(p -> p.socket)) { try { getTracer().info("clear socket[local=%s, remote=%s]..", Sockets.getId(socket, false), Sockets.getId(socket, true)); Sockets.close(socket); } catch (Exception ex) { log.error("SocketPool clear", ex); } } pool.clear(); } }
package org.rx.socks; import org.rx.core.Disposable; import org.rx.core.LogWriter; import static org.rx.core.Contract.isNull; public abstract class Traceable extends Disposable { private LogWriter tracer; public LogWriter getTracer() { return tracer; } public synchronized void setTracer(LogWriter tracer) { this.tracer = isNull(tracer, new LogWriter()); } }
package org.rx.socks; import lombok.extern.slf4j.Slf4j; import org.rx.beans.$; import org.rx.beans.Tuple; import org.rx.util.BufferSegment; import org.rx.util.BytesSegment; import org.rx.core.*; import org.rx.core.AsyncTask; import org.rx.io.MemoryStream; import java.io.IOException; import java.net.*; import java.util.ArrayList; import java.util.Collections; import java.util.List; import static org.rx.beans.$.$; import static org.rx.core.Contract.isNull; import static org.rx.core.Contract.require; @Slf4j public class DirectSocket extends Traceable implements AutoCloseable { @FunctionalInterface public interface SocketSupplier { Tupleget(MemoryStream pack); } private static class ClientItem { private final DirectSocket owner; private final BufferSegment segment; public final NetworkStream stream; public final AutoCloseable toSock; public final NetworkStream toStream; public ClientItem(Socket client, DirectSocket owner) { this.owner = owner; segment = new BufferSegment(Contract.config.getDefaultBufferSize(), 2); try { stream = new NetworkStream(client, segment.alloc()); if (owner.directAddress != null) { SocketPool.PooledSocket pooledSocket = App.retry(owner.connectRetryCount, p -> SocketPool.Pool.borrowSocket(p.directAddress), owner); toSock = pooledSocket; toStream = new NetworkStream(pooledSocket.socket, segment.alloc(), false); return; } if (owner.directSupplier != null) { MemoryStream firstPack = new MemoryStream(32, true); BytesSegment buffer = stream.getSegment(); int read; while ((read = stream.readSegment()) > 0) { System.out.println("----:" + Bytes.toString(buffer.array, buffer.offset, read)); firstPack.write(buffer.array, buffer.offset, read); Tuple toSocks; if ((toSocks = owner.directSupplier.get(firstPack)) != null) { toSock = toSocks.left; firstPack.writeTo(toStream = new NetworkStream(toSocks.right, segment.alloc(), false)); return; } } log.info("DirectSocket ClientState directSupplier read: {}\ncontent: {}", read, Bytes.toString(firstPack.toArray(), 0, firstPack.getLength())); } } catch (IOException ex) { throw new SocketException((InetSocketAddress) client.getLocalSocketAddress(), ex); } throw new SocketException((InetSocketAddress) client.getLocalSocketAddress(), "DirectSocket directSupplier error"); } public void closeSocket() { owner.getTracer().info("client close socket[%s->%s]..", Sockets.getId(stream.getSocket(), false), Sockets.getId(stream.getSocket(), true)); owner.clients.remove(this); stream.close(); } public void closeToSocket(boolean pooling) { owner.getTracer().info("client %s socket[%s->%s]..", pooling ? "pooling" : "close", Sockets.getId(toStream.getSocket(), false), Sockets.getId(toStream.getSocket(), true)); if (pooling) { try { toSock.close(); } catch (Exception ex) { ex.printStackTrace(); } } else { Sockets.close(toStream.getSocket()); } } } public static final SocketSupplier HttpSupplier = pack -> { String line = Bytes.readLine(pack.getBuffer()); if (line == null) { return null; } InetSocketAddress authority; try { authority = Sockets.parseEndpoint( new URL(line.split(" ")[1]) .getAuthority()); } catch (MalformedURLException ex) { throw SystemException.wrap(ex); } SocketPool.PooledSocket pooledSocket = App.retry(2, p -> SocketPool.Pool.borrowSocket(p), authority); return Tuple.of(pooledSocket, pooledSocket.socket); }; private static final int DefaultBacklog = 128; private static final int DefaultConnectRetryCount = 4; private final ServerSocket server; private final List clients; private volatile int connectRetryCount; private InetSocketAddress directAddress; private SocketSupplier directSupplier; @Override public boolean isClosed() { return !(!super.isClosed() && !server.isClosed()); } public InetSocketAddress getLocalAddress() { return (InetSocketAddress) server.getLocalSocketAddress(); } public NQuery > getClients() { return NQuery.of(clients).select(p -> Tuple.of(p.stream.getSocket(), p.toStream.getSocket())); } public int getConnectRetryCount() { return connectRetryCount; } public void setConnectRetryCount(int connectRetryCount) { if (connectRetryCount <= 0) { connectRetryCount = 1; } this.connectRetryCount = connectRetryCount; } public DirectSocket(int listenPort, InetSocketAddress directAddr) { this(new InetSocketAddress(Sockets.AnyAddress, listenPort), directAddr, null); } public DirectSocket(InetSocketAddress listenAddr, InetSocketAddress directAddr, SocketSupplier directSupplier) { require(listenAddr); require(this, directAddr != null || directSupplier != null); try { server = new ServerSocket(); server.setReuseAddress(true); server.bind(listenAddr, DefaultBacklog); } catch (IOException ex) { throw new SocketException(listenAddr, ex); } directAddress = directAddr; this.directSupplier = directSupplier; clients = Collections.synchronizedList(new ArrayList<>()); connectRetryCount = DefaultConnectRetryCount; String taskName = String.format("DirectSocket[%s->%s]", listenAddr, isNull(directAddress, "autoAddress")); LogWriter tracer = new LogWriter(); tracer.setPrefix(taskName + " "); setTracer(tracer); AsyncTask.TaskFactory.run(() -> { getTracer().info("start.."); while (!isClosed()) { try { ClientItem client = new ClientItem(server.accept(), this); clients.add(client); onReceive(client, taskName); } catch (IOException ex) { log.error(taskName, ex); } } close(); }, taskName); } @Override protected void freeObjects() { try { for (ClientItem client : NQuery.of(clients)) { client.closeSocket(); } clients.clear(); server.close(); } catch (IOException ex) { log.error("DirectSocket close", ex); } getTracer().info("stop.."); } private void onReceive(ClientItem client, String taskName) { AsyncTask.TaskFactory.run(() -> { try { int recv = client.stream.directTo(client.toStream, (p1, p2) -> { getTracer().info("sent %s bytes from %s to %s..", p2, Sockets.getId(client.stream.getSocket(), true), Sockets.getId(client.toStream.getSocket(), false)); return true; }); getTracer().info("socket[%s->%s] closing with %s", Sockets.getId(client.stream.getSocket(), false), Sockets.getId(client.stream.getSocket(), true), recv); } catch (SystemException ex) { $ out = $(); if (ex.tryGet(out, java.net.SocketException.class)) { if (out.v.getMessage().contains("Socket closed")) { //ignore log.debug("DirectTo ignore socket closed"); return; } } throw ex; } finally { client.closeSocket(); } }, String.format("%s[networkStream]", taskName)); AsyncTask.TaskFactory.run(() -> { int recv = NetworkStream.StreamEOF; try { recv = client.toStream.directTo(client.stream, (p1, p2) -> { getTracer().info("recv %s bytes from %s to %s..", p2, Sockets.getId(client.toStream.getSocket(), false), Sockets.getId(client.stream.getSocket(), true)); return true; }); getTracer().info("socket[%s->%s] closing with %s", Sockets.getId(client.toStream.getSocket(), false), Sockets.getId(client.toStream.getSocket(), true), recv); } catch (SystemException ex) { $ out = $(); if (ex.tryGet(out, java.net.SocketException.class)) { if (out.v.getMessage().contains("Socket closed")) { //ignore log.debug("DirectTo ignore socket closed"); return; } } throw ex; } finally { client.closeToSocket(recv == NetworkStream.CannotWrite); } }, String.format("%s[toNetworkStream]", taskName)); } }
package org.rx.socks; import lombok.extern.slf4j.Slf4j; import org.rx.util.BytesSegment; import org.rx.io.IOStream; import java.io.IOException; import java.net.Socket; import static org.rx.core.Contract.require; import static org.rx.socks.Sockets.shutdown; @Slf4j public final class NetworkStream extends IOStream { @FunctionalInterface public interface DirectPredicate { boolean test(BytesSegment buffer, int count); } public static final int SocketEOF = 0; public static final int StreamEOF = -1; public static final int CannotWrite = -2; private final boolean ownsSocket; private final Socket socket; private final BytesSegment segment; public boolean isConnected() { return !isClosed() && !socket.isClosed() && socket.isConnected(); } @Override public boolean canRead() { return super.canRead() && checkSocket(socket, false); } @Override public boolean canWrite() { return super.canWrite() && checkSocket(socket, true); } private static boolean checkSocket(Socket sock, boolean isWrite) { return !sock.isClosed() && sock.isConnected() && !(isWrite ? sock.isOutputShutdown() : sock.isInputShutdown()); } public Socket getSocket() { return socket; } public BytesSegment getSegment() { return segment; } public NetworkStream(Socket socket, BytesSegment segment) throws IOException { this(socket, segment, true); } public NetworkStream(Socket socket, BytesSegment segment, boolean ownsSocket) throws IOException { super(socket.getInputStream(), socket.getOutputStream()); this.ownsSocket = ownsSocket; this.socket = socket; this.segment = segment; } @Override protected void freeObjects() { try { log.info("NetworkStream freeObjects ownsSocket={} socket[{}][closed={}]", ownsSocket, Sockets.getId(socket, false), socket.isClosed()); if (ownsSocket) { //super.freeObjects(); Ignore this!! Sockets.close(socket, 1); } } finally { segment.close(); } } int readSegment() { return read(segment.array, segment.offset, segment.count); } void writeSegment(int count) { write(segment.array, segment.offset, count); } public int directTo(NetworkStream to, DirectPredicate onEach) { checkNotClosed(); require(to); int recv = StreamEOF; while (canRead() && (recv = read(segment.array, segment.offset, segment.count)) >= -1) { if (recv <= 0) { if (ownsSocket) { log.debug("DirectTo read {} flag and shutdown send", recv); shutdown(socket, 1); } break; } if (!to.canWrite()) { log.debug("DirectTo read {} bytes and can't write", recv); recv = CannotWrite; break; } to.write(segment.array, segment.offset, recv); if (onEach != null && !onEach.test(segment, recv)) { recv = StreamEOF; break; } } if (to.canWrite()) { to.flush(); } return recv; } }
package org.rx.socks; import org.rx.core.SystemException; import java.net.InetSocketAddress; public class SocketException extends SystemException { private InetSocketAddress localAddress; public InetSocketAddress getLocalAddress() { return localAddress; } public SocketException(InetSocketAddress localAddress, Exception ex) { super(ex); this.localAddress = localAddress; } public SocketException(InetSocketAddress localAddress, String msg) { super(msg); this.localAddress = localAddress; } }
package org.rx.socks; import java.io.IOException; import java.net.*; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFutureListener; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollServerSocketChannel; import io.netty.channel.epoll.EpollSocketChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import org.rx.core.Strings; import org.rx.core.SystemException; import org.rx.core.WeakCache; import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Properties; import java.util.function.Function; import static org.rx.core.Contract.require; import static org.rx.core.Contract.values; public final class Sockets { public static final InetAddress LocalAddress, AnyAddress; static { LocalAddress = InetAddress.getLoopbackAddress(); try { AnyAddress = InetAddress.getByName("0.0.0.0"); } catch (Exception ex) { throw SystemException.wrap(ex); } } public InetAddress[] getAddresses(String host) { return (InetAddress[]) WeakCache.getOrStore("Sockets.getAddresses", values(host), p -> { try { return InetAddress.getAllByName(host); } catch (UnknownHostException ex) { throw SystemException.wrap(ex); } }); } public static InetSocketAddress getAnyEndpoint(int port) { return new InetSocketAddress(AnyAddress, port); } public static InetSocketAddress parseEndpoint(String endpoint) { require(endpoint); String[] arr = Strings.split(endpoint, ":", 2); return new InetSocketAddress(arr[0], Integer.parseInt(arr[1])); } public static void writeAndFlush(Channel channel, Object... packs) { require(channel); channel.eventLoop().execute(() -> { for (Object pack : packs) { channel.write(pack); } channel.flush(); }); } public static EventLoopGroup bossEventLoop() { return eventLoopGroup(1); } public static EventLoopGroup workEventLoop() { return eventLoopGroup(0); } public static EventLoopGroup eventLoopGroup(int threadAmount) { return Epoll.isAvailable() ? new EpollEventLoopGroup(threadAmount) : new NioEventLoopGroup(threadAmount); //NioEventLoopGroup(0, TaskFactory.getExecutor()); } public static Bootstrap bootstrap() { return bootstrap(getChannelClass()); } public static Bootstrap bootstrap(Class extends Channel> channelClass) { require(channelClass); return new Bootstrap().group(channelClass.getName().startsWith("Epoll") ? new EpollEventLoopGroup() : new NioEventLoopGroup()).channel(channelClass); } public static Bootstrap bootstrap(Channel channel) { require(channel); return new Bootstrap().group(channel.eventLoop()).channel(channel.getClass()); } public static Class extends ServerChannel> getServerChannelClass() { return Epoll.isAvailable() ? EpollServerSocketChannel.class : NioServerSocketChannel.class; } public static Class extends Channel> getChannelClass() { return Epoll.isAvailable() ? EpollSocketChannel.class : NioSocketChannel.class; } public static void closeOnFlushed(Channel channel) { require(channel); if (!channel.isActive()) { return; } channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); } public static void close(Socket socket) { close(socket, 1 | 2); } public static void close(Socket socket, int flags) { require(socket); if (!socket.isClosed()) { shutdown(socket, flags); try { socket.setSoLinger(true, 2); socket.close(); } catch (IOException ex) { throw SystemException.wrap(ex); } } } /** * @param socket * @param flags Send=1, Receive=2 */ public static void shutdown(Socket socket, int flags) { require(socket); if (!socket.isClosed() && socket.isConnected()) { try { if ((flags & 1) == 1 && !socket.isOutputShutdown()) { socket.shutdownOutput(); } if ((flags & 2) == 2 && !socket.isInputShutdown()) { socket.shutdownInput(); } } catch (IOException ex) { throw SystemException.wrap(ex); } } } public static String getId(Socket sock, boolean isRemote) { require(sock); InetSocketAddress addr = (InetSocketAddress) (isRemote ? sock.getRemoteSocketAddress() : sock.getLocalSocketAddress()); return addr.getHostString() + ":" + addr.getPort(); } public staticT httpProxyInvoke(String proxyAddr, Function func) { setHttpProxy(proxyAddr); try { return func.apply(proxyAddr); } finally { clearHttpProxy(); } } public static void setHttpProxy(String proxyAddr) { setHttpProxy(proxyAddr, null, null, null); } public static void setHttpProxy(String proxyAddr, List nonProxyHosts, String userName, String password) { InetSocketAddress ipe = parseEndpoint(proxyAddr); Properties prop = System.getProperties(); prop.setProperty("http.proxyHost", ipe.getAddress().getHostAddress()); prop.setProperty("http.proxyPort", String.valueOf(ipe.getPort())); prop.setProperty("https.proxyHost", ipe.getAddress().getHostAddress()); prop.setProperty("https.proxyPort", String.valueOf(ipe.getPort())); if (!CollectionUtils.isEmpty(nonProxyHosts)) { //如"localhost|192.168.0.*" prop.setProperty("http.nonProxyHosts", String.join("|", nonProxyHosts)); } if (userName != null && password != null) { Authenticator.setDefault(new UserAuthenticator(userName, password)); } } public static void clearHttpProxy() { System.clearProperty("http.proxyHost"); System.clearProperty("http.proxyPort"); System.clearProperty("https.proxyHost"); System.clearProperty("https.proxyPort"); System.clearProperty("http.nonProxyHosts"); } static class UserAuthenticator extends Authenticator { private String userName; private String password; public UserAuthenticator(String userName, String password) { this.userName = userName; this.password = password; } protected PasswordAuthentication getPasswordAuthentication() { return new PasswordAuthentication(userName, password.toCharArray()); } } }