feat: allow user to set custom api keys for different platforms

This commit is contained in:
RockYang
2023-09-04 17:34:29 +08:00
parent f7a427d2c0
commit 2820adad53
11 changed files with 76 additions and 68 deletions

View File

@@ -169,7 +169,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
if userVo.Calls <= 0 {
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
utils.ReplyMessage(ws, "您的对话次数已经用尽请联系管理员或者点击左下角菜单加入众筹获得100次对话")
utils.ReplyMessage(ws, "![](/images/wx.png)")
return nil
@@ -189,7 +189,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.Temperature = h.App.ChatConfig.Azure.Temperature
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
break
case types.ChatGML:
case types.ChatGLM:
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
break
@@ -208,7 +208,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
// 加载聊天上下文
var chatCtx []interface{}
if userVo.ChatConfig.EnableContext {
if h.App.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
} else {
@@ -269,11 +269,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGML:
case types.ChatGLM:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
return nil
return fmt.Errorf("not supported platform: %s", session.Model.Platform)
}
// Tokens 统计 token 数量
@@ -336,7 +335,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
break
case types.ChatGML:
case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages
req.Messages = nil
@@ -368,21 +367,24 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, key.Value, proxyURL, req.Model)
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
switch platform {
case types.Azure:
request.Header.Set("api-key", key.Value)
request.Header.Set("api-key", *apiKey)
break
case types.ChatGML:
token, err := h.getChatGLMToken(key.Value)
case types.ChatGLM:
token, err := h.getChatGLMToken(*apiKey)
if err != nil {
return nil, err
}
@@ -390,8 +392,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
default:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
}
*apiKey = key.Value
return client.Do(request)
}