mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-11 19:53:50 +08:00
feat: refactor LLM api request code, get API URL from ApiKey object
This commit is contained in:
@@ -275,7 +275,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
if err == nil {
|
||||
for _, v := range messages {
|
||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||
if tokens+tks >= types.ModelToTokens[req.Model] {
|
||||
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
||||
break
|
||||
}
|
||||
tokens += tks
|
||||
@@ -290,7 +290,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
if res.Error == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
|
||||
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
|
||||
break
|
||||
}
|
||||
tokens += msg.Tokens
|
||||
@@ -401,39 +401,33 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
|
||||
// 发送请求到 OpenAI 服务器
|
||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
|
||||
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
||||
if res.Error != nil {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
var apiURL string
|
||||
switch platform {
|
||||
case types.Azure:
|
||||
md := strings.Replace(req.Model, ".", "", 1)
|
||||
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
|
||||
break
|
||||
case types.ChatGLM:
|
||||
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
||||
req.Messages = nil
|
||||
break
|
||||
case types.Baidu:
|
||||
apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1)
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||
break
|
||||
default:
|
||||
apiURL = h.App.ChatConfig.OpenAI.ApiURL
|
||||
apiURL = apiKey.ApiURL
|
||||
}
|
||||
if *apiKey == "" {
|
||||
var key model.ApiKey
|
||||
res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key)
|
||||
if res.Error != nil {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
*apiKey = key.Value
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
// 百度文心,需要串接 access_token
|
||||
if platform == types.Baidu {
|
||||
token, err := h.getBaiduToken(*apiKey)
|
||||
token, err := h.getBaiduToken(apiKey.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -465,13 +459,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model)
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
|
||||
switch platform {
|
||||
case types.Azure:
|
||||
request.Header.Set("api-key", *apiKey)
|
||||
request.Header.Set("api-key", apiKey.Value)
|
||||
break
|
||||
case types.ChatGLM:
|
||||
token, err := h.getChatGLMToken(*apiKey)
|
||||
token, err := h.getChatGLMToken(apiKey.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -480,7 +474,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
case types.Baidu:
|
||||
request.RequestURI = ""
|
||||
case types.OpenAI:
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
}
|
||||
return client.Do(request)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user