feat: refactor LLM api request code, get API URL from ApiKey object

This commit is contained in:
RockYang 2024-01-04 14:51:33 +08:00
parent 21f2622a4b
commit 4b1c4f7ccc
13 changed files with 59 additions and 53 deletions

View File

@ -14,7 +14,7 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
绘画函数插件。
## 最新版本一键部署脚本
目前仅支持 Ubuntu 和 Centos 系统。
```shell
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.3-8b588904ef.sh)"
```

View File

@ -62,7 +62,6 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message
const MjMsg = "mj"
var ModelToTokens = map[string]int{
"gpt-3.5-turbo": 4096,
@ -75,4 +74,12 @@ var ModelToTokens = map[string]int{
"ernie_bot_turbo": 8192, // 文心一言
"general": 8192, // 科大讯飞
"general2": 8192,
"general3": 8192,
}
func GetModelMaxToken(model string) int {
if token, ok := ModelToTokens[model]; ok {
return token
}
return 4096
}

View File

@ -141,7 +141,6 @@ type InviteReward struct {
}
type ModelAPIConfig struct {
ApiURL string `json:"api_url,omitempty"`
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
ApiKey string `json:"api_key"`

View File

@ -29,7 +29,7 @@ func (h *ChatHandler) sendAzureMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {

View File

@ -46,7 +46,7 @@ func (h *ChatHandler) sendBaiduMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {

View File

@ -275,7 +275,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.ModelToTokens[req.Model] {
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
tokens += tks
@ -290,7 +290,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
break
}
tokens += msg.Tokens
@ -401,39 +401,33 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
var apiURL string
switch platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu:
apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
default:
apiURL = h.App.ChatConfig.OpenAI.ApiURL
apiURL = apiKey.ApiURL
}
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
// 更新 API KEY 的最后使用时间
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(*apiKey)
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
@ -465,13 +459,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model)
logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch platform {
case types.Azure:
request.Header.Set("api-key", *apiKey)
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM:
token, err := h.getChatGLMToken(*apiKey)
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
@ -480,7 +474,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
case types.Baidu:
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
}
return client.Do(request)
}

View File

@ -30,7 +30,7 @@ func (h *ChatHandler) sendChatGLMMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {

View File

@ -29,7 +29,7 @@ func (h *ChatHandler) sendOpenAiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {

View File

@ -67,29 +67,25 @@ func (h *ChatHandler) sendXunFeiMessage(
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
if apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ? AND type = ?", session.Model.Platform, "chat").Order("last_used_at ASC").First(&key)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
apiKey = key.Value
var apiKey model.ApiKey
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
key := strings.Split(apiKey, "|")
key := strings.Split(apiKey.Value, "|")
if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY")
return nil
}
apiURL := strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", Model2URL[req.Model], 1)
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)

View File

@ -208,7 +208,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY
var apiKey model.ApiKey
tx = h.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey)
tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
return
@ -231,7 +231,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
// translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey.Value, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL)
if err == nil {
prompt = pt
}

View File

@ -66,10 +66,10 @@ func (h *PromptHandler) Translate(c *gin.Context) {
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
// 获取 OpenAI 的 API KEY
var apiKey model.ApiKey
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
res := h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
if res.Error != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
}
return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey.Value, h.App.Config.ProxyURL, h.App.ChatConfig.OpenAI.ApiURL)
return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey, h.App.Config.ProxyURL)
}

View File

@ -1,5 +1,10 @@
package main
func main() {
import (
"chatplus/utils"
"fmt"
)
func main() {
fmt.Println(utils.RandString(64))
}

View File

@ -3,6 +3,7 @@ package utils
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/store/model"
"encoding/json"
"fmt"
"github.com/imroc/req/v3"
@ -85,7 +86,7 @@ type apiErrRes struct {
} `json:"error"`
}
func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (string, error) {
func OpenAIRequest(prompt string, apiKey model.ApiKey, proxy string) (string, error) {
messages := make([]interface{}, 1)
messages[0] = types.Message{
Role: "user",
@ -94,8 +95,12 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
var response apiRes
var errRes apiErrRes
r, err := req.C().SetProxyURL(proxy).R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey).
client := req.C()
if apiKey.UseProxy && proxy != "" {
client.SetProxyURL(proxy)
}
r, err := client.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(types.ApiRequest{
Model: "gpt-3.5-turbo",
Temperature: 0.9,
@ -104,7 +109,7 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
Messages: messages,
}).
SetErrorResult(&errRes).
SetSuccessResult(&response).Post(apiURL)
SetSuccessResult(&response).Post(apiKey.ApiURL)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
}