diff --git a/README.md b/README.md index 8b85243d..d421442c 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 绘画函数插件。 ## 最新版本一键部署脚本 - +目前仅支持 Ubuntu 和 Centos 系统。 ```shell bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.3-8b588904ef.sh)" ``` diff --git a/api/core/types/chat.go b/api/core/types/chat.go index bfc9ecec..255ddcba 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -62,7 +62,6 @@ type ApiError struct { const PromptMsg = "prompt" // prompt message const ReplyMsg = "reply" // reply message -const MjMsg = "mj" var ModelToTokens = map[string]int{ "gpt-3.5-turbo": 4096, @@ -75,4 +74,12 @@ var ModelToTokens = map[string]int{ "ernie_bot_turbo": 8192, // 文心一言 "general": 8192, // 科大讯飞 "general2": 8192, + "general3": 8192, +} + +func GetModelMaxToken(model string) int { + if token, ok := ModelToTokens[model]; ok { + return token + } + return 4096 } diff --git a/api/core/types/config.go b/api/core/types/config.go index 2872e73a..3bea5a71 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -141,7 +141,6 @@ type InviteReward struct { } type ModelAPIConfig struct { - ApiURL string `json:"api_url,omitempty"` Temperature float32 `json:"temperature"` MaxTokens int `json:"max_tokens"` ApiKey string `json:"api_key"` diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go index a1a7c328..f4f75e6b 100644 --- a/api/handler/chatimpl/azure_handler.go +++ b/api/handler/chatimpl/azure_handler.go @@ -29,7 +29,7 @@ func (h *ChatHandler) sendAzureMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() - var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] + var apiKey = model.ApiKey{} response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go index bc92b8a2..62fcbf1a 100644 --- a/api/handler/chatimpl/baidu_handler.go +++ b/api/handler/chatimpl/baidu_handler.go @@ -46,7 +46,7 @@ func (h *ChatHandler) sendBaiduMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() - var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] + var apiKey = model.ApiKey{} response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 6be5dc49..95cb50af 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -275,7 +275,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio if err == nil { for _, v := range messages { tks, _ := utils.CalcTokens(v.Content, req.Model) - if tokens+tks >= types.ModelToTokens[req.Model] { + if tokens+tks >= types.GetModelMaxToken(req.Model) { break } tokens += tks @@ -290,7 +290,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio if res.Error == nil { for i := len(historyMessages) - 1; i >= 0; i-- { msg := historyMessages[i] - if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] { + if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) { break } tokens += msg.Tokens @@ -401,39 +401,33 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) { // 发送请求到 OpenAI 服务器 // useOwnApiKey: 是否使用了用户自己的 API KEY -func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) { - +func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) { + res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) + if res.Error != nil { + return nil, errors.New("no available key, please import key") + } var apiURL string switch platform { case types.Azure: md := strings.Replace(req.Model, ".", "", 1) - apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1) + apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) break case types.ChatGLM: - apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1) + apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 req.Messages = nil break case types.Baidu: - apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1) + apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) break default: - apiURL = h.App.ChatConfig.OpenAI.ApiURL + apiURL = apiKey.ApiURL } - if *apiKey == "" { - var key model.ApiKey - res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key) - if res.Error != nil { - return nil, errors.New("no available key, please import key") - } - // 更新 API KEY 的最后使用时间 - h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) - *apiKey = key.Value - } - + // 更新 API KEY 的最后使用时间 + h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) // 百度文心,需要串接 access_token if platform == types.Baidu { - token, err := h.getBaiduToken(*apiKey) + token, err := h.getBaiduToken(apiKey.Value) if err != nil { return nil, err } @@ -465,13 +459,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf } else { client = http.DefaultClient } - logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model) + logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model) switch platform { case types.Azure: - request.Header.Set("api-key", *apiKey) + request.Header.Set("api-key", apiKey.Value) break case types.ChatGLM: - token, err := h.getChatGLMToken(*apiKey) + token, err := h.getChatGLMToken(apiKey.Value) if err != nil { return nil, err } @@ -480,7 +474,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf case types.Baidu: request.RequestURI = "" case types.OpenAI: - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) } return client.Do(request) } diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go index 0a9f74df..5fbacad7 100644 --- a/api/handler/chatimpl/chatglm_handler.go +++ b/api/handler/chatimpl/chatglm_handler.go @@ -30,7 +30,7 @@ func (h *ChatHandler) sendChatGLMMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() - var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] + var apiKey = model.ApiKey{} response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index d4fa1aee..84a8bcc8 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -29,7 +29,7 @@ func (h *ChatHandler) sendOpenAiMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() - var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] + var apiKey = model.ApiKey{} response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index d4ccd664..b1b33d16 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -67,29 +67,25 @@ func (h *ChatHandler) sendXunFeiMessage( prompt string, ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 - var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] - if apiKey == "" { - var key model.ApiKey - res := h.db.Where("platform = ? AND type = ?", session.Model.Platform, "chat").Order("last_used_at ASC").First(&key) - if res.Error != nil { - utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") - return nil - } - // 更新 API KEY 的最后使用时间 - h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) - apiKey = key.Value + var apiKey model.ApiKey + res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) + if res.Error != nil { + utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + return nil } + // 更新 API KEY 的最后使用时间 + h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - key := strings.Split(apiKey, "|") + key := strings.Split(apiKey.Value, "|") if len(key) != 3 { utils.ReplyMessage(ws, "非法的 API KEY!") return nil } - apiURL := strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", Model2URL[req.Model], 1) + apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1) wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) //握手并建立websocket 连接 conn, resp, err := d.Dial(wsURL, nil) diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index a3bf3fec..ca78589a 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -208,7 +208,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { prompt := utils.InterfaceToString(params["prompt"]) // get image generation API KEY var apiKey model.ApiKey - tx = h.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey) + tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) if tx.Error != nil { resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) return @@ -231,7 +231,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { // translate prompt const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" - pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey.Value, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL) + pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL) if err == nil { prompt = pt } diff --git a/api/handler/prompt_handler.go b/api/handler/prompt_handler.go index c4e021d1..c05d28af 100644 --- a/api/handler/prompt_handler.go +++ b/api/handler/prompt_handler.go @@ -66,10 +66,10 @@ func (h *PromptHandler) Translate(c *gin.Context) { func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) { // 获取 OpenAI 的 API KEY var apiKey model.ApiKey - res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey) + res := h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey) if res.Error != nil { return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) } - return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey.Value, h.App.Config.ProxyURL, h.App.ChatConfig.OpenAI.ApiURL) + return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey, h.App.Config.ProxyURL) } diff --git a/api/test/test.go b/api/test/test.go index fe7f767d..2fb3b157 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,5 +1,10 @@ package main +import ( + "chatplus/utils" + "fmt" +) + func main() { - + fmt.Println(utils.RandString(64)) } diff --git a/api/utils/net.go b/api/utils/net.go index 972d7c74..d8cc8614 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -3,6 +3,7 @@ package utils import ( "chatplus/core/types" logger2 "chatplus/logger" + "chatplus/store/model" "encoding/json" "fmt" "github.com/imroc/req/v3" @@ -85,7 +86,7 @@ type apiErrRes struct { } `json:"error"` } -func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (string, error) { +func OpenAIRequest(prompt string, apiKey model.ApiKey, proxy string) (string, error) { messages := make([]interface{}, 1) messages[0] = types.Message{ Role: "user", @@ -94,8 +95,12 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s var response apiRes var errRes apiErrRes - r, err := req.C().SetProxyURL(proxy).R().SetHeader("Content-Type", "application/json"). - SetHeader("Authorization", "Bearer "+apiKey). + client := req.C() + if apiKey.UseProxy && proxy != "" { + client.SetProxyURL(proxy) + } + r, err := client.R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). SetBody(types.ApiRequest{ Model: "gpt-3.5-turbo", Temperature: 0.9, @@ -104,7 +109,7 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s Messages: messages, }). SetErrorResult(&errRes). - SetSuccessResult(&response).Post(apiURL) + SetSuccessResult(&response).Post(apiKey.ApiURL) if err != nil || r.IsErrorState() { return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message) }