diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 95f46c05..88f6233d 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -67,4 +67,8 @@ var ModelToTokens = map[string]int{ "gpt-3.5-turbo-16k": 16384, "gpt-4": 8192, "gpt-4-32k": 32768, + "chatglm_pro": 32768, + "chatglm_std": 16384, + "chatglm_lite": 4096, + "ernie_bot_turbo": 8192, // 文心一言 } diff --git a/api/core/types/config.go b/api/core/types/config.go index 9e4d957a..32adbe52 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -79,6 +79,7 @@ type ChatConfig struct { OpenAI ModelAPIConfig `json:"open_ai"` Azure ModelAPIConfig `json:"azure"` ChatGML ModelAPIConfig `json:"chat_gml"` + Baidu ModelAPIConfig `json:"baidu"` EnableContext bool `json:"enable_context"` // 是否开启聊天上下文 EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录 diff --git a/api/handler/admin/api_key_handler.go b/api/handler/admin/api_key_handler.go index b4614f6c..50e24cdb 100644 --- a/api/handler/admin/api_key_handler.go +++ b/api/handler/admin/api_key_handler.go @@ -36,11 +36,11 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { apiKey := model.ApiKey{} if data.Id > 0 { - h.db.Find(&apiKey) + h.db.Find(&apiKey, data.Id) } apiKey.Platform = data.Platform apiKey.Value = data.Value - res := h.db.Save(&apiKey) + res := h.db.Debug().Save(&apiKey) if res.Error != nil { resp.ERROR(c, "更新数据库失败!") return diff --git a/api/handler/baidu_handler.go b/api/handler/baidu_handler.go index ba376745..b802951f 100644 --- a/api/handler/baidu_handler.go +++ b/api/handler/baidu_handler.go @@ -9,14 +9,30 @@ import ( "context" "encoding/json" "fmt" - "github.com/golang-jwt/jwt/v5" "gorm.io/gorm" "io" + "net/http" "strings" "time" "unicode/utf8" ) +type baiduResp struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` + IsTruncated bool `json:"is_truncated"` + Result string `json:"result"` + NeedClearHistory bool `json:"need_clear_history"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + // 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端 func (h *ChatHandler) sendBaiduMessage( chatCtx []interface{}, @@ -56,38 +72,42 @@ func (h *ChatHandler) sendBaiduMessage( // 循环读取 Chunk 消息 var message = types.Message{} var contents = make([]string, 0) - var event, content string + var content string scanner := bufio.NewScanner(response.Body) for scanner.Scan() { line := scanner.Text() if len(line) < 5 || strings.HasPrefix(line, "id:") { continue } - if strings.HasPrefix(line, "event:") { - event = line[6:] - continue - } if strings.HasPrefix(line, "data:") { content = line[5:] } - switch event { - case "add": - if len(contents) == 0 { - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - } - utils.ReplyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(content), - }) - contents = append(contents, content) - case "finish": + + var resp baiduResp + err := utils.JsonDecode(content, &resp) + if err != nil { + logger.Error("error with parse data line: ", err) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) break - case "error": - utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content)) + } + + if len(contents) == 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(resp.Result), + }) + contents = append(contents, resp.Result) + + if resp.IsTruncated { + utils.ReplyMessage(ws, "AI 输出异常中断") + break + } + + if resp.IsEnd { break - case "interrupted": - utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**") } } // end for @@ -192,17 +212,14 @@ func (h *ChatHandler) sendBaiduMessage( } var res struct { - Code int `json:"code"` - Success bool `json:"success"` - Msg string `json:"msg"` + Code int `json:"error_code"` + Msg string `json:"error_msg"` } err = json.Unmarshal(body, &res) if err != nil { return fmt.Errorf("error with decode response: %v", err) } - if !res.Success { - utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg) - } + utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg) } return nil @@ -215,21 +232,41 @@ func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) { return tokenString, nil } - expr := time.Hour * 2 - key := strings.Split(apiKey, ".") + expr := time.Hour * 24 * 20 // access_token 有效期 + key := strings.Split(apiKey, "|") if len(key) != 2 { return "", fmt.Errorf("invalid api key: %s", apiKey) } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "api_key": key[0], - "timestamp": time.Now().Unix(), - "exp": time.Now().Add(expr).Add(time.Second * 10).Unix(), - }) - token.Header["alg"] = "HS256" - token.Header["sign_type"] = "SIGN" - delete(token.Header, "typ") - // Sign and get the complete encoded token as a string using the secret - tokenString, err = token.SignedString([]byte(key[1])) + url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1]) + client := &http.Client{} + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return "", err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + + res, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error with send request: %w", err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return "", fmt.Errorf("error with read response: %w", err) + } + var r map[string]interface{} + err = json.Unmarshal(body, &r) + if err != nil { + return "", fmt.Errorf("error with parse response: %w", err) + } + + if r["error"] != nil { + return "", fmt.Errorf("error with api response: %s", r["error_description"]) + } + + tokenString = fmt.Sprintf("%s", r["access_token"]) h.redis.Set(ctx, apiKey, tokenString, expr) - return tokenString, err + return tokenString, nil } diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index c29cfe67..0a7fafac 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -202,6 +202,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio case types.OpenAI: req.Temperature = h.App.ChatConfig.OpenAI.Temperature req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens + // OpenAI 支持函数功能 var functions = make([]types.Function, 0) for _, f := range types.InnerFunctions { if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney { @@ -281,6 +282,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) case types.ChatGLM: return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.Baidu: + return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + } utils.ReplyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, @@ -364,12 +368,36 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf break case types.ChatGLM: apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1) - req.Prompt = req.Messages + req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 req.Messages = nil break + case types.Baidu: + apiURL = h.App.ChatConfig.Baidu.ApiURL + break default: apiURL = h.App.ChatConfig.OpenAI.ApiURL } + if *apiKey == "" { + var key model.ApiKey + res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key) + if res.Error != nil { + return nil, errors.New("no available key, please import key") + } + // 更新 API KEY 的最后使用时间 + h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) + *apiKey = key.Value + } + + // 百度文心,需要串接 access_token + if platform == types.Baidu { + token, err := h.getBaiduToken(*apiKey) + if err != nil { + return nil, err + } + logger.Info("百度文心 Access_Token:", token) + apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token) + } + // 创建 HttpClient 请求对象 var client *http.Client requestBody, err := json.Marshal(req) @@ -394,17 +422,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf } else { client = http.DefaultClient } - if *apiKey == "" { - var key model.ApiKey - res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key) - if res.Error != nil { - return nil, errors.New("no available key, please import key") - } - // 更新 API KEY 的最后使用时间 - h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) - *apiKey = key.Value - } - logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model) switch platform { case types.Azure: @@ -418,7 +435,9 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf logger.Info(token) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) break - default: + case types.Baidu: + request.RequestURI = "" + case types.OpenAI: request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) } return client.Do(request) diff --git a/api/test/test.go b/api/test/test.go index 62c0e4de..c0e85921 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,14 +1,55 @@ package main import ( + "encoding/json" "fmt" - "os" + "io" + "log" + "net/http" ) func main() { - bytes, err := os.ReadFile("res/text2img.json") + apiKey := "qjvqGdqpTY7qQaGBMenM7XgQ" + apiSecret := "3G1RzBGXywZv4VbYRTyAfNns1vIOAG8t" + token, err := getBaiduToken(apiKey, apiSecret) if err != nil { - panic(err) + log.Fatal(err) } - fmt.Println(string(bytes)) + + fmt.Println(token) + +} + +func getBaiduToken(apiKey string, apiSecret string) (string, error) { + + url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", apiKey, apiSecret) + client := &http.Client{} + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return "", err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + + res, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error with send request: %w", err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return "", fmt.Errorf("error with read response: %w", err) + } + var r map[string]interface{} + err = json.Unmarshal(body, &r) + if err != nil { + return "", fmt.Errorf("error with parse response: %w", err) + } + + if r["error"] != nil { + return "", fmt.Errorf("error with api response: %s", r["error_description"]) + } + + return fmt.Sprintf("%s", r["access_token"]), nil } diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index 1ad7cd62..65e9c94f 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -91,7 +91,6 @@ const platforms = ref([ {name: "【百度】文心一言", value: "Baidu"}, {name: "【微软】Azure", value: "Azure"}, {name: "【OpenAI】ChatGPT", value: "OpenAI"}, - ]) // 获取数据 diff --git a/web/src/views/admin/ChatModel.vue b/web/src/views/admin/ChatModel.vue index 3c0540b4..2152858f 100644 --- a/web/src/views/admin/ChatModel.vue +++ b/web/src/views/admin/ChatModel.vue @@ -47,7 +47,7 @@ - {{ item }} + {{ item.name }} @@ -94,7 +94,12 @@ const rules = reactive({ }) const loading = ref(true) const formRef = ref(null) -const platforms = ref(["Azure", "OpenAI", "ChatGLM"]) +const platforms = ref([ + {name: "【清华智普】ChatGLM", value: "ChatGLM"}, + {name: "【百度】文心一言", value: "Baidu"}, + {name: "【微软】Azure", value: "Azure"}, + {name: "【OpenAI】ChatGPT", value: "OpenAI"}, +]) // 获取数据 httpGet('/api/admin/model/list').then((res) => { diff --git a/web/src/views/admin/SysConfig.vue b/web/src/views/admin/SysConfig.vue index 663153ad..c4a9cda7 100644 --- a/web/src/views/admin/SysConfig.vue +++ b/web/src/views/admin/SysConfig.vue @@ -158,6 +158,9 @@ onMounted(() => { if (res.data.chat_gml) { chat.value.chat_gml = res.data.chat_gml } + if (res.data.baidu) { + chat.value.baidu = res.data.baidu + } chat.value.context_deep = res.data.context_deep chat.value.enable_context = res.data.enable_context chat.value.enable_history = res.data.enable_history