WebSocketConfig.java 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. package com.zsElectric.boot.config;
  2. import cn.hutool.core.util.StrUtil;
  3. import com.zsElectric.boot.security.model.SysUserDetails;
  4. import com.zsElectric.boot.security.token.TokenManager;
  5. import com.zsElectric.boot.system.service.WebSocketService;
  6. import lombok.extern.slf4j.Slf4j;
  7. import org.jetbrains.annotations.NotNull;
  8. import org.springframework.context.annotation.Configuration;
  9. import org.springframework.context.annotation.Lazy;
  10. import org.springframework.http.HttpHeaders;
  11. import org.springframework.messaging.Message;
  12. import org.springframework.messaging.MessageChannel;
  13. import org.springframework.messaging.MessagingException;
  14. import org.springframework.messaging.simp.config.ChannelRegistration;
  15. import org.springframework.messaging.simp.config.MessageBrokerRegistry;
  16. import org.springframework.messaging.simp.stomp.StompCommand;
  17. import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
  18. import org.springframework.messaging.support.ChannelInterceptor;
  19. import org.springframework.messaging.support.MessageHeaderAccessor;
  20. import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
  21. import org.springframework.security.authentication.BadCredentialsException;
  22. import org.springframework.security.core.Authentication;
  23. import org.springframework.security.core.AuthenticationException;
  24. import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
  25. import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
  26. import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
  27. /**
  28. * WebSocket 配置类
  29. *
  30. * 核心功能:
  31. * - 配置 WebSocket 端点
  32. * - 配置消息代理
  33. * - 实现连接认证与授权
  34. * - 管理用户会话生命周期
  35. *
  36. * @author Ray.Hao
  37. * @since 3.0.0
  38. */
  39. @EnableWebSocketMessageBroker
  40. @Configuration
  41. @Slf4j
  42. public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
  43. private static final String WS_ENDPOINT = "/ws";
  44. private static final String APP_DESTINATION_PREFIX = "/app";
  45. private static final String USER_DESTINATION_PREFIX = "/user";
  46. private static final String[] BROKER_DESTINATIONS = {"/topic", "/queue"};
  47. private final TokenManager tokenManager;
  48. private final WebSocketService webSocketService;
  49. public WebSocketConfig(TokenManager tokenManager, @Lazy WebSocketService webSocketService) {
  50. this.tokenManager = tokenManager;
  51. this.webSocketService = webSocketService;
  52. log.info("✓ WebSocket 配置已加载");
  53. }
  54. /**
  55. * 注册 STOMP 端点
  56. *
  57. * 客户端通过该端点建立 WebSocket 连接
  58. */
  59. @Override
  60. public void registerStompEndpoints(StompEndpointRegistry registry) {
  61. registry
  62. .addEndpoint(WS_ENDPOINT)
  63. .setAllowedOriginPatterns("*"); // 允许跨域(生产环境建议配置具体域名)
  64. log.info("✓ STOMP 端点已注册: {}", WS_ENDPOINT);
  65. }
  66. /**
  67. * 配置消息代理
  68. *
  69. * - /app 前缀:客户端发送消息到服务端的前缀
  70. * - /topic 前缀:用于广播消息
  71. * - /queue 前缀:用于点对点消息
  72. * - /user 前缀:服务端发送给特定用户的消息前缀
  73. */
  74. @Override
  75. public void configureMessageBroker(MessageBrokerRegistry registry) {
  76. // 客户端发送消息的请求前缀
  77. registry.setApplicationDestinationPrefixes(APP_DESTINATION_PREFIX);
  78. // 启用简单消息代理,处理 /topic 和 /queue 前缀的消息
  79. registry.enableSimpleBroker(BROKER_DESTINATIONS);
  80. // 服务端通知客户端的前缀
  81. registry.setUserDestinationPrefix(USER_DESTINATION_PREFIX);
  82. log.info("✓ 消息代理已配置: app={}, broker={}, user={}",
  83. APP_DESTINATION_PREFIX, BROKER_DESTINATIONS, USER_DESTINATION_PREFIX);
  84. }
  85. /**
  86. * 配置客户端入站通道拦截器
  87. *
  88. * 核心功能:
  89. * 1. 连接建立时:解析 JWT Token 并绑定用户身份
  90. * 2. 连接关闭时:触发用户下线通知
  91. * 3. 安全防护:拦截无效连接请求
  92. */
  93. @Override
  94. public void configureClientInboundChannel(ChannelRegistration registration) {
  95. registration.interceptors(new ChannelInterceptor() {
  96. @Override
  97. public Message<?> preSend(@NotNull Message<?> message, @NotNull MessageChannel channel) {
  98. StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
  99. // 防御性检查:确保 accessor 不为空
  100. if (accessor == null) {
  101. log.warn("⚠ 收到异常消息:无法获取 StompHeaderAccessor");
  102. return ChannelInterceptor.super.preSend(message, channel);
  103. }
  104. StompCommand command = accessor.getCommand();
  105. if (command == null) {
  106. return ChannelInterceptor.super.preSend(message, channel);
  107. }
  108. try {
  109. switch (command) {
  110. case CONNECT:
  111. handleConnect(accessor);
  112. break;
  113. case DISCONNECT:
  114. handleDisconnect(accessor);
  115. break;
  116. case SUBSCRIBE:
  117. handleSubscribe(accessor);
  118. break;
  119. default:
  120. // 其他命令不需要特殊处理
  121. break;
  122. }
  123. } catch (AuthenticationException ex) {
  124. // 认证失败时强制关闭连接
  125. log.error("❌ 连接认证失败: {}", ex.getMessage());
  126. throw ex;
  127. } catch (Exception ex) {
  128. // 捕获其他未知异常
  129. log.error("❌ WebSocket 消息处理异常", ex);
  130. throw new MessagingException("消息处理失败: " + ex.getMessage());
  131. }
  132. return ChannelInterceptor.super.preSend(message, channel);
  133. }
  134. });
  135. log.info("✓ 客户端入站通道拦截器已配置");
  136. }
  137. /**
  138. * 处理客户端连接请求
  139. *
  140. * 安全校验流程:
  141. * 1. 提取 Authorization 头
  142. * 2. 验证 Bearer Token 格式
  143. * 3. 解析并验证 JWT 有效性
  144. * 4. 绑定用户身份到当前会话
  145. * 5. 记录用户上线状态
  146. */
  147. private void handleConnect(StompHeaderAccessor accessor) {
  148. String authorization = accessor.getFirstNativeHeader(HttpHeaders.AUTHORIZATION);
  149. // 安全检查:确保 Authorization 头存在且格式正确
  150. if (StrUtil.isBlank(authorization)) {
  151. log.warn("⚠ 非法连接请求:缺少 Authorization 头");
  152. throw new AuthenticationCredentialsNotFoundException("缺少 Authorization 头");
  153. }
  154. if (!authorization.startsWith("Bearer ")) {
  155. log.warn("⚠ 非法连接请求:Authorization 头格式错误");
  156. throw new BadCredentialsException("Authorization 头格式错误");
  157. }
  158. // 提取 JWT Token(移除 "Bearer " 前缀)
  159. String token = authorization.substring(7);
  160. if (StrUtil.isBlank(token)) {
  161. log.warn("⚠ 非法连接请求:Token 为空");
  162. throw new BadCredentialsException("Token 为空");
  163. }
  164. // 解析并验证 Token
  165. Authentication authentication;
  166. try {
  167. authentication = tokenManager.parseToken(token);
  168. } catch (Exception ex) {
  169. log.error("❌ Token 解析失败", ex);
  170. throw new BadCredentialsException("Token 无效: " + ex.getMessage());
  171. }
  172. // 验证解析结果
  173. if (authentication == null || !authentication.isAuthenticated()) {
  174. log.warn("⚠ Token 解析失败:认证对象无效");
  175. throw new BadCredentialsException("Token 解析失败");
  176. }
  177. // 获取用户详细信息
  178. Object principal = authentication.getPrincipal();
  179. if (!(principal instanceof SysUserDetails)) {
  180. log.error("❌ 无效的用户凭证类型: {}", principal.getClass().getName());
  181. throw new BadCredentialsException("用户凭证类型错误");
  182. }
  183. SysUserDetails userDetails = (SysUserDetails) principal;
  184. String username = userDetails.getUsername();
  185. if (StrUtil.isBlank(username)) {
  186. log.warn("⚠ 用户名为空");
  187. throw new BadCredentialsException("用户名为空");
  188. }
  189. // 绑定用户身份到当前会话(重要:用于 @SendToUser 等注解)
  190. accessor.setUser(authentication);
  191. // 获取会话 ID
  192. String sessionId = accessor.getSessionId();
  193. if (sessionId == null) {
  194. log.warn("⚠ 会话 ID 为空,使用临时 ID");
  195. sessionId = "temp-" + System.nanoTime();
  196. }
  197. // 记录用户上线状态
  198. try {
  199. webSocketService.userConnected(username, sessionId);
  200. log.info("✓ WebSocket 连接建立成功: 用户[{}], 会话[{}]", username, sessionId);
  201. } catch (Exception ex) {
  202. log.error("❌ 记录用户上线状态失败: 用户[{}], 会话[{}]", username, sessionId, ex);
  203. // 不抛出异常,允许连接继续
  204. }
  205. }
  206. /**
  207. * 处理客户端断开连接事件
  208. *
  209. * 注意:
  210. * - 只有成功建立过认证的连接才会触发下线事件
  211. * - 防止未认证成功的连接产生脏数据
  212. */
  213. private void handleDisconnect(StompHeaderAccessor accessor) {
  214. Authentication authentication = (Authentication) accessor.getUser();
  215. // 防御性检查:只处理已认证的连接
  216. if (authentication == null || !authentication.isAuthenticated()) {
  217. log.debug("未认证的连接断开,跳过处理");
  218. return;
  219. }
  220. Object principal = authentication.getPrincipal();
  221. if (!(principal instanceof SysUserDetails)) {
  222. log.warn("⚠ 断开连接时用户凭证类型异常");
  223. return;
  224. }
  225. SysUserDetails userDetails = (SysUserDetails) principal;
  226. String username = userDetails.getUsername();
  227. if (StrUtil.isNotBlank(username)) {
  228. try {
  229. webSocketService.userDisconnected(username);
  230. log.info("✓ WebSocket 连接断开: 用户[{}]", username);
  231. } catch (Exception ex) {
  232. log.error("❌ 记录用户下线状态失败: 用户[{}]", username, ex);
  233. }
  234. }
  235. }
  236. /**
  237. * 处理客户端订阅事件(可选)
  238. *
  239. * 用于记录订阅信息或实施订阅级别的权限控制
  240. */
  241. private void handleSubscribe(StompHeaderAccessor accessor) {
  242. Authentication authentication = (Authentication) accessor.getUser();
  243. if (authentication != null && authentication.isAuthenticated()) {
  244. String destination = accessor.getDestination();
  245. String username = authentication.getName();
  246. log.debug("用户[{}]订阅主题: {}", username, destination);
  247. // TODO: 这里可以实现订阅级别的权限控制
  248. // 例如:检查用户是否有权限订阅某个主题
  249. }
  250. }
  251. }