mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
feat: allow user to set custom api keys for different platforms
This commit is contained in:
parent
7ecd7eeba1
commit
32774d23c7
@ -113,11 +113,11 @@ type Platform string
|
||||
|
||||
const OpenAI = Platform("OpenAI")
|
||||
const Azure = Platform("Azure")
|
||||
const ChatGML = Platform("ChatGML")
|
||||
const ChatGLM = Platform("ChatGLM")
|
||||
|
||||
// UserChatConfig 用户的聊天配置
|
||||
type UserChatConfig struct {
|
||||
ApiKeys map[Platform]string
|
||||
ApiKeys map[Platform]string `json:"api_keys"`
|
||||
}
|
||||
|
||||
type ModelAPIConfig struct {
|
||||
|
@ -109,7 +109,7 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
ApiKeys: map[types.Platform]string{
|
||||
types.OpenAI: "",
|
||||
types.Azure: "",
|
||||
types.ChatGML: "",
|
||||
types.ChatGLM: "",
|
||||
},
|
||||
}),
|
||||
Calls: h.App.SysConfig.UserInitCalls,
|
||||
|
@ -28,7 +28,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey string
|
||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
@ -174,7 +174,9 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
|
||||
}
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
@ -183,14 +185,14 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||
if h.App.ChatConfig.EnableContext && functionCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
if userVo.ChatConfig.EnableHistory {
|
||||
if h.App.ChatConfig.EnableHistory {
|
||||
useContext := true
|
||||
if functionCall {
|
||||
useContext = false
|
||||
@ -254,8 +256,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
} else {
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
}
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
return nil
|
||||
}
|
||||
|
||||
if userVo.Calls <= 0 {
|
||||
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
||||
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!")
|
||||
utils.ReplyMessage(ws, "")
|
||||
return nil
|
||||
@ -189,7 +189,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
req.Temperature = h.App.ChatConfig.Azure.Temperature
|
||||
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
|
||||
break
|
||||
case types.ChatGML:
|
||||
case types.ChatGLM:
|
||||
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
|
||||
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
|
||||
break
|
||||
@ -208,7 +208,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
|
||||
// 加载聊天上下文
|
||||
var chatCtx []interface{}
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
if h.App.ChatConfig.EnableContext {
|
||||
if h.App.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
@ -269,11 +269,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.OpenAI:
|
||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.ChatGML:
|
||||
case types.ChatGLM:
|
||||
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
}
|
||||
|
||||
return nil
|
||||
return fmt.Errorf("not supported platform: %s", session.Model.Platform)
|
||||
}
|
||||
|
||||
// Tokens 统计 token 数量
|
||||
@ -336,7 +335,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
md := strings.Replace(req.Model, ".", "", 1)
|
||||
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
|
||||
break
|
||||
case types.ChatGML:
|
||||
case types.ChatGLM:
|
||||
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
|
||||
req.Prompt = req.Messages
|
||||
req.Messages = nil
|
||||
@ -368,6 +367,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
if *apiKey == "" {
|
||||
var key model.ApiKey
|
||||
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
|
||||
if res.Error != nil {
|
||||
@ -375,14 +375,16 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
}
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
*apiKey = key.Value
|
||||
}
|
||||
|
||||
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, key.Value, proxyURL, req.Model)
|
||||
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
|
||||
switch platform {
|
||||
case types.Azure:
|
||||
request.Header.Set("api-key", key.Value)
|
||||
request.Header.Set("api-key", *apiKey)
|
||||
break
|
||||
case types.ChatGML:
|
||||
token, err := h.getChatGLMToken(key.Value)
|
||||
case types.ChatGLM:
|
||||
token, err := h.getChatGLMToken(*apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -390,8 +392,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
break
|
||||
default:
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||
}
|
||||
*apiKey = key.Value
|
||||
return client.Do(request)
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey string
|
||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
@ -103,7 +103,9 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
|
||||
}
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
@ -112,14 +114,14 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
if h.App.ChatConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
if userVo.ChatConfig.EnableHistory {
|
||||
if h.App.ChatConfig.EnableHistory {
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
@ -167,8 +169,6 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||
}
|
||||
@ -205,7 +205,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
if !res.Success {
|
||||
utils.ReplyMessage(ws, "请求 ChatGML 失败:"+res.Msg)
|
||||
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey string
|
||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
@ -174,7 +174,9 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
|
||||
}
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
@ -183,14 +185,14 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||
if h.App.ChatConfig.EnableContext && functionCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
if userVo.ChatConfig.EnableHistory {
|
||||
if h.App.ChatConfig.EnableHistory {
|
||||
useContext := true
|
||||
if functionCall {
|
||||
useContext = false
|
||||
@ -254,8 +256,6 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
} else {
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
}
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
ApiKeys: map[types.Platform]string{
|
||||
types.OpenAI: "",
|
||||
types.Azure: "",
|
||||
types.ChatGML: "",
|
||||
types.ChatGLM: "",
|
||||
},
|
||||
}),
|
||||
Calls: h.App.SysConfig.UserInitCalls,
|
||||
|
@ -10,7 +10,7 @@ type User struct {
|
||||
TotalTokens int64 `json:"total_tokens"` // 总消耗tokens
|
||||
Calls int `json:"calls"` // 剩余对话次数
|
||||
ImgCalls int `json:"img_calls"`
|
||||
ChatConfig types.ChatConfig `json:"chat_config"` // 聊天配置
|
||||
ChatConfig types.UserChatConfig `json:"chat_config"` // 聊天配置
|
||||
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
|
||||
ExpiredTime int64 `json:"expired_time"` // 账户到期时间
|
||||
Status bool `json:"status"` // 当前状态
|
||||
|
@ -8,7 +8,7 @@
|
||||
title="用户设置"
|
||||
>
|
||||
<div class="user-info" id="user-info">
|
||||
<el-form v-if="form.id" :model="form" label-width="120px">
|
||||
<el-form v-if="form.id" :model="form" label-width="150px">
|
||||
<el-form-item label="账户">
|
||||
<span>{{ form.mobile }}</span>
|
||||
</el-form-item>
|
||||
@ -34,8 +34,14 @@
|
||||
<el-form-item label="累计消耗 Tokens">
|
||||
<el-tag type="info">{{ form['total_tokens'] }}</el-tag>
|
||||
</el-form-item>
|
||||
<el-form-item label="API KEY">
|
||||
<el-input v-model="form['chat_config']['api_key']"/>
|
||||
<el-form-item label="OpenAI API KEY">
|
||||
<el-input v-model="form.chat_config['api_keys']['OpenAI']"/>
|
||||
</el-form-item>
|
||||
<el-form-item label="Azure API KEY">
|
||||
<el-input v-model="form['chat_config']['api_keys']['Azure']"/>
|
||||
</el-form-item>
|
||||
<el-form-item label="ChatGLM API KEY">
|
||||
<el-input v-model="form['chat_config']['api_keys']['ChatGLM']"/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
@ -77,15 +83,16 @@ const form = ref({
|
||||
mobile: '',
|
||||
calls: 0,
|
||||
tokens: 0,
|
||||
chat_configs: {}
|
||||
chat_config: {api_keys: {OpenAI: "", Azure: "", ChatGLM: ""}}
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
// 获取最新用户信息
|
||||
httpGet('/api/user/profile').then(res => {
|
||||
form.value = res.data
|
||||
}).catch(() => {
|
||||
ElMessage.error("获取用户信息失败")
|
||||
form.value.chat_config.api_keys = res.data.chat_config.api_keys ?? {OpenAI: "", Azure: "", ChatGLM: ""}
|
||||
}).catch(e => {
|
||||
ElMessage.error("获取用户信息失败:" + e.message)
|
||||
});
|
||||
})
|
||||
|
||||
|
@ -82,7 +82,7 @@ const rules = reactive({
|
||||
const loading = ref(true)
|
||||
const formRef = ref(null)
|
||||
const title = ref("")
|
||||
const platforms = ref(["Azure", "OpenAI", "ChatGML"])
|
||||
const platforms = ref(["Azure", "OpenAI", "ChatGLM"])
|
||||
|
||||
// 获取数据
|
||||
httpGet('/api/admin/apikey/list').then((res) => {
|
||||
|
@ -9,7 +9,7 @@
|
||||
<el-table :data="items" :row-key="row => row.id" table-layout="auto">
|
||||
<el-table-column prop="platform" label="所属平台">
|
||||
<template #default="scope">
|
||||
<span class="sort" :data-id="scope.row.id">{{scope.row.platform}}</span>
|
||||
<span class="sort" :data-id="scope.row.id">{{ scope.row.platform }}</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="name" label="模型名称"/>
|
||||
@ -47,7 +47,7 @@
|
||||
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
|
||||
<el-form-item label="所属平台:" prop="platform">
|
||||
<el-select v-model="item.platform" placeholder="请选择平台">
|
||||
<el-option v-for="item in platforms" :value="item" :key="item">{{item}}</el-option>
|
||||
<el-option v-for="item in platforms" :value="item" :key="item">{{ item }}</el-option>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
|
||||
@ -94,7 +94,7 @@ const rules = reactive({
|
||||
})
|
||||
const loading = ref(true)
|
||||
const formRef = ref(null)
|
||||
const platforms = ref(["Azure","OpenAI","ChatGML"])
|
||||
const platforms = ref(["Azure", "OpenAI", "ChatGLM"])
|
||||
|
||||
// 获取数据
|
||||
httpGet('/api/admin/model/list').then((res) => {
|
||||
@ -127,13 +127,13 @@ onMounted(() => {
|
||||
const sortedData = Array.from(from.children).map(row => row.querySelector('.sort').getAttribute('data-id'));
|
||||
const ids = []
|
||||
const sorts = []
|
||||
sortedData.forEach((id,index) => {
|
||||
sortedData.forEach((id, index) => {
|
||||
ids.push(parseInt(id))
|
||||
sorts.push(index)
|
||||
})
|
||||
|
||||
httpPost("/api/admin/model/sort", {ids: ids, sorts:sorts}).catch(e => {
|
||||
ElMessage.error("排序失败:"+e.message)
|
||||
httpPost("/api/admin/model/sort", {ids: ids, sorts: sorts}).catch(e => {
|
||||
ElMessage.error("排序失败:" + e.message)
|
||||
})
|
||||
}
|
||||
})
|
||||
@ -174,7 +174,7 @@ const enable = (row) => {
|
||||
httpPost('/api/admin/model/enable', {id: row.id, enabled: row.enabled}).then(() => {
|
||||
ElMessage.success("操作成功!")
|
||||
}).catch(e => {
|
||||
ElMessage.error("操作失败:"+e.message)
|
||||
ElMessage.error("操作失败:" + e.message)
|
||||
})
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user