自学内容网 自学内容网

Java NIO实现高性能HTTP代理

NIO采用多路复用IO模型,相比传统BIO(阻塞IO),通过轮询机制检测注册的Channel是否有事件发生,可以实现一个线程处理客户端的多个连接,极大提升了并发性能。
在5年前,本人出于对HTTP正向代理的好奇新,那时候也在学习JAVA,了解到了NIO,就想用NIO写一个正向代理软件,当时虽然实现了正向代理,但是代码逻辑及其混乱,而且没有经过测试也许有不少的bug
近期因为找工作,又复习起了以往的一些JAVA知识,包括JVM内存模型、GC垃圾回收机制等等,其中也包括NIO。现在回头再看NIO,理解也更深刻了一点。
在多路复用IO模型中,会有一个线程不断去轮询多个socket的状态,只有当socket真正有读写事件时,才真正调用实际的IO读写操作。因为在多路复用IO模型中,只需要使用一个线程就可以管理多个socket,系统不需要建立新的进程或者线程,也不必维护这些线程和进程,并且只有在真正有socket 读写事件进行时,才会使用IO资源,所以它大大减少了资源占用。在Java NIO中,是通过selector.select()去查询每个通道是否有到达事件,如果没有事件,则一直阻塞在那里,因此这种方式会导致用户线程的阻塞。多路复用IO模式,通过一个线程就可以管理多个socket,只有当socket 真正有读写事件发生才会占用资源来进行实际的读写操作。因此,多路复用IO比较适合连接数比较多的情况。
本HTTP代理软件只能代理HTTP和HTTPS协议,分享出来共广大网友参考和学习
1.Bootstrap类
此类用于创建和启动一个HTTP代理服务
package org.example;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class Bootstrap {
    private final Logger logger = LogManager.getLogger(Bootstrap.class);
    private AbstractEventLoop serverEventLoop;
    private int port;

    public Bootstrap() {
        port = 8888;
        serverEventLoop = new ServerEventLoop(this);
    }

    public Bootstrap bindPort(int port) {
        try {
            this.port = port;
            this.serverEventLoop.bind(port);
        } catch (Exception e) {
            logger.error("open server socket channel error.", e);
        }
        return this;
    }

    public void start() {
        serverEventLoop.getSelector().wakeup();
        logger.info("Proxy server started at port {}.", port);
    }

    public AbstractEventLoop getServerEventLoop() {
        return serverEventLoop;
    }
}


2.ServerEventLoop
事件循环,单线程处理事件循环。包括客户端的连接和读写请求,目标服务器的连接和读写事件,在同一个事件循环中处理。
package org.example;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.example.common.HttpRequestParser;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;

public class ServerEventLoop extends AbstractEventLoop {
    private final Logger logger = LogManager.getLogger(ServerEventLoop.class);

    public ServerEventLoop(Bootstrap bootstrap) {
        super(bootstrap);
    }

    @Override
    protected void processSelectedKey(SelectionKey key) {
        if (key.isValid() && key.isAcceptable()) {
            if (key.attachment() instanceof Acceptor acceptor) {
                acceptor.accept();
            }
        }
        if (key.isValid() && key.isReadable()) {
            if (key.attachment() instanceof ChannelHandler channelHandler) {
                channelHandler.handleRead();
            }
        }
        if (key.isValid() && key.isConnectable()) {
            key.interestOpsAnd(~SelectionKey.OP_CONNECT);
            if (key.attachment() instanceof ChannelHandler channelHandler) {
                channelHandler.handleConnect();
            }
        }
        if (key.isValid() && key.isWritable()) {
            key.interestOpsAnd(~SelectionKey.OP_WRITE);
            if (key.attachment() instanceof ChannelHandler channelHandler) {
                channelHandler.handleWrite();
            }
        }
    }

    @Override
    public void bind(int port) throws Exception {
        ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
        serverSocketChannel.configureBlocking(false);
        SelectionKey key = serverSocketChannel.register(this.selector, SelectionKey.OP_ACCEPT);
        key.attach(new Acceptor(serverSocketChannel));
        serverSocketChannel.bind(new InetSocketAddress(port));
    }

    class Acceptor {
        ServerSocketChannel ssc;

        public Acceptor(ServerSocketChannel ssc) {
            this.ssc = ssc;
        }

        public void accept() {
            try {
                SocketChannel socketChannel = ssc.accept();
                socketChannel.configureBlocking(false);
                socketChannel.register(selector, SelectionKey.OP_READ, new ClientChannelHandler(socketChannel));
                logger.info("accept client connection");
            } catch (IOException e) {
                logger.error("accept error");
            }
        }
    }

    abstract class ChannelHandler {
        Logger logger;
        SocketChannel channel;
        ByteBuffer writeBuffer;

        public ChannelHandler(SocketChannel channel) {
            this.logger = LogManager.getLogger(this.getClass());
            this.channel = channel;
            this.writeBuffer = null;
        }

        abstract void handleRead();

        public void handleWrite() {
            doWrite();
        }

        public abstract void onChannelClose();

        public ByteBuffer doRead() {
            ByteBuffer buffer = ByteBuffer.allocate(4096);
            try {
                int len = channel.read(buffer);
                if (len == -1) {
                    logger.info("read end-of-stream, close channel {}", channel);
                    channel.close();
                    onChannelClose();
                }
                if (len > 0) {
                    buffer.flip();
                }
            } catch (IOException e) {
                logger.error("read channel error");
                try {
                    channel.close();
                    onChannelClose();
                } catch (IOException ex) {
                    logger.error("close channel error.");
                }
            }
            return buffer;
        }

        public void doWrite() {
            if (writeBuffer != null) {
                try {
                    while (writeBuffer.hasRemaining()) {
                        channel.write(writeBuffer);
                    }
                } catch (IOException e) {
                    logger.error("write channel error.");
                    try {
                        channel.close();
                        onChannelClose();
                    } catch (IOException ex) {
                        logger.error("close channel error");
                    }
                }
                writeBuffer = null;
            }
        }

        public void handleConnect() {
        }
    }

    class ClientChannelHandler extends ChannelHandler {
        HttpRequestParser requestParser;
        private SelectableChannel proxyChannel;

        public ClientChannelHandler(SocketChannel sc) {
            super(sc);
            this.channel = sc;
            this.requestParser = new HttpRequestParser();
            this.proxyChannel = null;
        }

        @Override
        public void handleRead() {
            if (requestParser.isParsed()) {
                if (proxyChannel != null) {
                    SelectionKey proxyKey = proxyChannel.keyFor(selector);
                    if (proxyKey != null && proxyKey.isValid() && proxyKey.attachment() instanceof ProxyChannelHandler proxyHandler) {
                        //需要等待ProxyHandler的写入缓存为空后才可读取客户端的数据
                        if (proxyHandler.writeBuffer == null) {
                            ByteBuffer buffer = doRead();
                            if (buffer.hasRemaining() && proxyKey.isValid()) {
                                proxyHandler.writeBuffer = buffer;
                                proxyKey.interestOpsOr(SelectionKey.OP_WRITE);
                            }
                        }
                    }
                }
            } else {
                ByteBuffer buffer = doRead();
                requestParser.putFromByteBuffer(buffer);
                if (requestParser.isParsed()) {
                    //连接到目标服务器
                    ByteBuffer buf = null;
                    if (requestParser.getMethod().equals(HttpRequestParser.HTTP_METHOD_CONNECT)) {
                        //回写客户端连接成功
                        SelectionKey clientKey = channel.keyFor(selector);
                        if (clientKey != null && clientKey.isValid() && clientKey.attachment() instanceof ClientChannelHandler clientHandler) {
                            clientHandler.writeBuffer = ByteBuffer.wrap((requestParser.getProtocol() + " 200 Connection Established\r\n\r\n").getBytes());
                            clientKey.interestOpsOr(SelectionKey.OP_WRITE);
                        }
                    } else {
                        //将缓存的客户端的数据通过代理转发
                        byte[] allBytes = requestParser.getAllBytes();
                        buf = ByteBuffer.wrap(allBytes);
                    }
                    this.proxyChannel = connect(requestParser.getAddress(), buf);
                }
            }

        }

        @Override
        public void onChannelClose() {
            try {
                if (proxyChannel != null) {
                    proxyChannel.close();
                }
            } catch (IOException e) {
                logger.error("close channel error");
            }
        }

        private SocketChannel connect(String address, ByteBuffer buffer) {
            String host = address;
            int port = 80;
            if (address.contains(":")) {
                host = address.split(":")[0].trim();
                port = Integer.parseInt(address.split(":")[1].trim());
            }
            SocketAddress target = new InetSocketAddress(host, port);
            SocketChannel socketChannel = null;
            SelectionKey proxyKey = null;
            int step = 0;
            try {
                socketChannel = SocketChannel.open();
                socketChannel.configureBlocking(false);
                step = 1;
                ProxyChannelHandler proxyHandler = new ProxyChannelHandler(socketChannel);
                proxyHandler.setClientChannel(channel);
                proxyHandler.writeBuffer = buffer;
                proxyKey = socketChannel.register(selector, SelectionKey.OP_CONNECT, proxyHandler);
                proxyKey.interestOpsOr(SelectionKey.OP_WRITE);
                step = 2;
                socketChannel.connect(target);
            } catch (IOException e) {
                logger.error("connect error.");
                switch (step) {
                    case 2:
                        proxyKey.cancel();
                    case 1:
                        try {
                            socketChannel.close();
                        } catch (IOException ex) {
                            logger.error("close channel error.");
                        }
                        socketChannel = null;
                        break;
                }
            }
            return socketChannel;
        }
    }

    class ProxyChannelHandler extends ChannelHandler {
        private SelectableChannel clientChannel;

        public ProxyChannelHandler(SocketChannel sc) {
            super(sc);
            clientChannel = null;
        }

        @Override
        public void handleConnect() {
            try {
                if (channel.isConnectionPending() && channel.finishConnect()) {
                    SelectionKey proxyKey = channel.keyFor(selector);
                    proxyKey.interestOpsOr(SelectionKey.OP_READ);
                }
            } catch (IOException e) {
                try {
                    channel.close();
                    onChannelClose();
                } catch (IOException ex) {
                    logger.error("close channel error.");
                }
                logger.error("finish connection error.");
            }
        }

        @Override
        public void handleRead() {
            if (clientChannel != null) {
                SelectionKey clientKey = clientChannel.keyFor(selector);
                if (clientKey != null && clientKey.isValid() && clientKey.attachment() instanceof ClientChannelHandler clientHandler) {
                    if (clientHandler.writeBuffer == null) {
                        ByteBuffer buffer = doRead();
                        if (buffer.hasRemaining() && clientKey.isValid()) {
                            clientHandler.writeBuffer = buffer;
                            clientKey.interestOpsOr(SelectionKey.OP_WRITE);
                        }
                    }
                }
            }
        }

        @Override
        public void onChannelClose() {
            try {
                if (clientChannel != null) {
                    clientChannel.close();
                }
            } catch (IOException e) {
                logger.error("close channel error");
            }
        }

        public void setClientChannel(SocketChannel client) {
            this.clientChannel = client;
        }
    }
}
3.AbstractEventLoop
事件循环的抽象类
package org.example;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.IOException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.Executors;

public abstract class AbstractEventLoop implements Runnable {
    private final Logger logger = LogManager.getLogger(AbstractEventLoop.class);
    protected Selector selector;
    protected Bootstrap bootstrap;

    public AbstractEventLoop(Bootstrap bootstrap) {
        this.bootstrap = bootstrap;
        openSelector();
        Executors.newSingleThreadExecutor().submit(this);
    }

    public void bind(int port) throws Exception {
        throw new Exception("not support");
    }

    @Override
    public void run() {
        while (true) {
            try {
                if (selector.select() > 0) {
                    processSelectedKeys();
                }
            } catch (Exception e) {
                logger.error("select error.", e);
            }
        }
    }

    private void processSelectedKeys() {
        Set<SelectionKey> keys = selector.selectedKeys();
        Iterator<SelectionKey> iterator = keys.iterator();
        while (iterator.hasNext()) {
            SelectionKey key = iterator.next();
            iterator.remove();
            processSelectedKey(key);
        }
    }
    protected abstract void processSelectedKey(SelectionKey key);
    public Selector openSelector() {
        try {
            this.selector = Selector.open();
            return this.selector;
        } catch (IOException e) {
            logger.error("open selector error.", e);
        }
        return null;
    }

    public Selector getSelector() {
        return selector;
    }
}

4.HttpRequestParser
用于解析HTTP请求报文中的请求头,可以获取主机和端口号
package org.example.common;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;

public class HttpRequestParser {
    private final Logger logger = LogManager.getLogger(HttpRequestParser.class);
    public static final String COLON = ":";
    public static final String REQUEST_HEADER_HOST_PREFIX = "host:";
    private UnboundedByteBuffer requestBytes = new UnboundedByteBuffer();
    private List<String> headers = new ArrayList<>();
    public static final String HTTP_METHOD_GET = "GET";
    public static final String HTTP_METHOD_POST = "POST";
    public static final String HTTP_METHOD_PUT = "PUT";
    public static final String HTTP_METHOD_DELETE = "DELETE";
    public static final String HTTP_METHOD_TRACE = "TRACE";
    public static final String HTTP_METHOD_OPTIONS = "OPTIONS";
    public static final String HTTP_METHOD_HEAD = "HEAD";
    public static final String HTTP_METHOD_CONNECT = "CONNECT";

    private String address;

    private String protocol;
    private String method;

    private boolean parsed = false;

    private StringBuffer reqHeaderBuffer = new StringBuffer();

    public void putFromByteBuffer(ByteBuffer buffer) {
        for (; buffer.hasRemaining(); ) {
            byte b = buffer.get();
            requestBytes.addByte(b);
            reqHeaderBuffer.append((char) b);
            if (b == '\n' && reqHeaderBuffer.charAt(reqHeaderBuffer.length() - 2) == '\r') {
                if (reqHeaderBuffer.length() == 2) {
                    parsed = true;
                    logger.debug("Request header line end.");
                    break;
                }
                String headerLine = reqHeaderBuffer.substring(0, reqHeaderBuffer.length() - 2);
                logger.debug("Request header line parsed {}", headerLine);
                headers.add(headerLine);
                if (headerLine.startsWith(HTTP_METHOD_GET)
                        || headerLine.startsWith(HTTP_METHOD_POST)
                        || headerLine.startsWith(HTTP_METHOD_PUT)
                        || headerLine.startsWith(HTTP_METHOD_DELETE)
                        || headerLine.startsWith(HTTP_METHOD_TRACE)
                        || headerLine.startsWith(HTTP_METHOD_OPTIONS)
                        || headerLine.startsWith(HTTP_METHOD_HEAD)
                        || headerLine.startsWith(HTTP_METHOD_CONNECT)) {
                    this.protocol = headerLine.split(" ")[2].trim();
                    this.method = headerLine.split(" ")[0].trim();
                } else if (headerLine.toLowerCase().startsWith(REQUEST_HEADER_HOST_PREFIX)) {
                    this.address = headerLine.toLowerCase().replace(REQUEST_HEADER_HOST_PREFIX, "").trim();
                }
                reqHeaderBuffer.delete(0, reqHeaderBuffer.length());
            }
        }
    }

    public boolean isParsed() {
        return parsed;
    }

    public String getAddress() {
        return address;
    }

    public String getProtocol() {
        return protocol;
    }

    public String getMethod() {
        return method;
    }

    public byte[] getAllBytes() {
        return requestBytes.toByteArray();
    }
}

5.UnboundedByteBuffer
无界的字节缓冲区,每次会以两倍的容量扩容,可以用于追加存入客户端的请求数据,实现粘包
package org.example.common;

public class UnboundedByteBuffer {
    private byte[] bytes;
    private int size;
    private int cap;
    private final int DEFAULT_CAP = 4096;
    private final int MAX_CAP = 1 << 30;

    public UnboundedByteBuffer() {
        this.cap = DEFAULT_CAP;
        this.bytes = new byte[this.cap];
        this.size = 0;
    }

    public void addBytes(byte[] data) {
        ensureCapacity(data.length);
        System.arraycopy(data, 0, bytes, size, data.length);
        this.size += data.length;
    }

    private void ensureCapacity(int scale) {
        if (scale + this.size > this.cap) {
            int tmpCap = this.cap;
            while (scale + this.size > tmpCap) {
                tmpCap = tmpCap << 1;
            }
            if (tmpCap > MAX_CAP) {
                return;
            }
            byte[] newBytes = new byte[tmpCap];
            System.arraycopy(this.bytes, 0, newBytes, 0, this.size);
            this.bytes = newBytes;
        }
    }

    public byte[] toByteArray() {
        byte[] ret = new byte[this.size];
        System.arraycopy(this.bytes, 0, ret, 0, this.size);
        return ret;
    }

    public void addByte(byte b) {
        ensureCapacity(1);
        this.bytes[this.size++] = b;
    }
}

以上实现是在单个事件循环线程中处理所有事件,一个更好的方案是将客户端的Channel和代理服务器与目标服务器的Channel区分开,分别在两个事件循环中处理。基本实现也和本文中的代码大体一致,两者在理论上应该存在性能差距,实际经过本人测试可以每秒处理客户端的上千个连接。代码传送门

原文地址:https://blog.csdn.net/Yuriey/article/details/143632928

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!