feat: refactor LLM api request code, get API URL from ApiKey object

This commit is contained in:
RockYang 2024-01-04 14:51:33 +08:00
parent 21f2622a4b
commit 4b1c4f7ccc
13 changed files with 59 additions and 53 deletions

View File

@ -14,7 +14,7 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
绘画函数插件。 绘画函数插件。
## 最新版本一键部署脚本 ## 最新版本一键部署脚本
目前仅支持 Ubuntu 和 Centos 系统。
```shell ```shell
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.3-8b588904ef.sh)" bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.3-8b588904ef.sh)"
``` ```

View File

@ -62,7 +62,6 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message const ReplyMsg = "reply" // reply message
const MjMsg = "mj"
var ModelToTokens = map[string]int{ var ModelToTokens = map[string]int{
"gpt-3.5-turbo": 4096, "gpt-3.5-turbo": 4096,
@ -75,4 +74,12 @@ var ModelToTokens = map[string]int{
"ernie_bot_turbo": 8192, // 文心一言 "ernie_bot_turbo": 8192, // 文心一言
"general": 8192, // 科大讯飞 "general": 8192, // 科大讯飞
"general2": 8192, "general2": 8192,
"general3": 8192,
}
func GetModelMaxToken(model string) int {
if token, ok := ModelToTokens[model]; ok {
return token
}
return 4096
} }

View File

@ -141,7 +141,6 @@ type InviteReward struct {
} }
type ModelAPIConfig struct { type ModelAPIConfig struct {
ApiURL string `json:"api_url,omitempty"`
Temperature float32 `json:"temperature"` Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
ApiKey string `json:"api_key"` ApiKey string `json:"api_key"`

View File

@ -29,7 +29,7 @@ func (h *ChatHandler) sendAzureMessage(
ws *types.WsClient) error { ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间 promptCreatedAt := time.Now() // 记录提问时间
start := time.Now() start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start)) logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil { if err != nil {

View File

@ -46,7 +46,7 @@ func (h *ChatHandler) sendBaiduMessage(
ws *types.WsClient) error { ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间 promptCreatedAt := time.Now() // 记录提问时间
start := time.Now() start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start)) logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil { if err != nil {

View File

@ -275,7 +275,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
if err == nil { if err == nil {
for _, v := range messages { for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model) tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.ModelToTokens[req.Model] { if tokens+tks >= types.GetModelMaxToken(req.Model) {
break break
} }
tokens += tks tokens += tks
@ -290,7 +290,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
if res.Error == nil { if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- { for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i] msg := historyMessages[i]
if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] { if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
break break
} }
tokens += msg.Tokens tokens += msg.Tokens
@ -401,39 +401,33 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器 // 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY // useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) { func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
var apiURL string var apiURL string
switch platform { switch platform {
case types.Azure: case types.Azure:
md := strings.Replace(req.Model, ".", "", 1) md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1) apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break break
case types.ChatGLM: case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1) apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil req.Messages = nil
break break
case types.Baidu: case types.Baidu:
apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1) apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break break
default: default:
apiURL = h.App.ChatConfig.OpenAI.ApiURL apiURL = apiKey.ApiURL
}
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
} }
// 更新 API KEY 的最后使用时间 // 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
// 百度文心,需要串接 access_token // 百度文心,需要串接 access_token
if platform == types.Baidu { if platform == types.Baidu {
token, err := h.getBaiduToken(*apiKey) token, err := h.getBaiduToken(apiKey.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -465,13 +459,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else { } else {
client = http.DefaultClient client = http.DefaultClient
} }
logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model) logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch platform { switch platform {
case types.Azure: case types.Azure:
request.Header.Set("api-key", *apiKey) request.Header.Set("api-key", apiKey.Value)
break break
case types.ChatGLM: case types.ChatGLM:
token, err := h.getChatGLMToken(*apiKey) token, err := h.getChatGLMToken(apiKey.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -480,7 +474,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
case types.Baidu: case types.Baidu:
request.RequestURI = "" request.RequestURI = ""
case types.OpenAI: case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
} }
return client.Do(request) return client.Do(request)
} }

View File

@ -30,7 +30,7 @@ func (h *ChatHandler) sendChatGLMMessage(
ws *types.WsClient) error { ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间 promptCreatedAt := time.Now() // 记录提问时间
start := time.Now() start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start)) logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil { if err != nil {

View File

@ -29,7 +29,7 @@ func (h *ChatHandler) sendOpenAiMessage(
ws *types.WsClient) error { ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间 promptCreatedAt := time.Now() // 记录提问时间
start := time.Now() start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start)) logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil { if err != nil {

View File

@ -67,29 +67,25 @@ func (h *ChatHandler) sendXunFeiMessage(
prompt string, prompt string,
ws *types.WsClient) error { ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间 promptCreatedAt := time.Now() // 记录提问时间
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] var apiKey model.ApiKey
if apiKey == "" { res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
var key model.ApiKey
res := h.db.Where("platform = ? AND type = ?", session.Model.Platform, "chat").Order("last_used_at ASC").First(&key)
if res.Error != nil { if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员") utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil return nil
} }
// 更新 API KEY 的最后使用时间 // 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
apiKey = key.Value
}
d := websocket.Dialer{ d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
} }
key := strings.Split(apiKey, "|") key := strings.Split(apiKey.Value, "|")
if len(key) != 3 { if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY") utils.ReplyMessage(ws, "非法的 API KEY")
return nil return nil
} }
apiURL := strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", Model2URL[req.Model], 1) apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接 //握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil) conn, resp, err := d.Dial(wsURL, nil)

View File

@ -208,7 +208,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
prompt := utils.InterfaceToString(params["prompt"]) prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY // get image generation API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
tx = h.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey) tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil { if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
return return
@ -231,7 +231,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
// translate prompt // translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey.Value, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL) pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL)
if err == nil { if err == nil {
prompt = pt prompt = pt
} }

View File

@ -66,10 +66,10 @@ func (h *PromptHandler) Translate(c *gin.Context) {
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) { func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
// 获取 OpenAI 的 API KEY // 获取 OpenAI 的 API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey) res := h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
if res.Error != nil { if res.Error != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error) return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
} }
return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey.Value, h.App.Config.ProxyURL, h.App.ChatConfig.OpenAI.ApiURL) return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey, h.App.Config.ProxyURL)
} }

View File

@ -1,5 +1,10 @@
package main package main
func main() { import (
"chatplus/utils"
"fmt"
)
func main() {
fmt.Println(utils.RandString(64))
} }

View File

@ -3,6 +3,7 @@ package utils
import ( import (
"chatplus/core/types" "chatplus/core/types"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/store/model"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
@ -85,7 +86,7 @@ type apiErrRes struct {
} `json:"error"` } `json:"error"`
} }
func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (string, error) { func OpenAIRequest(prompt string, apiKey model.ApiKey, proxy string) (string, error) {
messages := make([]interface{}, 1) messages := make([]interface{}, 1)
messages[0] = types.Message{ messages[0] = types.Message{
Role: "user", Role: "user",
@ -94,8 +95,12 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
var response apiRes var response apiRes
var errRes apiErrRes var errRes apiErrRes
r, err := req.C().SetProxyURL(proxy).R().SetHeader("Content-Type", "application/json"). client := req.C()
SetHeader("Authorization", "Bearer "+apiKey). if apiKey.UseProxy && proxy != "" {
client.SetProxyURL(proxy)
}
r, err := client.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(types.ApiRequest{ SetBody(types.ApiRequest{
Model: "gpt-3.5-turbo", Model: "gpt-3.5-turbo",
Temperature: 0.9, Temperature: 0.9,
@ -104,7 +109,7 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
Messages: messages, Messages: messages,
}). }).
SetErrorResult(&errRes). SetErrorResult(&errRes).
SetSuccessResult(&response).Post(apiURL) SetSuccessResult(&response).Post(apiKey.ApiURL)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message) return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
} }