update 优化 将框架内的 sse ws 统一走认证处理器 不再自行排除编码处理

update 适配 snail-ai 的 spring-ai 版本过低问题 先降级spring-ai版本到m4等后续适配
This commit is contained in:
疯狂的狮子Li
2026-05-26 19:07:15 +08:00
parent 1e1b33764d
commit b5da5f30c9
10 changed files with 143 additions and 57 deletions
+1 -1
View File
@@ -56,7 +56,7 @@
<easy-es.version>3.0.2</easy-es.version>
<elasticsearch-client.version>7.17.28</elasticsearch-client.version>
<!-- Spring AI 2.0 预览版,正式版发布后仅需调整此版本号 -->
<spring-ai.version>2.0.0-M6</spring-ai.version>
<spring-ai.version>2.0.0-M4</spring-ai.version>
<!-- 插件版本 -->
<maven-jar-plugin.version>3.5.0</maven-jar-plugin.version>
@@ -0,0 +1,28 @@
package org.dromara.common.push.annotation;
import org.dromara.common.push.condition.MessageTransportCondition;
import org.springframework.context.annotation.Conditional;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 按消息推送传输方式启用组件。
*
* @author Lion Li
*/
@Documented
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Conditional(MessageTransportCondition.class)
public @interface ConditionalOnMessageTransport {
/**
* 传输方式:sse / websocket。
*/
String value();
}
@@ -0,0 +1,39 @@
package org.dromara.common.push.condition;
import org.dromara.common.push.annotation.ConditionalOnMessageTransport;
import org.dromara.common.push.enums.MessageTransportEnum;
import org.jspecify.annotations.NonNull;
import org.springframework.context.annotation.Condition;
import org.springframework.context.annotation.ConditionContext;
import org.springframework.core.type.AnnotatedTypeMetadata;
import java.util.Map;
/**
* 消息推送传输方式条件判断。
*
* @author Lion Li
*/
public class MessageTransportCondition implements Condition {
/**
* 判断当前消息推送配置是否匹配注解声明的传输方式。
*
* @param context 条件上下文
* @param metadata 注解元数据
* @return 是否匹配
*/
@Override
public boolean matches(@NonNull ConditionContext context, AnnotatedTypeMetadata metadata) {
Map<String, Object> attributes = metadata.getAnnotationAttributes(ConditionalOnMessageTransport.class.getName());
if (attributes == null) {
return true;
}
Boolean enabled = context.getEnvironment().getProperty("message.enabled", Boolean.class, true);
String transport = context.getEnvironment().getProperty("message.transport", MessageTransportEnum.SSE.getCode());
String expected = (String) attributes.get("value");
return enabled && expected.equalsIgnoreCase(transport);
}
}
@@ -1,11 +1,11 @@
package org.dromara.common.push.config;
import org.dromara.common.push.annotation.ConditionalOnMessageTransport;
import org.dromara.common.push.controller.SseController;
import org.dromara.common.push.core.SseEmitterSessionManager;
import org.dromara.common.push.listener.MessageTopicListener;
import org.dromara.common.push.properties.MessageProperties;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.context.annotation.Bean;
import java.util.concurrent.ScheduledExecutorService;
@@ -16,7 +16,7 @@ import java.util.concurrent.ScheduledExecutorService;
* @author Lion Li
*/
@AutoConfiguration(after = MessageAutoConfiguration.class)
@ConditionalOnExpression("'${message.enabled:true}'.equalsIgnoreCase('true') && '${message.transport:sse}'.equalsIgnoreCase('sse')")
@ConditionalOnMessageTransport("sse")
public class MessageSseConfiguration {
/**
@@ -1,12 +1,12 @@
package org.dromara.common.push.config;
import org.dromara.common.push.annotation.ConditionalOnMessageTransport;
import org.dromara.common.push.listener.MessageTopicListener;
import org.dromara.common.push.core.WebSocketSessionManager;
import org.dromara.common.push.handler.PlusWebSocketHandler;
import org.dromara.common.push.interceptor.PlusWebSocketInterceptor;
import org.dromara.common.push.properties.MessageProperties;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.context.annotation.Bean;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
@@ -22,7 +22,7 @@ import java.util.concurrent.ScheduledExecutorService;
*/
@EnableWebSocket
@AutoConfiguration(after = MessageAutoConfiguration.class)
@ConditionalOnExpression("'${message.enabled:true}'.equalsIgnoreCase('true') && '${message.transport:sse}'.equalsIgnoreCase('websocket')")
@ConditionalOnMessageTransport("websocket")
public class MessageWebSocketConfiguration {
/**
@@ -2,8 +2,10 @@ package org.dromara.common.push.controller;
import cn.dev33.satoken.annotation.SaIgnore;
import cn.dev33.satoken.stp.StpUtil;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import org.dromara.common.core.domain.R;
import org.dromara.common.push.annotation.ConditionalOnMessageTransport;
import org.dromara.common.push.core.SseEmitterSessionManager;
import org.dromara.common.satoken.utils.LoginHelper;
import org.springframework.beans.factory.DisposableBean;
@@ -18,6 +20,7 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
* @author Lion Li
*/
@RestController
@ConditionalOnMessageTransport("sse")
@RequiredArgsConstructor
public class SseController implements DisposableBean {
@@ -29,10 +32,8 @@ public class SseController implements DisposableBean {
* @return SSE 发射器
*/
@GetMapping(value = "${message.path:/resource/message}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter connect() {
if (!StpUtil.isLogin()) {
return null;
}
public SseEmitter connect(HttpServletResponse response) {
prepareSseResponse(response);
String tokenValue = StpUtil.getTokenValue();
Long userId = LoginHelper.getUserId();
return sessionManager.connect(userId, tokenValue);
@@ -52,6 +53,18 @@ public class SseController implements DisposableBean {
return R.ok();
}
/**
* 设置 SSE 响应头,覆盖统一鉴权成功路径中的默认 JSON 响应类型。
*
* @param response 当前响应
*/
private void prepareSseResponse(HttpServletResponse response) {
response.setContentType(MediaType.TEXT_EVENT_STREAM_VALUE);
response.setCharacterEncoding("UTF-8");
response.setHeader("Cache-Control", "no-cache");
response.setHeader("X-Accel-Buffering", "no");
}
// 以下为demo仅供参考 禁止使用 请在业务逻辑中使用工具发送而不是用接口发送
// /**
// * 向特定用户发送消息
@@ -1,9 +1,6 @@
package org.dromara.common.push.interceptor;
import cn.dev33.satoken.exception.NotLoginException;
import cn.dev33.satoken.stp.StpUtil;
import lombok.extern.slf4j.Slf4j;
import org.dromara.common.core.utils.StringUtils;
import org.dromara.common.push.constant.MessageConstants;
import org.dromara.common.satoken.utils.LoginHelper;
import org.dromara.system.api.model.LoginUser;
@@ -11,7 +8,6 @@ import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;
import java.util.Map;
@@ -20,12 +16,10 @@ import java.util.Map;
*
* @author Lion Li
*/
@Slf4j
public class PlusWebSocketInterceptor implements HandshakeInterceptor {
/**
* 握手前拦截(核心认证逻辑)
* 校验登录状态、Token、客户端ID,认证通过才允许建立 WebSocket 连接
* 握手前提取统一鉴权后的用户信息。
*
* @param attributes 用于传递到 WebSocketSession 的属性集合
* @return 是否允许握手(true=允许,false=拒绝)
@@ -33,40 +27,11 @@ public class PlusWebSocketInterceptor implements HandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) {
try {
// 1. 获取当前登录用户与 Token
LoginUser loginUser = LoginHelper.getLoginUser();
String tokenValue = StpUtil.getTokenValue();
// 2. 未登录直接拒绝握手
if (loginUser == null || StringUtils.isBlank(tokenValue)) {
return false;
}
// 3. 校验客户端ID(防止多端冒用)
String headerCid = request.getHeaders().getFirst(LoginHelper.CLIENT_KEY);
String paramCid = UriComponentsBuilder.fromUri(request.getURI())
.build()
.getQueryParams()
.getFirst(LoginHelper.CLIENT_KEY);
Object clientExtra = StpUtil.getExtra(LoginHelper.CLIENT_KEY);
// 客户端ID必须与请求头/参数中的一致,否则拒绝连接
if (clientExtra == null || !StringUtils.equalsAny(clientExtra.toString(), headerCid, paramCid)) {
throw NotLoginException.newInstance(StpUtil.getLoginType(),
"-100", "客户端ID与Token不匹配",
StpUtil.getTokenValue());
}
// 4. 认证通过,将用户信息存入会话属性,供后续 WebSocketHandler 使用
attributes.put(MessageConstants.LOGIN_USER_KEY, loginUser);
attributes.put(MessageConstants.LOGIN_TOKEN_KEY, tokenValue);
return true;
} catch (NotLoginException e) {
// 认证失败,记录日志并拒绝连接
log.error("WebSocket 认证失败'{}',无法访问系统资源", e.getMessage());
return false;
}
LoginUser loginUser = LoginHelper.getLoginUser();
String tokenValue = StpUtil.getTokenValue();
attributes.put(MessageConstants.LOGIN_USER_KEY, loginUser);
attributes.put(MessageConstants.LOGIN_TOKEN_KEY, tokenValue);
return true;
}
/**
@@ -3,12 +3,14 @@ package org.dromara.common.security.config;
import cn.dev33.satoken.exception.NotLoginException;
import cn.dev33.satoken.exception.NotPermissionException;
import cn.dev33.satoken.filter.SaServletFilter;
import cn.dev33.satoken.filter.SaTokenContextFilterForJakartaServlet;
import cn.dev33.satoken.httpauth.basic.SaHttpBasicUtil;
import cn.dev33.satoken.interceptor.SaInterceptor;
import cn.dev33.satoken.router.SaRouter;
import cn.dev33.satoken.stp.StpUtil;
import cn.dev33.satoken.util.SaResult;
import cn.dev33.satoken.util.SaTokenConsts;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
@@ -21,13 +23,15 @@ import org.dromara.common.core.utils.StringUtils;
import org.dromara.common.satoken.utils.LoginHelper;
import org.dromara.common.security.config.properties.SecurityProperties;
import org.dromara.common.security.handler.AllUrlHandler;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.core.Ordered;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import java.util.EnumSet;
import java.util.List;
/**
@@ -45,8 +49,28 @@ public class SecurityConfig implements WebMvcConfigurer {
private static final String CLIENT_RULE_SEPARATOR_REGEX = "[,;\\r\\n]+";
private final SecurityProperties securityProperties;
@Value("${message.path:/resource/message}")
private String messagePath;
/**
* 重新注册 Sa-Token 上下文过滤器,使其覆盖 Servlet 异步分发。
* <p>
* SSE、WebSocket 握手等场景可能触发 ASYNC/ERROR dispatcher,如果上下文过滤器只处理普通 REQUEST,
* 后续统一鉴权或业务代码读取 SaHolder/StpUtil 时会出现 SaTokenContext 未初始化。
*
* @param filter Sa-Token 官方上下文过滤器
* @return 过滤器注册配置
*/
@Bean
public FilterRegistrationBean<SaTokenContextFilterForJakartaServlet> saTokenContextFilterRegistration(
SaTokenContextFilterForJakartaServlet filter) {
FilterRegistrationBean<SaTokenContextFilterForJakartaServlet> registration = new FilterRegistrationBean<>();
registration.setFilter(filter);
registration.setName("saTokenContextFilterForServlet");
registration.addUrlPatterns("/*");
registration.setDispatcherTypes(EnumSet.of(DispatcherType.REQUEST, DispatcherType.ASYNC, DispatcherType.ERROR));
registration.setAsyncSupported(true);
registration.setOrder(Ordered.HIGHEST_PRECEDENCE);
return registration;
}
/**
* 注册 Sa-Token 路由拦截器并配置鉴权规则。
@@ -91,8 +115,7 @@ public class SecurityConfig implements WebMvcConfigurer {
});
})).addPathPatterns("/**")
// 排除不需要拦截的路径
.excludePathPatterns(securityProperties.getExcludes())
.excludePathPatterns(messagePath);
.excludePathPatterns(securityProperties.getExcludes());
}
/**
@@ -19,6 +19,7 @@ import com.aizuda.snail.ai.openapi.client.core.api.OpenApiChatClient;
import com.aizuda.snail.ai.openapi.client.core.api.OpenApiConversationClient;
import com.aizuda.snail.ai.openapi.client.core.api.OpenApiUserClient;
import com.aizuda.snail.ai.openapi.client.core.listener.SseEventListener;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
@@ -190,7 +191,9 @@ public class SnailAiController extends BaseController {
@PostMapping(value = "/agent/{agentId}/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter chatStream(
@NotNull(message = "智能体ID不能为空") @PathVariable Long agentId,
@RequestBody OpenApiChatRequest request) {
@RequestBody OpenApiChatRequest request,
HttpServletResponse response) {
prepareSseResponse(response);
SseEmitter emitter = new SseEmitter(SSE_TIMEOUT);
AtomicBoolean closed = new AtomicBoolean(false);
emitter.onTimeout(() -> {
@@ -303,6 +306,18 @@ public class SnailAiController extends BaseController {
}
}
/**
* 设置 SSE 响应头,覆盖统一鉴权成功路径中的默认 JSON 响应类型。
*
* @param response 当前响应
*/
private void prepareSseResponse(HttpServletResponse response) {
response.setContentType(MediaType.TEXT_EVENT_STREAM_VALUE);
response.setCharacterEncoding("UTF-8");
response.setHeader("Cache-Control", "no-cache");
response.setHeader("X-Accel-Buffering", "no");
}
/**
* 获取当前登录用户对应的 openId,不存在时会自动注册。
*/
+4 -1
View File
@@ -190,7 +190,10 @@ INSERT IGNORE INTO snail_ai_model_provider (provider_name, provider_key, descrip
VALUES ('OpenAI', 'openai', 'OpenAI官方模型 (GPT-4, GPT-3.5等)', 1),
('Claude', 'claude', 'Anthropic Claude模型', 1),
('Ollama', 'ollama', '本地开源模型 (Llama, Mistral等)', 1),
('Google Gemini', 'gemini', 'Google Gemini模型', 1);
('Google Gemini', 'gemini', 'Google Gemini模型', 1),
('阿里云百炼', 'qwen', '阿里云百炼 OpenAI 兼容模型 (Qwen等)', 1),
('DeepSeek', 'deepseek', 'DeepSeek OpenAI 兼容模型', 1),
('智谱AI', 'zhipu', '智谱AI OpenAI 兼容模型 (GLM等)', 1);
-- ============================================
-- 智能体相关表