自学内容网 自学内容网

gRPC-拦截器

简介

在构建 gRPC 应用程序时,无论是客户端应用程序,还是服务器端应用程序,在远程方法执行之前或之后,都可能需要执行一些通用逻辑。

gRPC 提供了简单的 API,用来在客户端和服务器端的 gRPC 应用程序中实现并安装拦截器。它是 gRPC 核心扩展机制之一,在一些使用场景中(如日志、身份验证、授权、性能度量指标、跟踪以及其他一些自定义需求),拦截器拦截每个 RPC 调用的执行,可以使用拦截器进行日志记录、身份验证/授权、指标收集以及许多其他可以跨 RPC 共享的功能。

在 gRPC 应用程序中,拦截器根据拦截的 RPC 调用类型可以分为以下的两大类:
第一个是一元拦截器(unary interceptor),它拦截一元 RPC 的调用;
第二个是流拦截器(streaming interceptor),它处理流式 RPC 的调用;
客户端和服务端都可使用一元拦截器和流拦截器。

一元拦截器

客户端
public static void main(String[] args) {
    ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 50050)
            .usePlaintext()
            .intercept(new ClientLoggingInterceptor())
            .build();
    ProductInfoGrpc.ProductInfoBlockingStub stub = ProductInfoGrpc.newBlockingStub(channel);
    ProductId productId = ProductId.newBuilder().setValue("1").build();
    Product product = stub.getProduct(productId);
    System.out.println("product.getName() = " + product.getName());
    channel.shutdown();
}
public static class ClientLoggingInterceptor implements ClientInterceptor{
   @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> methodDescriptor, CallOptions callOptions, Channel next) {
        System.out.println("执行ClientLoggingInterceptor拦截器...");
        ClientCall<ReqT, RespT> clientCall = next.newCall(methodDescriptor, callOptions);
        // 调用下一个拦截器
        return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(clientCall) {
            @Override
            public void start(Listener<RespT> responseListener, Metadata headers) {
                // 在调用开始前执行
                System.out.println("客户端调用:" + methodDescriptor.getFullMethodName());
                super.start(new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(responseListener) {
                    @Override
                    public void onMessage(RespT message) {
                        // 收到响应后执行
                        System.out.println("服务端返回:" + message);
                        super.onMessage(message);
                    }
                }, headers);
            }

        };
    }
}
服务端
public void start() throws IOException {
    int port = 50050;
    server = ServerBuilder.forPort(port)
            .addService(new ProductInfoImpl())
            .intercept(new ServerExecuteTimeInterceptor())
            .build()
            .start();
    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
        ProductInfoServer.this.stop();
    }));
    System.out.println("server start on port 50050");
}
public static class ServerExecuteTimeInterceptor implements ServerInterceptor{
    @Override
    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
        String methodName = serverCall.getMethodDescriptor().getFullMethodName();
        System.out.println("receive request :" + methodName);
        ServerCall.Listener<ReqT> listener = next.startCall(serverCall, headers);
        return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(listener) {
            long start=0, end=0;
            String method  = methodName;
            @Override
            public void onHalfClose() {
                System.out.println("client half close");
                super.onHalfClose();
            }
            @Override
            public void onCancel() {
                System.out.println("client cancel");
                super.onCancel();
            }
            @Override
            public void onComplete() {
                System.out.println("call complete");
                super.onComplete();
                end = System.currentTimeMillis();
                System.out.println("请求:"+method+"耗时:" + (end-start));
            }

            @Override
            public void onMessage(ReqT message) {
                System.out.println("收到客户端消息:" + message);
                super.onMessage(message);
                start = System.currentTimeMillis();
            }
        };
    }
}

流拦截器

客户端
public static void main(String[] args) throws Exception{
    CountDownLatch countDownLatch = new CountDownLatch(1);
    ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 50050)
            .usePlaintext()
            .intercept(new ClientStreamingInterceptor())
            .build();
    ProductInfoGrpc.ProductInfoStub stub = ProductInfoGrpc.newStub(channel);
    StreamObserver<Product> requestObserver = stub.saveProductBatch(new StreamObserver<ProductResult>() {
        @Override
        public void onNext(ProductResult productResult) {
            System.out.println("服务端返回:" + productResult.getSuccess());
        }
        @Override
        public void onError(Throwable throwable) {
        }
        @Override
        public void onCompleted() {
            countDownLatch.countDown();
        }
    });
    for (int i = 0; i < 10; i++) {
        Product p = Product.newBuilder().setId(""+i).setName("p"+i).build();
        System.out.println("客户端发送:" + p);
        requestObserver.onNext(p);
    }
    requestObserver.onCompleted();
    System.out.println("客户端发送完成");
    countDownLatch.await();
}
public static class ClientStreamingInterceptor implements ClientInterceptor{
    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
        System.out.println("执行ClientStreamingInterceptor拦截器...");
        //把自己开发的ClientStreamTracerFactory融入到gRPC体系
        callOptions = callOptions.withStreamTracerFactory(new ClientStreamTracer.Factory() {
            @Override
            public ClientStreamTracer newClientStreamTracer(ClientStreamTracer.StreamInfo info, Metadata headers) {
                return new ClientStreamTracer() {
                    @Override
                    //用于输出响应头
                    public void outboundHeaders() {
                        System.out.println("client: 用于输出请求头.....");
                        super.outboundHeaders();
                    }

                    @Override
                    //设置消息编号
                    public void outboundMessage(int seqNo) {
                        System.out.println("client: 设置流消息的编号: " + seqNo);
                        super.outboundMessage(seqNo);
                    }

                    @Override
                    public void outboundUncompressedSize(long bytes) {
                        System.out.println("client: 获得未压缩消息的大小:" + bytes);
                        super.outboundUncompressedSize(bytes);
                    }

                    @Override
                    //用于获得 输出消息的大小
                    public void outboundWireSize(long bytes) {
                        System.out.println("client: 用于获得 输出消息的大小:" + bytes);
                        super.outboundWireSize(bytes);
                    }

                    @Override
                    //拦截消息发送
                    public void outboundMessageSent(int seqNo, long optionalWireSize, long optionalUncompressedSize) {
                        System.out.println("client: 监控请求操作 outboundMessageSent:" + seqNo);
                        super.outboundMessageSent(seqNo, optionalWireSize, optionalUncompressedSize);
                    }

                    //inbound  对于相应相关操作的拦截
                    @Override
                    public void inboundHeaders() {
                        System.out.println("用于获得响应头....");
                        super.inboundHeaders();
                    }

                    @Override
                    public void inboundMessage(int seqNo) {
                        System.out.println("获得响应消息的编号..." + seqNo);
                        super.inboundMessage(seqNo);
                    }

                    @Override
                    public void inboundWireSize(long bytes) {
                        System.out.println("获得响应消息的大小... " + bytes);
                        super.inboundWireSize(bytes);
                    }

                    @Override
                    public void inboundMessageRead(int seqNo, long optionalWireSize, long optionalUncompressedSize) {
                        System.out.println("集中获得消息的编号 ,大小 ,未压缩大小..." + seqNo +" " + optionalWireSize +" "+ optionalUncompressedSize);
                        super.inboundMessageRead(seqNo, optionalWireSize, optionalUncompressedSize);
                    }

                    @Override
                    public void inboundUncompressedSize(long bytes) {
                        System.out.println("获得响应消息未压缩大小..." + bytes);
                        super.inboundUncompressedSize(bytes);
                    }

                    @Override
                    public void inboundTrailers(Metadata trailers) {
                        System.out.println("响应结束..");
                        super.inboundTrailers(trailers);
                    }
                };
            }
        });
        return next.newCall(method, callOptions);
    }
}
服务端
public void start() throws IOException {
    int port = 50050;
    server = ServerBuilder.forPort(port)
            .addService(new ProductInfoImpl())
            .addStreamTracerFactory(new ServerStreamingInterceptor())
            .build()
            .start();
    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
        ProductInfoServer.this.stop();
    }));
    System.out.println("server start on port 50050");
}
public static class ServerStreamingInterceptor extends ServerStreamTracer.Factory{
    @Override
    public ServerStreamTracer newServerStreamTracer(String s, Metadata metadata) {
        return new ServerStreamTracer(){
            @Override
            public void inboundMessage(int seqNo) {
                super.inboundMessage(seqNo);
            }

            @Override
            public void inboundWireSize(long bytes) {
                super.inboundWireSize(bytes);
            }

            @Override
            public void inboundMessageRead(int seqNo, long optionalWireSize, long optionalUncompressedSize) {
                System.out.println("server: 获得client发送的请求消息 ..." + seqNo+","+optionalWireSize+","+optionalUncompressedSize);
                super.inboundMessageRead(seqNo, optionalWireSize, optionalUncompressedSize);
            }

            @Override
            public void inboundUncompressedSize(long bytes) {
                super.inboundUncompressedSize(bytes);
            }

            //outbound 拦截请求
            @Override
            public void outboundMessage(int seqNo) {
                super.outboundMessage(seqNo);
            }


            @Override
            public void outboundMessageSent(int seqNo, long optionalWireSize, long optionalUncompressedSize) {
                System.out.println("server: 响应数据的拦截 ..." + seqNo+","+optionalWireSize+","+optionalUncompressedSize);
                super.outboundMessageSent(seqNo, optionalWireSize, optionalUncompressedSize);
            }

            @Override
            public void outboundWireSize(long bytes) {
                super.outboundWireSize(bytes);
            }

            @Override
            public void outboundUncompressedSize(long bytes) {
                super.outboundUncompressedSize(bytes);
            }
        };
    }
}

完整的源码下载:https://github.com/xjs1919/learning-demo/tree/master/grpc-demo


原文地址:https://blog.csdn.net/goldenfish1919/article/details/143483597

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!