自学内容网 自学内容网

netty使用redis发布订阅实现消息推送

netty使用redis发布订阅实现消息推送

场景

项目中需要给用户推送消息:

在这里插入图片描述

接口

@RestController
public class PushApi {

    @Autowired
    private PushService pushService;

    /**
     * 消息推送
     * @param query
     * @return
     */
    @PostMapping("/push/message")
    public String push(@RequestBody MessagePushConfigDto query){
        pushService.push(query);
        return "success";
    }


}

@Component
@Slf4j
public class PushService {

    @Autowired
    private StringRedisTemplate redisTemplate;

    @Autowired
    private MessageService messageService;

    public void push(MessagePushConfigDto query) {
        String messageNo = UUID.randomUUID().toString();
        if (query.getType()== Constants.MSG_TYPE_ALL){
            doPushGroup(query, messageNo);
        }else {
            doPushToUser(query, messageNo);
        }
    }

    private void doPushGroup(MessagePushConfigDto query, String messageNo) {
        MessageDto dto = new MessageDto();
        dto.setModule(query.getModule());
        dto.setType(query.getType());
        dto.setMessageNo(messageNo);
        dto.setContent(query.getContent());
        //转发至其他节点
        redisTemplate.convertAndSend(Constants.TOPIC_MODULE, JSON.toJSONString(dto));
    }

    private void doPushToUser(MessagePushConfigDto query, String messageNo) {
        for (String identityNo : query.getIdentityList()) {
            MessageDto dto = new MessageDto();
            dto.setModule(query.getModule());
            dto.setType(query.getType());
            dto.setMessageNo(messageNo);
            dto.setContent(query.getContent());
            dto.setIdentityNo(identityNo);

            String key = MessageFormat.format(Constants.USER_KEY, query.getModule(),identityNo);
            String nodeIp = redisTemplate.opsForValue().get(key);
            if (StrUtil.isBlank(nodeIp)){
                log.info("no user found: {}-{}",identityNo, key);
                return;
            }
            if (NodeConfig.node.equals(nodeIp)){
                log.info("send from local: {}", identityNo);
                messageService.sendToUser(dto.getMessageNo(),dto.getModule(),dto.getIdentityNo(),dto.getContent());
            }else {
                //转发至其他节点
                redisTemplate.convertAndSend(Constants.TOPIC_USER, JSON.toJSONString(dto));
            }
        }
    }
}

实体

//发送的消息
@Data
public class MessageDto {

    private String module;
    /**
     * 1、指定用户
     * 2、全部
     */
    private Integer type;
    private String messageNo;
    private String content;
    private String identityNo;

}

//消息配置
@Data
public class MessagePushConfigDto {

    private String module;
    /**
     * 1、指定用户
     * 2、全部
     */
    private Integer type;
    private String content;
    private List<String> identityList;

}

//常量
public interface Constants {

    int MSG_TYPE_ALL = 1;
    int MSG_TYPE_SINGLE = 0;

    String TOPIC_MODULE = "topic:module";
    String TOPIC_USER = "topic:module:user";


    String USER_KEY = "socket:module:{0}:userId:{1}";
}
MessageService 发送消息接口
public interface MessageService {

    /**
     * 发送组
     * @param messageNo
     * @param module
     * @param content
     */
    void sendToGroup(String messageNo, String module, String content);

    /**
     * 单用户发送
     * @param messageNo
     * @param module
     * @param identityNo
     * @param content
     */
    void sendToUser(String messageNo, String module, String identityNo, String content);
}

public class MessageServiceImpl implements MessageService {

    private SessionRegistry sessionRegistry;

    public MessageServiceImpl(SessionRegistry sessionRegistry) {
        this.sessionRegistry = sessionRegistry;
    }

    @Override
    public void sendToGroup(String messageNo, String module, String content) {
        SessionGroup sessionGroup = sessionRegistry.retrieveGroup(module);
        if (!Objects.isNull(sessionGroup)){
            sessionGroup.sendGroup(content);
        }
    }

    @Override
    public void sendToUser(String messageNo, String module, String identityNo, String content) {
        WssSession wssSession = sessionRegistry.retrieveSession(module, identityNo);
        if (!Objects.isNull(wssSession)){
            wssSession.send(content);
        }
    }
}
SessionService

操作 session 服务,并设置 用户到redis

public interface SessionService<WS extends WssSession<C>,C> {

    /**
     * 添加session
     * @param session
     */
    void addSession(WS session);

    /**
     * 删除session
     * @param session
     */
    void removeSession(WS session);

}
public abstract class AbstractSessionService<SR extends SessionRegistry<WS, C>, WS extends WssSession<C>, C>
        implements SessionService<WS, C> {

    @Getter
    private SR sessionRegistry;

    public AbstractSessionService(SR sessionRegistry) {
        this.sessionRegistry = sessionRegistry;
    }
}

public class SessionServiceImpl<SR extends SessionRegistry<WS, C>, WS extends WssSession<C>, C>
        extends AbstractSessionService<SR, WS, C> {

    private StringRedisTemplate redisTemplate;

    public SessionServiceImpl(SR sessionRegistry, StringRedisTemplate redisTemplate) {
        super(sessionRegistry);
        this.redisTemplate = redisTemplate;
    }

    @Override
    public void addSession(WS session) {
        getSessionRegistry().addSession(session);
        String key = MessageFormat.format(Constants.USER_KEY, session.getModule(), session.getIdentityNo());
        redisTemplate.opsForValue().set(key, NodeConfig.node);
    }

    @Override
    public void removeSession(WS session) {
        getSessionRegistry().removeSession(session);
        String key = MessageFormat.format(Constants.USER_KEY, session.getModule(), session.getIdentityNo());
        redisTemplate.delete(key);
    }

}

websocket 实现

定义session接口相关

public interface WssSession<C> {

    /**
     * 模块
     * @return
     */
    String getModule();

    /**
     * 用户唯一标识
     * @return
     */
    String getIdentityNo();

    /**
     * 通信渠道
     * @return
     */
    C getChannel();

    /**
     * 发送消息
     * @param message
     */
    void send(String message);

}

public interface SessionGroup <T extends WssSession<C>, C>{

    /**
     * add session
     * @param session
     */
    void addSession(T session);

    /**
     * remove session
     * @param session
     */
    void removeSession(T session);

    /**
     * 发送组数据
     * @param message
     */
    void sendGroup(String message);

    /**
     * 根据唯一标识查询session
     * @param identityNo
     * @return
     */

    T getSession(String identityNo);

}
public interface SessionRegistry<T extends WssSession<C>, C> {

    /**
     * 添加 session
     *
     * @param session
     */
    void addSession(T session);

    /**
     * 移除 session
     *
     * @param session
     */
    void removeSession(T session);

    /**
     * 查询 SessionGroup
     * @param module
     * @return
     */
    SessionGroup<T, C> retrieveGroup(String module);

    /**
     * 查询 session
     * @param module
     * @param identityNo
     * @return
     */
    T retrieveSession(String module, String identityNo);

}

public abstract class AbstractSession<C> implements WssSession<C>{

    private String module;
    private String identityNo;
    private C channel;


    public AbstractSession(String module, String identityNo, C channel) {
        this.module = module;
        this.identityNo = identityNo;
        this.channel = channel;
    }

    @Override
    public String getModule() {
        return module;
    }

    @Override
    public String getIdentityNo() {
        return identityNo;
    }

    @Override
    public C getChannel() {
        return channel;
    }
}

public abstract class AbstractSessionRegistry<T extends WssSession<C>, C> implements SessionRegistry<T, C> {

    private Map<String, SessionGroup<T, C>> map = new ConcurrentHashMap<>();

    @Override
    public void addSession(T session) {
        SessionGroup<T, C> sessionGroup = map.computeIfAbsent(session.getModule(), key -> newSessionGroup());
        sessionGroup.addSession(session);
    }

    protected abstract SessionGroup<T, C> newSessionGroup();

    @Override
    public void removeSession(T session) {
        SessionGroup<T, C> sessionGroup = map.get(session.getModule());
        sessionGroup.removeSession(session);
    }

    @Override
    public SessionGroup<T, C> retrieveGroup(String module) {
        return map.get(module);
    }

    @Override
    public T retrieveSession(String module, String identityNo) {
        SessionGroup<T, C> sessionGroup = map.get(module);
        if (sessionGroup != null) {
            return (T) sessionGroup.getSession(identityNo);
        }
        return null;
    }
}

使用 netty 容器

@Slf4j
@Component
public class NettyServer {

    private NioEventLoopGroup boss;
    private NioEventLoopGroup worker;


    @Value("${namespace:/ns}")
    private String namespace;

    @Autowired
    private SessionService sessionService;

    @PostConstruct
    public void start() {
        try {
            boss = new NioEventLoopGroup(1);
            worker = new NioEventLoopGroup();

            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap.group(boss, worker).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<SocketChannel>() {
                @Override
                protected void initChannel(SocketChannel ch) throws Exception {
                    ChannelPipeline pipeline = ch.pipeline();
                    pipeline.addLast(new IdleStateHandler(0, 0, 60));
                    pipeline.addLast(new HeartBeatInboundHandler());
                    pipeline.addLast(new HttpServerCodec());
                    pipeline.addLast(new HttpObjectAggregator(64 * 1024));
                    pipeline.addLast(new ChunkedWriteHandler());
                    pipeline.addLast(new HttpRequestInboundHandler(namespace));
                    pipeline.addLast(new WebSocketServerProtocolHandler(namespace, true));
                    pipeline.addLast(new WebSocketHandShakeHandler(sessionService));
                }
            });
            int port = 9999;
            serverBootstrap.bind(port).addListener((ChannelFutureListener) future -> {
                if (future.isSuccess()) {
                    log.info("server start at port successfully: {}", port);
                } else {
                    log.info("server start at port error: {}", port);
                }
            }).sync();
        } catch (InterruptedException e) {
            log.error("start error", e);
            close();
        }
    }

    @PreDestroy
    public void destroy() {
        close();
    }

    private void close() {
        log.info("websocket server close..");
        if (boss != null) {
            boss.shutdownGracefully();
        }
        if (worker != null) {
            worker.shutdownGracefully();
        }
    }


}

public class NettySessionGroup implements SessionGroup<NWssSession,Channel> {

    private ChannelGroup group = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
    //Map<identityNo,channel>
    private Map<String, NWssSession> map = new ConcurrentHashMap<>();
    @Override
    public void addSession(NWssSession session) {
        group.add(session.getChannel());
        map.put(session.getIdentityNo(), session);
    }

    @Override
    public void removeSession(NWssSession session) {
        group.remove(session.getChannel());
        map.remove(session.getIdentityNo());
    }

    @Override
    public void sendGroup(String message){
        group.writeAndFlush(new TextWebSocketFrame(message));
    }

    @Override
    public NWssSession getSession(String identityNo) {
        return map.get(identityNo);
    }


}


public class NettySessionRegistry extends AbstractSessionRegistry<NWssSession, Channel> {

    @Override
    protected SessionGroup<NWssSession, Channel> newSessionGroup() {
        return new NettySessionGroup();
    }
}

public class NWssSession extends AbstractSession<Channel> {

    public NWssSession(String module, String identityNo, Channel channel) {
        super(module, identityNo, channel);
    }

    @Override
    public void send(String message) {
        getChannel().writeAndFlush(new TextWebSocketFrame(message));
    }
}

public class NettyUtil {

    //参数-module<->user-code
    public static AttributeKey<String> G_U = AttributeKey.valueOf("GU");
    //参数-uri
    public static AttributeKey<String> P = AttributeKey.valueOf("P");

    /**
     * 设置上下文参数
     *
     * @param channel
     * @param attributeKey
     * @param data
     * @param <T>
     */
    public static <T> void setAttr(Channel channel, AttributeKey<T> attributeKey, T data) {
        Attribute<T> attr = channel.attr(attributeKey);
        if (attr != null) {
            attr.set(data);
        }
    }

    /**
     * 获取上下文参数
     *
     * @param channel
     * @param attributeKey
     * @param <T>
     * @return
     */
    public static <T> T getAttr(Channel channel, AttributeKey<T> attributeKey) {
        return channel.attr(attributeKey).get();
    }

    /**
     * 根据 渠道获取 session
     *
     * @param channel
     * @return
     */
    public static NWssSession getSession(Channel channel) {
        String attr = channel.attr(G_U).get();
        if (StrUtil.isNotBlank(attr)) {
            String[] split = attr.split(",");
            String groupId = split[0];
            String username = split[1];
            return new NWssSession(groupId, username, channel);
        }
        return null;
    }


    public static void writeForbiddenRepose(ChannelHandlerContext ctx) {
        String res = "FORBIDDEN";
        FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN, Unpooled.wrappedBuffer(res.getBytes(StandardCharsets.UTF_8)));
        response.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain");
        response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
        ctx.writeAndFlush(response);
        ctx.close();
    }


}

public interface WebSocketListener {

    void handShakeSuccessful(ChannelHandlerContext ctx, String uri);

    void handShakeFailed(ChannelHandlerContext ctx,String uri);
}

//解析 request uri参数
@Slf4j
public class DefaultWebSocketListener implements WebSocketListener {
    private static final String G = "module";
    private static final String U = "userCode";

    @Override
    public void handShakeSuccessful(ChannelHandlerContext ctx, String uri) {
        QueryStringDecoder decoderQuery = new QueryStringDecoder(uri);
        Map<String, List<String>> params = decoderQuery.parameters();
        String groupId = getParameter(G, params);
        String userCode = getParameter(U, params);
        if (StrUtil.isBlank(groupId) || StrUtil.isBlank(userCode)) {
            log.info("module or userCode is null: {}", uri);
            NettyUtil.writeForbiddenRepose(ctx);
            return;
        }
        //传递参数
        NettyUtil.setAttr(ctx.channel(), NettyUtil.G_U, groupId.concat(",").concat(userCode));
    }

    @Override
    public void handShakeFailed(ChannelHandlerContext ctx, String uri) {
        log.info("handShakeFailed failed,close channel");
        ctx.close();
    }

    private String getParameter(String key, Map<String, List<String>> params) {
        if (CollectionUtils.isEmpty(params)) {
            return null;
        }
        List<String> value = params.get(key);
        if (CollectionUtils.isEmpty(value)) {
            return null;
        }
        return value.get(0);
    }

}

netty handler
//心跳
@Slf4j
public class HeartBeatInboundHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof IdleStateEvent ise){
            if (ise.state()== IdleState.ALL_IDLE){
                //关闭连接
                log.info("HeartBeatInboundHandler heart beat close");
                ctx.channel().close();
                return;
            }
        }
        super.userEventTriggered(ctx,evt);
    }


}

/**
 * @Date: 2024/7/17 13:06
 * 处理 http 协议 的请求参数并传递
 */
@Slf4j
public class HttpRequestInboundHandler extends ChannelInboundHandlerAdapter {
    private String namespace;

    public HttpRequestInboundHandler(String namespace) {
        this.namespace = namespace;
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof FullHttpRequest request) {
            //ws://localhost:8080/n/ws?groupId=xx&username=tom
            String requestUri = request.uri();
            String decode = URLDecoder.decode(requestUri, StandardCharsets.UTF_8);
            log.info("raw request url: {}", decode);
            URI uri = new URI(requestUri);
            if (!uri.getPath().startsWith(namespace)) {
                NettyUtil.writeForbiddenRepose(ctx);
                return;
            }
            // TODO: 2024/7/17 校验token
            // 比如从 header中获取token

            // 构建自定义WebSocket握手处理器, 也可以使用 netty自带 WebSocketServerProtocolHandler
            //shakeHandsIfNecessary(ctx, request, requestUri);
            //去掉参数 ===>  ws://localhost:8080/n/ws
            //传递参数
            NettyUtil.setAttr(ctx.channel(), NettyUtil.P, requestUri);
            request.setUri(namespace);
            ctx.pipeline().remove(this);
            ctx.fireChannelRead(request);
        }
    }
/*

    private void shakeHandsIfNecessary(ChannelHandlerContext ctx, FullHttpRequest request, String requestUri) {
        WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                getWebSocketLocation(request), null, true);
        WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(request);
        if (handshaker == null) {
            // 如果不支持WebSocket版本,返回HTTP 405错误
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
        } else {
            ChannelPipeline pipeline = ctx.channel().pipeline();
            handshaker.handshake(ctx.channel(), request).addListener((ChannelFutureListener) future -> {
                if (future.isSuccess()) {
                    //握手成功 WebSocketListener listener
                    listener.handShakeSuccessful(ctx, requestUri);
                } else {
                    //握手失败
                    listener.handShakeFailed(ctx, requestUri);
                }

            });
        }
    }

    private String getWebSocketLocation(FullHttpRequest req) {
        return "ws://" + req.headers().get(HttpHeaderNames.HOST) + prefix;
    }

*/
}

@Slf4j
public class WebSocketBizHandler extends SimpleChannelInboundHandler<WebSocketFrame> {

    private SessionService sessionService;

    public WebSocketBizHandler(SessionService sessionService){
        this.sessionService = sessionService;
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        log.info("handlerAdded");
        NWssSession session = NettyUtil.getSession(ctx.channel());
        if (session == null) {
            log.info("session is null: {}", ctx.channel().id());
            NettyUtil.writeForbiddenRepose(ctx);
            return;
        }
        sessionService.addSession(session);
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        log.info("handlerRemoved");
        NWssSession session = NettyUtil.getSession(ctx.channel());
        if (session == null) {
            log.info("session is null: {}", ctx.channel().id());
            return;
        }
        sessionService.removeSession(session);
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception {
        if (msg instanceof TextWebSocketFrame) {

        } else if (msg instanceof BinaryWebSocketFrame) {

        } else if (msg instanceof PingWebSocketFrame) {

        } else if (msg instanceof PongWebSocketFrame) {

        } else if (msg instanceof CloseWebSocketFrame) {
            if (ctx.channel().isActive()) {
                ctx.close();
            }
        }
        ctx.writeAndFlush(new TextWebSocketFrame("默认回复"));
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        //处理最后的业务异常
        log.info("WebSocketBizHandler error: ", cause);
    }
}

//处理websocket协议握手
@Slf4j
public class WebSocketHandShakeHandler extends ChannelInboundHandlerAdapter {

    private SessionService sessionService;
    private WebSocketListener webSocketListener = new DefaultWebSocketListener();

    public WebSocketHandShakeHandler(SessionService sessionService) {
        this.sessionService = sessionService;
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
            log.info("WebSocketHandShakeHandler shake-hands success");
            // 在此处获取URL、Headers等信息并做校验,通过throw异常来中断链接。
            String uri = NettyUtil.getAttr(ctx.channel(), NettyUtil.P);
            if (StrUtil.isBlank(uri)) {
                log.info("request uri is null");
                NettyUtil.writeForbiddenRepose(ctx);
                return;
            }
            webSocketListener.handShakeSuccessful(ctx, uri);
            ChannelPipeline pipeline = ctx.channel().pipeline();
            pipeline.addLast(new WebSocketBizHandler(sessionService));
            pipeline.remove(this);
            return;
        }
        super.userEventTriggered(ctx, evt);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        if (cause instanceof WebSocketHandshakeException) {
            //只处理 websocket 握手相关异常
            FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.BAD_REQUEST,
                    Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
            ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
            return;
        }
        super.exceptionCaught(ctx,cause);
    }

}


配置

@Component
public class NodeConfig {

    public static String node;

    @PostConstruct
    public void init() {
        String localhostStr = NetUtil.getLocalhostStr();
        NodeConfig.node = localhostStr;
        Assert.notNull(NodeConfig.node, "local ip is null");
    }
}

@Slf4j
@Configuration
public class RedisPublishConfig {

    @Bean
    public RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory, MessageListener messageListener) {
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        container.setConnectionFactory(connectionFactory);
        List<PatternTopic> topicList = new ArrayList<>();
        topicList.add(new PatternTopic(Constants.TOPIC_USER));
        topicList.add(new PatternTopic(Constants.TOPIC_MODULE));
        container.addMessageListener(messageListener, topicList);
        log.info("RedisMessageListenerContainer listen topic: {}", Constants.TOPIC_USER);
        return container;
    }

}

@Slf4j
@Component
public class RedisPublisherListener implements MessageListener {

    @Autowired
    private RedisPublisherConsumer messageService;

    @Override
    public void onMessage(Message message, byte[] pattern) {
        try {
            String topic = new String(pattern);
            String msg = new String(message.getBody(), "utf-8");
            log.info("recv topic:{}, msg: {}", topic, msg);
            messageService.consume(topic, msg);
        } catch (UnsupportedEncodingException e) {
            log.error("recv msg error: {}", new String(pattern), e);
        }
    }
}

@Configuration
public class WebSocketConfig {


    @Bean
    public NettySessionRegistry sessionRegistry() {
        return new NettySessionRegistry();
    }

    @Bean
    public SessionService<NWssSession, Channel> sessionService(StringRedisTemplate redisTemplate) {
        return new SessionServiceImpl<>(sessionRegistry(), redisTemplate);
    }

    @Bean
    public MessageService messageService() {
        return new MessageServiceImpl(sessionRegistry());
    }

}

good luck!


原文地址:https://blog.csdn.net/u013887008/article/details/140618184

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!