feat: allow user to set custom api keys for different platforms

This commit is contained in:
RockYang 2023-09-04 17:34:29 +08:00
parent 7ecd7eeba1
commit 32774d23c7
11 changed files with 76 additions and 68 deletions

View File

@ -113,11 +113,11 @@ type Platform string
const OpenAI = Platform("OpenAI")
const Azure = Platform("Azure")
const ChatGML = Platform("ChatGML")
const ChatGLM = Platform("ChatGLM")
// UserChatConfig 用户的聊天配置
type UserChatConfig struct {
ApiKeys map[Platform]string
ApiKeys map[Platform]string `json:"api_keys"`
}
type ModelAPIConfig struct {

View File

@ -109,7 +109,7 @@ func (h *UserHandler) Save(c *gin.Context) {
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGML: "",
types.ChatGLM: "",
},
}),
Calls: h.App.SysConfig.UserInitCalls,

View File

@ -28,7 +28,7 @@ func (h *ChatHandler) sendAzureMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey string
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
@ -174,7 +174,9 @@ func (h *ChatHandler) sendAzureMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
}
if message.Role == "" {
message.Role = "assistant"
@ -183,14 +185,14 @@ func (h *ChatHandler) sendAzureMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false {
if h.App.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if userVo.ChatConfig.EnableHistory {
if h.App.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
@ -254,8 +256,6 @@ func (h *ChatHandler) sendAzureMessage(
} else {
totalTokens = replyToken + getTotalTokens(req)
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}

View File

@ -169,7 +169,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
if userVo.Calls <= 0 {
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
utils.ReplyMessage(ws, "您的对话次数已经用尽请联系管理员或者点击左下角菜单加入众筹获得100次对话")
utils.ReplyMessage(ws, "![](/images/wx.png)")
return nil
@ -189,7 +189,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.Temperature = h.App.ChatConfig.Azure.Temperature
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
break
case types.ChatGML:
case types.ChatGLM:
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
break
@ -208,7 +208,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
// 加载聊天上下文
var chatCtx []interface{}
if userVo.ChatConfig.EnableContext {
if h.App.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
} else {
@ -269,11 +269,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGML:
case types.ChatGLM:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
return nil
return fmt.Errorf("not supported platform: %s", session.Model.Platform)
}
// Tokens 统计 token 数量
@ -336,7 +335,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
break
case types.ChatGML:
case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages
req.Messages = nil
@ -368,6 +367,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
@ -375,14 +375,16 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, key.Value, proxyURL, req.Model)
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
switch platform {
case types.Azure:
request.Header.Set("api-key", key.Value)
request.Header.Set("api-key", *apiKey)
break
case types.ChatGML:
token, err := h.getChatGLMToken(key.Value)
case types.ChatGLM:
token, err := h.getChatGLMToken(*apiKey)
if err != nil {
return nil, err
}
@ -390,8 +392,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
default:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
}
*apiKey = key.Value
return client.Do(request)
}

View File

@ -29,7 +29,7 @@ func (h *ChatHandler) sendChatGLMMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey string
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
@ -103,7 +103,9 @@ func (h *ChatHandler) sendChatGLMMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
}
if message.Role == "" {
message.Role = "assistant"
@ -112,14 +114,14 @@ func (h *ChatHandler) sendChatGLMMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext {
if h.App.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if userVo.ChatConfig.EnableHistory {
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
@ -167,8 +169,6 @@ func (h *ChatHandler) sendChatGLMMessage(
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
totalTokens = replyToken + getTotalTokens(req)
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}
@ -205,7 +205,7 @@ func (h *ChatHandler) sendChatGLMMessage(
return fmt.Errorf("error with decode response: %v", err)
}
if !res.Success {
utils.ReplyMessage(ws, "请求 ChatGML 失败:"+res.Msg)
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
}
}

View File

@ -28,7 +28,7 @@ func (h *ChatHandler) sendOpenAiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey string
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
@ -174,7 +174,9 @@ func (h *ChatHandler) sendOpenAiMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
}
if message.Role == "" {
message.Role = "assistant"
@ -183,14 +185,14 @@ func (h *ChatHandler) sendOpenAiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false {
if h.App.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if userVo.ChatConfig.EnableHistory {
if h.App.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
@ -254,8 +256,6 @@ func (h *ChatHandler) sendOpenAiMessage(
} else {
totalTokens = replyToken + getTotalTokens(req)
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}

View File

@ -100,7 +100,7 @@ func (h *UserHandler) Register(c *gin.Context) {
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGML: "",
types.ChatGLM: "",
},
}),
Calls: h.App.SysConfig.UserInitCalls,

View File

@ -10,7 +10,7 @@ type User struct {
TotalTokens int64 `json:"total_tokens"` // 总消耗tokens
Calls int `json:"calls"` // 剩余对话次数
ImgCalls int `json:"img_calls"`
ChatConfig types.ChatConfig `json:"chat_config"` // 聊天配置
ChatConfig types.UserChatConfig `json:"chat_config"` // 聊天配置
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
ExpiredTime int64 `json:"expired_time"` // 账户到期时间
Status bool `json:"status"` // 当前状态

View File

@ -8,7 +8,7 @@
title="用户设置"
>
<div class="user-info" id="user-info">
<el-form v-if="form.id" :model="form" label-width="120px">
<el-form v-if="form.id" :model="form" label-width="150px">
<el-form-item label="账户">
<span>{{ form.mobile }}</span>
</el-form-item>
@ -34,8 +34,14 @@
<el-form-item label="累计消耗 Tokens">
<el-tag type="info">{{ form['total_tokens'] }}</el-tag>
</el-form-item>
<el-form-item label="API KEY">
<el-input v-model="form['chat_config']['api_key']"/>
<el-form-item label="OpenAI API KEY">
<el-input v-model="form.chat_config['api_keys']['OpenAI']"/>
</el-form-item>
<el-form-item label="Azure API KEY">
<el-input v-model="form['chat_config']['api_keys']['Azure']"/>
</el-form-item>
<el-form-item label="ChatGLM API KEY">
<el-input v-model="form['chat_config']['api_keys']['ChatGLM']"/>
</el-form-item>
</el-form>
</div>
@ -77,15 +83,16 @@ const form = ref({
mobile: '',
calls: 0,
tokens: 0,
chat_configs: {}
chat_config: {api_keys: {OpenAI: "", Azure: "", ChatGLM: ""}}
})
onMounted(() => {
//
httpGet('/api/user/profile').then(res => {
form.value = res.data
}).catch(() => {
ElMessage.error("获取用户信息失败")
form.value.chat_config.api_keys = res.data.chat_config.api_keys ?? {OpenAI: "", Azure: "", ChatGLM: ""}
}).catch(e => {
ElMessage.error("获取用户信息失败:" + e.message)
});
})

View File

@ -82,7 +82,7 @@ const rules = reactive({
const loading = ref(true)
const formRef = ref(null)
const title = ref("")
const platforms = ref(["Azure", "OpenAI", "ChatGML"])
const platforms = ref(["Azure", "OpenAI", "ChatGLM"])
//
httpGet('/api/admin/apikey/list').then((res) => {

View File

@ -9,7 +9,7 @@
<el-table :data="items" :row-key="row => row.id" table-layout="auto">
<el-table-column prop="platform" label="所属平台">
<template #default="scope">
<span class="sort" :data-id="scope.row.id">{{scope.row.platform}}</span>
<span class="sort" :data-id="scope.row.id">{{ scope.row.platform }}</span>
</template>
</el-table-column>
<el-table-column prop="name" label="模型名称"/>
@ -47,7 +47,7 @@
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
<el-form-item label="所属平台:" prop="platform">
<el-select v-model="item.platform" placeholder="请选择平台">
<el-option v-for="item in platforms" :value="item" :key="item">{{item}}</el-option>
<el-option v-for="item in platforms" :value="item" :key="item">{{ item }}</el-option>
</el-select>
</el-form-item>
@ -94,7 +94,7 @@ const rules = reactive({
})
const loading = ref(true)
const formRef = ref(null)
const platforms = ref(["Azure","OpenAI","ChatGML"])
const platforms = ref(["Azure", "OpenAI", "ChatGLM"])
//
httpGet('/api/admin/model/list').then((res) => {
@ -127,13 +127,13 @@ onMounted(() => {
const sortedData = Array.from(from.children).map(row => row.querySelector('.sort').getAttribute('data-id'));
const ids = []
const sorts = []
sortedData.forEach((id,index) => {
sortedData.forEach((id, index) => {
ids.push(parseInt(id))
sorts.push(index)
})
httpPost("/api/admin/model/sort", {ids: ids, sorts:sorts}).catch(e => {
ElMessage.error("排序失败:"+e.message)
httpPost("/api/admin/model/sort", {ids: ids, sorts: sorts}).catch(e => {
ElMessage.error("排序失败:" + e.message)
})
}
})
@ -174,7 +174,7 @@ const enable = (row) => {
httpPost('/api/admin/model/enable', {id: row.id, enabled: row.enabled}).then(() => {
ElMessage.success("操作成功!")
}).catch(e => {
ElMessage.error("操作失败:"+e.message)
ElMessage.error("操作失败:" + e.message)
})
}