diff --git a/api/core/types/config.go b/api/core/types/config.go index 027ea14a..8b1341c9 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -155,45 +155,6 @@ func (c RedisConfig) Url() string { return fmt.Sprintf("%s:%d", c.Host, c.Port) } -type Platform struct { - Name string `json:"name"` - Value string `json:"value"` - ChatURL string `json:"chat_url"` - ImgURL string `json:"img_url"` -} - -var OpenAI = Platform{ - Name: "OpenAI - GPT", - Value: "OpenAI", - ChatURL: "https://api.chat-plus.net/v1/chat/completions", - ImgURL: "https://api.chat-plus.net/v1/images/generations", -} -var Azure = Platform{ - Name: "微软 - Azure", - Value: "Azure", - ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15", -} -var ChatGLM = Platform{ - Name: "智谱 - ChatGLM", - Value: "ChatGLM", - ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke", -} -var Baidu = Platform{ - Name: "百度 - 文心大模型", - Value: "Baidu", - ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}", -} -var XunFei = Platform{ - Name: "讯飞 - 星火大模型", - Value: "XunFei", - ChatURL: "wss://spark-api.xf-yun.com/{version}/chat", -} -var QWen = Platform{ - Name: "阿里 - 通义千问", - Value: "QWen", - ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", -} - type SystemConfig struct { Title string `json:"title,omitempty"` // 网站标题 Slogan string `json:"slogan,omitempty"` // 网站 slogan diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index 66ee2a27..c50c8d6f 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -150,10 +150,9 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) { // GetAppConfig 获取内置配置 func (h *ConfigHandler) GetAppConfig(c *gin.Context) { resp.SUCCESS(c, gin.H{ - "mj_plus": h.App.Config.MjPlusConfigs, - "mj_proxy": h.App.Config.MjProxyConfigs, - "sd": h.App.Config.SdConfigs, - "platforms": Platforms, + "mj_plus": h.App.Config.MjPlusConfigs, + "mj_proxy": h.App.Config.MjProxyConfigs, + "sd": h.App.Config.SdConfigs, }) } diff --git a/api/handler/admin/types.go b/api/handler/admin/types.go deleted file mode 100644 index c06139ba..00000000 --- a/api/handler/admin/types.go +++ /dev/null @@ -1,12 +0,0 @@ -package admin - -import "geekai/core/types" - -var Platforms = []types.Platform{ - types.OpenAI, - types.QWen, - types.XunFei, - types.ChatGLM, - types.Baidu, - types.Azure, -} diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go deleted file mode 100644 index bd28d720..00000000 --- a/api/handler/chatimpl/azure_handler.go +++ /dev/null @@ -1,111 +0,0 @@ -package chatimpl - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "geekai/core/types" - "geekai/store/model" - "geekai/store/vo" - "geekai/utils" - "io" - "strings" - "time" -) - -// 微软 Azure 模型消息发送实现 - -func (h *ChatHandler) sendAzureMessage( - chatCtx []types.Message, - req types.ApiRequest, - userVo vo.User, - ctx context.Context, - session *types.ChatSession, - role model.ChatRole, - prompt string, - ws *types.WsClient) error { - promptCreatedAt := time.Now() // 记录提问时间 - start := time.Now() - var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session, &apiKey) - logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) - if err != nil { - if strings.Contains(err.Error(), "context canceled") { - return fmt.Errorf("用户取消了请求:%s", prompt) - } else if strings.Contains(err.Error(), "no available key") { - return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") - } - return err - } else { - defer response.Body.Close() - } - - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { - replyCreatedAt := time.Now() // 记录回复时间 - // 循环读取 Chunk 消息 - var message = types.Message{} - var contents = make([]string, 0) - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - line := scanner.Text() - if !strings.Contains(line, "data:") || len(line) < 30 { - continue - } - - var responseBody = types.ApiResponse{} - err = json.Unmarshal([]byte(line[6:]), &responseBody) - if err != nil { // 数据解析出错 - return errors.New(line) - } - - if len(responseBody.Choices) == 0 { - continue - } - - // 初始化 role - if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { - message.Role = responseBody.Choices[0].Delta.Role - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - continue - } else if responseBody.Choices[0].FinishReason != "" { - break // 输出完成或者输出中断了 - } else { - content := responseBody.Choices[0].Delta.Content - contents = append(contents, utils.InterfaceToString(content)) - utils.ReplyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), - }) - } - } // end for - - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - } else { - logger.Error("信息读取出错:", err) - } - } - - // 消息发送成功 - if len(contents) > 0 { - h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) - } - - } else { - body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求大模型 API 失败:%s", body) - } - - return nil -} diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go deleted file mode 100644 index 783ac3e9..00000000 --- a/api/handler/chatimpl/baidu_handler.go +++ /dev/null @@ -1,185 +0,0 @@ -package chatimpl - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "geekai/core/types" - "geekai/store/model" - "geekai/store/vo" - "geekai/utils" - "io" - "net/http" - "strings" - "time" -) - -type baiduResp struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - SentenceId int `json:"sentence_id"` - IsEnd bool `json:"is_end"` - IsTruncated bool `json:"is_truncated"` - Result string `json:"result"` - NeedClearHistory bool `json:"need_clear_history"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -// 百度文心一言消息发送实现 - -func (h *ChatHandler) sendBaiduMessage( - chatCtx []types.Message, - req types.ApiRequest, - userVo vo.User, - ctx context.Context, - session *types.ChatSession, - role model.ChatRole, - prompt string, - ws *types.WsClient) error { - promptCreatedAt := time.Now() // 记录提问时间 - start := time.Now() - var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session, &apiKey) - logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) - if err != nil { - logger.Error(err) - if strings.Contains(err.Error(), "context canceled") { - return fmt.Errorf("用户取消了请求:%s", prompt) - } else if strings.Contains(err.Error(), "no available key") { - return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") - } - return err - } else { - defer response.Body.Close() - } - - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { - replyCreatedAt := time.Now() // 记录回复时间 - // 循环读取 Chunk 消息 - var message = types.Message{} - var contents = make([]string, 0) - var content string - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - line := scanner.Text() - if len(line) < 5 || strings.HasPrefix(line, "id:") { - continue - } - - if strings.HasPrefix(line, "data:") { - content = line[5:] - } - - // 处理代码换行 - if len(content) == 0 { - content = "\n" - } - - var resp baiduResp - err := utils.JsonDecode(content, &resp) - if err != nil { - logger.Error("error with parse data line: ", err) - utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) - break - } - - if len(contents) == 0 { - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - } - utils.ReplyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(resp.Result), - }) - contents = append(contents, resp.Result) - - if resp.IsTruncated { - utils.ReplyMessage(ws, "AI 输出异常中断") - break - } - - if resp.IsEnd { - break - } - - } // end for - - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - } else { - logger.Error("信息读取出错:", err) - } - } - - // 消息发送成功 - if len(contents) > 0 { - h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) - } - } else { - body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求大模型 API 失败:%s", body) - } - - return nil -} - -func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) { - ctx := context.Background() - tokenString, err := h.redis.Get(ctx, apiKey).Result() - if err == nil { - return tokenString, nil - } - - expr := time.Hour * 24 * 20 // access_token 有效期 - key := strings.Split(apiKey, "|") - if len(key) != 2 { - return "", fmt.Errorf("invalid api key: %s", apiKey) - } - url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1]) - client := &http.Client{} - req, err := http.NewRequest("POST", url, nil) - if err != nil { - return "", err - } - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - - res, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("error with send request: %w", err) - } - defer res.Body.Close() - - body, err := io.ReadAll(res.Body) - if err != nil { - return "", fmt.Errorf("error with read response: %w", err) - } - var r map[string]interface{} - err = json.Unmarshal(body, &r) - if err != nil { - return "", fmt.Errorf("error with parse response: %w", err) - } - - if r["error"] != nil { - return "", fmt.Errorf("error with api response: %s", r["error_description"]) - } - - tokenString = fmt.Sprintf("%s", r["access_token"]) - h.redis.Set(ctx, apiKey, tokenString, expr) - return tokenString, nil -} diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 1aeb58ee..43cae8d2 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -208,21 +208,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio Model: session.Model.Value, Stream: true, } - switch session.Model.Platform { - case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value: - req.Temperature = session.Model.Temperature - req.MaxTokens = session.Model.MaxTokens - break - case types.OpenAI.Value: - req.Temperature = session.Model.Temperature - req.MaxTokens = session.Model.MaxTokens - // OpenAI 支持函数功能 - var items []model.Function - res := h.DB.Where("enabled", true).Find(&items) - if res.Error != nil { - break - } - + req.Temperature = session.Model.Temperature + req.MaxTokens = session.Model.MaxTokens + // OpenAI 支持函数功能 + var items []model.Function + res = h.DB.Where("enabled", true).Find(&items) + if res.Error == nil { var tools = make([]types.Tool, 0) for _, v := range items { var parameters map[string]interface{} @@ -248,15 +239,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio req.Tools = tools req.ToolChoice = "auto" } - case types.QWen.Value: - req.Parameters = map[string]interface{}{ - "max_tokens": session.Model.MaxTokens, - "temperature": session.Model.Temperature, - } - break - - default: - return fmt.Errorf("不支持的平台:%s", session.Model.Platform) } // 加载聊天上下文 @@ -344,65 +326,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } logger.Debug("最终Prompt:", fullPrompt) - if session.Model.Platform == types.QWen.Value { - req.Input = make(map[string]interface{}) - reqMgs = append(reqMgs, types.Message{ - Role: "user", - Content: fullPrompt, - }) - req.Input["messages"] = reqMgs - } else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model - imgURLs := utils.ExtractImgURLs(prompt) - logger.Debugf("detected IMG: %+v", imgURLs) - var content interface{} - if len(imgURLs) > 0 { - data := make([]interface{}, 0) - for _, v := range imgURLs { - text = strings.Replace(text, v, "", 1) - data = append(data, gin.H{ - "type": "image_url", - "image_url": gin.H{ - "url": v, - }, - }) - } + // extract images from prompt + imgURLs := utils.ExtractImgURLs(prompt) + logger.Debugf("detected IMG: %+v", imgURLs) + var content interface{} + if len(imgURLs) > 0 { + data := make([]interface{}, 0) + for _, v := range imgURLs { + text = strings.Replace(text, v, "", 1) data = append(data, gin.H{ - "type": "text", - "text": strings.TrimSpace(text), + "type": "image_url", + "image_url": gin.H{ + "url": v, + }, }) - content = data - } else { - content = fullPrompt } - req.Messages = append(reqMgs, map[string]interface{}{ - "role": "user", - "content": content, + data = append(data, gin.H{ + "type": "text", + "text": strings.TrimSpace(text), }) + content = data } else { - req.Messages = append(reqMgs, map[string]interface{}{ - "role": "user", - "content": fullPrompt, - }) + content = fullPrompt } + req.Messages = append(reqMgs, map[string]interface{}{ + "role": "user", + "content": content, + }) logger.Debugf("%+v", req.Messages) - switch session.Model.Platform { - case types.Azure.Value: - return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.OpenAI.Value: - return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.ChatGLM.Value: - return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.Baidu.Value: - return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.XunFei.Value: - return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.QWen.Value: - return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - } - - return nil + return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) } // Tokens 统计 token 数量 @@ -485,48 +439,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi } // ONLY allow apiURL in blank list - if session.Model.Platform == types.OpenAI.Value { - err := h.licenseService.IsValidApiURL(apiKey.ApiURL) - if err != nil { - return nil, err - } + err := h.licenseService.IsValidApiURL(apiKey.ApiURL) + if err != nil { + return nil, err } - - var apiURL string - switch session.Model.Platform { - case types.Azure.Value: - md := strings.Replace(req.Model, ".", "", 1) - apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) - break - case types.ChatGLM.Value: - apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) - req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 - req.Messages = nil - break - case types.Baidu.Value: - apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) - break - case types.QWen.Value: - apiURL = apiKey.ApiURL - req.Messages = nil - break - default: - apiURL = apiKey.ApiURL - } - // 更新 API KEY 的最后使用时间 - h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) - // 百度文心,需要串接 access_token - if session.Model.Platform == types.Baidu.Value { - token, err := h.getBaiduToken(apiKey.Value) - if err != nil { - return nil, err - } - logger.Info("百度文心 Access_Token:", token) - apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token) - } - logger.Debugf(utils.JsonEncode(req)) + apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) // 创建 HttpClient 请求对象 var client *http.Client requestBody, err := json.Marshal(req) @@ -550,28 +469,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi } else { client = http.DefaultClient } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) - switch session.Model.Platform { - case types.Azure.Value: - request.Header.Set("api-key", apiKey.Value) - break - case types.ChatGLM.Value: - token, err := h.getChatGLMToken(apiKey.Value) - if err != nil { - return nil, err - } - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - break - case types.Baidu.Value: - request.RequestURI = "" - case types.OpenAI.Value: - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) - break - case types.QWen.Value: - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) - request.Header.Set("X-DashScope-SSE", "enable") - break - } + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) return client.Do(request) } diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go deleted file mode 100644 index 0192abc8..00000000 --- a/api/handler/chatimpl/chatglm_handler.go +++ /dev/null @@ -1,142 +0,0 @@ -package chatimpl - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "bufio" - "context" - "errors" - "fmt" - "geekai/core/types" - "geekai/store/model" - "geekai/store/vo" - "geekai/utils" - "github.com/golang-jwt/jwt/v5" - "io" - "strings" - "time" -) - -// 清华大学 ChatGML 消息发送实现 - -func (h *ChatHandler) sendChatGLMMessage( - chatCtx []types.Message, - req types.ApiRequest, - userVo vo.User, - ctx context.Context, - session *types.ChatSession, - role model.ChatRole, - prompt string, - ws *types.WsClient) error { - promptCreatedAt := time.Now() // 记录提问时间 - start := time.Now() - var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session, &apiKey) - logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) - if err != nil { - if strings.Contains(err.Error(), "context canceled") { - return fmt.Errorf("用户取消了请求:%s", prompt) - } else if strings.Contains(err.Error(), "no available key") { - return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") - } - return err - } else { - defer response.Body.Close() - } - - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { - replyCreatedAt := time.Now() // 记录回复时间 - // 循环读取 Chunk 消息 - var message = types.Message{} - var contents = make([]string, 0) - var event, content string - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - line := scanner.Text() - if len(line) < 5 || strings.HasPrefix(line, "id:") { - continue - } - if strings.HasPrefix(line, "event:") { - event = line[6:] - continue - } - - if strings.HasPrefix(line, "data:") { - content = line[5:] - } - // 处理代码换行 - if len(content) == 0 { - content = "\n" - } - switch event { - case "add": - if len(contents) == 0 { - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - } - utils.ReplyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(content), - }) - contents = append(contents, content) - case "finish": - break - case "error": - utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content)) - break - case "interrupted": - utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**") - } - - } // end for - - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - } else { - logger.Error("信息读取出错:", err) - } - } - - // 消息发送成功 - if len(contents) > 0 { - h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) - } - } else { - body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求大模型 API 失败:%s", body) - } - - return nil -} - -func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) { - ctx := context.Background() - tokenString, err := h.redis.Get(ctx, apiKey).Result() - if err == nil { - return tokenString, nil - } - - expr := time.Hour * 2 - key := strings.Split(apiKey, ".") - if len(key) != 2 { - return "", fmt.Errorf("invalid api key: %s", apiKey) - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "api_key": key[0], - "timestamp": time.Now().Unix(), - "exp": time.Now().Add(expr).Add(time.Second * 10).Unix(), - }) - token.Header["alg"] = "HS256" - token.Header["sign_type"] = "SIGN" - delete(token.Header, "typ") - // Sign and get the complete encoded token as a string using the secret - tokenString, err = token.SignedString([]byte(key[1])) - h.redis.Set(ctx, apiKey, tokenString, expr) - return tokenString, err -} diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go deleted file mode 100644 index 28bf66bb..00000000 --- a/api/handler/chatimpl/qwen_handler.go +++ /dev/null @@ -1,150 +0,0 @@ -package chatimpl - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "bufio" - "context" - "fmt" - "geekai/core/types" - "geekai/store/model" - "geekai/store/vo" - "geekai/utils" - "github.com/syndtr/goleveldb/leveldb/errors" - "io" - "strings" - "time" -) - -type qWenResp struct { - Output struct { - FinishReason string `json:"finish_reason"` - Text string `json:"text"` - } `json:"output,omitempty"` - Usage struct { - TotalTokens int `json:"total_tokens"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"usage,omitempty"` - RequestID string `json:"request_id"` - - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -// 通义千问消息发送实现 -func (h *ChatHandler) sendQWenMessage( - chatCtx []types.Message, - req types.ApiRequest, - userVo vo.User, - ctx context.Context, - session *types.ChatSession, - role model.ChatRole, - prompt string, - ws *types.WsClient) error { - promptCreatedAt := time.Now() // 记录提问时间 - start := time.Now() - var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session, &apiKey) - logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) - if err != nil { - if strings.Contains(err.Error(), "context canceled") { - return fmt.Errorf("用户取消了请求:%s", prompt) - } else if strings.Contains(err.Error(), "no available key") { - return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") - } - return err - } else { - defer response.Body.Close() - } - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { - replyCreatedAt := time.Now() // 记录回复时间 - // 循环读取 Chunk 消息 - var message = types.Message{} - var contents = make([]string, 0) - scanner := bufio.NewScanner(response.Body) - - var content, lastText, newText string - var outPutStart = false - - for scanner.Scan() { - line := scanner.Text() - if len(line) < 5 || strings.HasPrefix(line, "id:") || - strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") { - continue - } - - if !strings.HasPrefix(line, "data:") { - continue - } - - content = line[5:] - var resp qWenResp - if len(contents) == 0 { // 发送消息头 - if !outPutStart { - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - outPutStart = true - continue - } else { - // 处理代码换行 - content = "\n" - } - } else { - err := utils.JsonDecode(content, &resp) - if err != nil { - logger.Error("error with parse data line: ", content) - utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) - break - } - if resp.Message != "" { - utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message)) - break - } - } - - //通过比较 lastText(上一次的文本)和 currentText(当前的文本), - //提取出新添加的文本部分。然后只将这部分新文本发送到客户端。 - //每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。 - currentText := resp.Output.Text - if currentText != lastText { - // 提取新增文本 - newText = strings.Replace(currentText, lastText, "", 1) - utils.ReplyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(newText), - }) - lastText = currentText // 更新 lastText - } - contents = append(contents, newText) - - if resp.Output.FinishReason == "stop" { - break - } - - } //end for - - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - } else { - logger.Error("信息读取出错:", err) - } - } - - // 消息发送成功 - if len(contents) > 0 { - h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) - } - } else { - body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求大模型 API 失败:%s", body) - } - - return nil -} diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go deleted file mode 100644 index 7ebea8d5..00000000 --- a/api/handler/chatimpl/xunfei_handler.go +++ /dev/null @@ -1,255 +0,0 @@ -package chatimpl - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "geekai/core/types" - "geekai/store/model" - "geekai/store/vo" - "geekai/utils" - "github.com/gorilla/websocket" - "gorm.io/gorm" - "io" - "net/http" - "net/url" - "strings" - "time" -) - -type xunFeiResp struct { - Header struct { - Code int `json:"code"` - Message string `json:"message"` - Sid string `json:"sid"` - Status int `json:"status"` - } `json:"header"` - Payload struct { - Choices struct { - Status int `json:"status"` - Seq int `json:"seq"` - Text []struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` - } `json:"text"` - } `json:"choices"` - Usage struct { - Text struct { - QuestionTokens int `json:"question_tokens"` - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"text"` - } `json:"usage"` - } `json:"payload"` -} - -var Model2URL = map[string]string{ - "general": "v1.1", - "generalv2": "v2.1", - "generalv3": "v3.1", - "generalv3.5": "v3.5", -} - -// 科大讯飞消息发送实现 - -func (h *ChatHandler) sendXunFeiMessage( - chatCtx []types.Message, - req types.ApiRequest, - userVo vo.User, - ctx context.Context, - session *types.ChatSession, - role model.ChatRole, - prompt string, - ws *types.WsClient) error { - promptCreatedAt := time.Now() // 记录提问时间 - var apiKey model.ApiKey - var res *gorm.DB - // use the bind key - if session.Model.KeyId > 0 { - res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey) - } - // use the last unused key - if apiKey.Id == 0 { - res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey) - } - if res.Error != nil { - return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") - } - // 更新 API KEY 的最后使用时间 - h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) - - d := websocket.Dialer{ - HandshakeTimeout: 5 * time.Second, - } - key := strings.Split(apiKey.Value, "|") - if len(key) != 3 { - utils.ReplyMessage(ws, "非法的 API KEY!") - return nil - } - - apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1) - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) - wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) - //握手并建立websocket 连接 - conn, resp, err := d.Dial(wsURL, nil) - if err != nil { - logger.Error(readResp(resp) + err.Error()) - utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) - return nil - } else if resp.StatusCode != 101 { - utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) - return nil - } - - data := buildRequest(key[0], req) - fmt.Printf("%+v", data) - fmt.Println(apiURL) - err = conn.WriteJSON(data) - if err != nil { - utils.ReplyMessage(ws, "发送消息失败:"+err.Error()) - return nil - } - - replyCreatedAt := time.Now() // 记录回复时间 - // 循环读取 Chunk 消息 - var message = types.Message{} - var contents = make([]string, 0) - var content string - for { - _, msg, err := conn.ReadMessage() - if err != nil { - logger.Error("error with read message:", err) - utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err)) - break - } - - // 解析数据 - var result xunFeiResp - err = json.Unmarshal(msg, &result) - if err != nil { - logger.Error("error with parsing JSON:", err) - utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) - return nil - } - - if result.Header.Code != 0 { - utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message)) - return nil - } - - content = result.Payload.Choices.Text[0].Content - // 处理代码换行 - if len(content) == 0 { - content = "\n" - } - contents = append(contents, content) - // 第一个结果 - if result.Payload.Choices.Status == 0 { - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - } - utils.ReplyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(content), - }) - - if result.Payload.Choices.Status == 2 { // 最终结果 - _ = conn.Close() // 关闭连接 - break - } - - select { - case <-ctx.Done(): - utils.ReplyMessage(ws, "**用户取消了生成指令!**") - return nil - default: - continue - } - - } - // 消息发送成功 - if len(contents) > 0 { - h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) - } - return nil -} - -// 构建 websocket 请求实体 -func buildRequest(appid string, req types.ApiRequest) map[string]interface{} { - return map[string]interface{}{ - "header": map[string]interface{}{ - "app_id": appid, - }, - "parameter": map[string]interface{}{ - "chat": map[string]interface{}{ - "domain": req.Model, - "temperature": req.Temperature, - "top_k": int64(6), - "max_tokens": int64(req.MaxTokens), - "auditing": "default", - }, - }, - "payload": map[string]interface{}{ - "message": map[string]interface{}{ - "text": req.Messages, - }, - }, - } -} - -// 创建鉴权 URL -func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) { - ul, err := url.Parse(hostURL) - if err != nil { - return "", err - } - - date := time.Now().UTC().Format(time.RFC1123) - signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} - //拼接签名字符串 - signStr := strings.Join(signString, "\n") - sha := hmacWithSha256(signStr, apiSecret) - - authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, - "hmac-sha256", "host date request-line", sha) - //将请求参数使用base64编码 - authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) - v := url.Values{} - v.Add("host", ul.Host) - v.Add("date", date) - v.Add("authorization", authorization) - //将编码后的字符串url encode后添加到url后面 - return hostURL + "?" + v.Encode(), nil -} - -// 使用 sha256 签名 -func hmacWithSha256(data, key string) string { - mac := hmac.New(sha256.New, []byte(key)) - mac.Write([]byte(data)) - encodeData := mac.Sum(nil) - return base64.StdEncoding.EncodeToString(encodeData) -} - -// 读取响应 -func readResp(resp *http.Response) string { - if resp == nil { - return "" - } - b, err := io.ReadAll(resp.Body) - if err != nil { - panic(err) - } - return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) -} diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index 368d12f7..e6926f07 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -212,21 +212,21 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode } func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) { + + session := h.DB.Session(&gorm.Session{}) // if the chat model bind a KEY, use it directly - var res *gorm.DB if chatModel.KeyId > 0 { - res = h.DB.Where("id", chatModel.KeyId).Find(apiKey) - } - // use the last unused key - if apiKey.Id == 0 { - res = h.DB.Where("platform", types.OpenAI.Value). - Where("type", "chat"). - Where("enabled", true).Order("last_used_at ASC").First(apiKey) + session = session.Where("id", chatModel.KeyId) + } else { // use the last unused key + session = session.Where("type", "chat"). + Where("enabled", true).Order("last_used_at ASC") } + + res := session.First(apiKey) if res.Error != nil { return nil, errors.New("no available key, please import key") } - apiURL := apiKey.ApiURL + apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) // 更新 API KEY 的最后使用时间 h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 9a915da5..4225182f 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -145,7 +145,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { // get image generation API KEY var apiKey model.ApiKey - tx = s.db.Where("type", "img"). + tx = s.db.Where("type", "dalle"). Where("enabled", true). Order("last_used_at ASC").First(&apiKey) if tx.Error != nil { @@ -157,6 +157,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { if len(apiKey.ProxyURL) > 5 { s.httpClient.SetProxyURL(apiKey.ProxyURL).R() } + apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL) reqBody := imgReq{ Model: "dall-e-3", Prompt: prompt, @@ -165,14 +166,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { Style: task.Style, Quality: task.Quality, } - logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody) - request := s.httpClient.R().SetHeader("Content-Type", "application/json") - if apiKey.Platform == types.Azure.Value { - request = request.SetHeader("api-key", apiKey.Value) - } else { - request = request.SetHeader("Authorization", "Bearer "+apiKey.Value) - } - r, err := request.SetBody(reqBody).SetErrorResult(&errRes).SetSuccessResult(&res).Post(apiKey.ApiURL) + logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiURL, apiKey.Value, reqBody) + r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(reqBody). + SetErrorResult(&errRes). + SetSuccessResult(&res). + Post(apiURL) if err != nil { return "", fmt.Errorf("error with send request: %v", err) } diff --git a/api/utils/openai.go b/api/utils/openai.go index 9ee01a35..5a3a83c6 100644 --- a/api/utils/openai.go +++ b/api/utils/openai.go @@ -54,7 +54,7 @@ type apiErrRes struct { func OpenAIRequest(db *gorm.DB, prompt string) (string, error) { var apiKey model.ApiKey - res := db.Where("platform", types.OpenAI.Value).Where("type", "chat").Where("enabled", true).First(&apiKey) + res := db.Where("type", "chat").Where("enabled", true).First(&apiKey) if res.Error != nil { return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) } diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index 08bdb08c..5d7cd6ad 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -2,19 +2,11 @@
注意:如果是百度文心一言平台,API-KEY 为 APIKey|SecretKey,中间用竖线(|)连接
-注意:如果是讯飞星火大模型,API-KEY 为 AppId|APIKey|APISecret,中间用竖线(|)连接
-