diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go index c5efe78d..f30991a5 100644 --- a/api/handler/admin/chat_model_handler.go +++ b/api/handler/admin/chat_model_handler.go @@ -15,6 +15,7 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index 303632a0..4987e58e 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -16,6 +16,7 @@ import ( "geekai/store/model" "geekai/utils" "geekai/utils/resp" + "github.com/gin-gonic/gin" "github.com/shirou/gopsutil/host" "gorm.io/gorm" @@ -128,3 +129,12 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) { license := h.licenseService.GetLicense() resp.SUCCESS(c, license) } + +// GetDrawingConfig 获取AI绘画配置 +func (h *ConfigHandler) GetDrawingConfig(c *gin.Context) { + resp.SUCCESS(c, gin.H{ + "mj_plus": h.App.Config.MjPlusConfigs, + "mj_proxy": h.App.Config.MjProxyConfigs, + "sd": h.App.Config.SdConfigs, + }) +} diff --git a/api/handler/chat_model_handler.go b/api/handler/chat_model_handler.go index baf3e471..555de7c8 100644 --- a/api/handler/chat_model_handler.go +++ b/api/handler/chat_model_handler.go @@ -13,6 +13,7 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -32,7 +33,7 @@ func (h *ChatModelHandler) List(c *gin.Context) { var res *gorm.DB // 如果用户没有登录,则加载所有开放模型 if !h.IsLogin(c) { - res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items) + res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items) } else { user, _ := h.GetLoginUser(c) var models []int @@ -43,7 +44,7 @@ func (h *ChatModelHandler) List(c *gin.Context) { } // 查询用户有权限访问的模型以及所有开放的模型 res = h.DB.Where("enabled = ?", true).Where( - h.DB.Where("id IN ?", models).Or("open =?", true), + h.DB.Where("id IN ?", models).Or("open", true), ).Order("sort_num ASC").Find(&items) } diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index e74f5036..bf67ab7b 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -148,7 +148,6 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode contentType := response.Header.Get("Content-Type") if strings.Contains(contentType, "text/event-stream") { // 循环读取 Chunk 消息 - var message = types.Message{} scanner := bufio.NewScanner(response.Body) var isNew = true for scanner.Scan() { @@ -159,26 +158,26 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode var responseBody = types.ApiResponse{} err = json.Unmarshal([]byte(line[6:]), &responseBody) - if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 - return fmt.Errorf("error with decode data: %v", err) + if err != nil { // 数据解析出错 + return fmt.Errorf("error with decode data: %v", line) } - // 初始化 role - if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { - message.Role = responseBody.Choices[0].Delta.Role + if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 continue - } else if responseBody.Choices[0].FinishReason != "" { - break // 输出完成或者输出中断了 - } else { - if isNew { - utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) - isNew = false - } - utils.ReplyChunkMessage(client, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), - }) } + + if responseBody.Choices[0].FinishReason == "stop" { + break + } + + if isNew { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) + isNew = false + } + utils.ReplyChunkMessage(client, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) } // end for utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) diff --git a/api/main.go b/api/main.go index ef46926f..a19ed49e 100644 --- a/api/main.go +++ b/api/main.go @@ -302,6 +302,7 @@ func main() { group.GET("config/get", h.Get) group.POST("active", h.Active) group.GET("config/get/license", h.GetLicense) + group.GET("config/get/draw", h.GetDrawingConfig) }), fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { group := s.Engine.Group("/api/admin/") diff --git a/web/src/views/MarkMap.vue b/web/src/views/MarkMap.vue index 9faf91b4..37382a0c 100644 --- a/web/src/views/MarkMap.vue +++ b/web/src/views/MarkMap.vue @@ -81,7 +81,7 @@