feat: allow chat model bind a fixed api key

This commit is contained in:
RockYang 2024-04-12 17:09:22 +08:00
parent 3b292c2a12
commit 0a01b55713
22 changed files with 183 additions and 84 deletions

View File

@ -62,6 +62,7 @@ type ChatModel struct {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id"` // 绑定 API KEY
}
type ApiError struct {

View File

@ -66,9 +66,20 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
status := h.GetBool(c, "status")
t := h.GetTrim(c, "type")
session := h.DB.Session(&gorm.Session{})
if status {
session = session.Where("enabled", true)
}
if t != "" {
session = session.Where("type", t)
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := h.DB.Find(&items)
res := session.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey

View File

@ -8,8 +8,6 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@ -35,6 +33,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@ -52,12 +51,15 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
KeyId: data.KeyId,
Power: data.Power}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
var res *gorm.DB
if data.Id > 0 {
item.Id = data.Id
res = h.DB.Select("*").Omit("created_at").Updates(&item)
} else {
res = h.DB.Create(&item)
}
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@ -84,18 +86,33 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var cms = make([]vo.ChatModel, 0)
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var cm vo.ChatModel
err := utils.CopyObject(item, &cm)
if err == nil {
cm.Id = item.Id
cm.CreatedAt = item.CreatedAt.Unix()
cm.UpdatedAt = item.UpdatedAt.Unix()
cms = append(cms, cm)
} else {
logger.Error(err)
}
if res.Error != nil {
resp.SUCCESS(c, cms)
return
}
// initialize key name
keyIds := make([]int, 0)
for _, v := range items {
keyIds = append(keyIds, v.KeyId)
}
var keys []model.ApiKey
keyMap := make(map[uint]string)
h.DB.Where("id IN ?", keyIds).Find(&keys)
for _, v := range keys {
keyMap[v.Id] = v.Name
}
for _, item := range items {
var cm vo.ChatModel
err := utils.CopyObject(item, &cm)
if err == nil {
cm.Id = item.Id
cm.CreatedAt = item.CreatedAt.Unix()
cm.UpdatedAt = item.UpdatedAt.Unix()
cm.KeyName = keyMap[uint(item.KeyId)]
cms = append(cms, cm)
} else {
logger.Error(err)
}
}
resp.SUCCESS(c, cms)

View File

@ -30,7 +30,7 @@ func (h *ChatHandler) sendAzureMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@ -47,7 +47,7 @@ func (h *ChatHandler) sendBaiduMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@ -122,6 +122,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
@ -463,13 +464,21 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Find(apiKey)
}
// use the last unused key
if res.Error != nil {
res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
}
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
var apiURL string
switch platform {
switch session.Model.Platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
@ -492,7 +501,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
if session.Model.Platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
@ -527,8 +536,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch platform {
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch session.Model.Platform {
case types.Azure:
request.Header.Set("api-key", apiKey.Value)
break

View File

@ -31,7 +31,7 @@ func (h *ChatHandler) sendChatGLMMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@ -31,7 +31,7 @@ func (h *ChatHandler) sendOpenAiMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@ -45,7 +45,7 @@ func (h *ChatHandler) sendQWenMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"html/template"
"io"
"net/http"
@ -69,7 +70,15 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
}
// use the last unused key
if res.Error != nil {
res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil

View File

@ -125,7 +125,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
params += fmt.Sprintf(" --c %d", data.Chaos)
}
if len(data.ImgArr) > 0 && data.Iw > 0 {
params += fmt.Sprintf(" --iw %f", data.Iw)
params += fmt.Sprintf(" --iw %.2f", data.Iw)
}
if data.Raw {
params += " --style raw"

View File

@ -12,4 +12,5 @@ type ChatModel struct {
MaxTokens int // 最大响应长度
MaxContext int // 最大上下文长度
Temperature float32 // 模型温度
KeyId int // 绑定 API KEY ID
}

View File

@ -12,4 +12,6 @@ type ChatModel struct {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id"`
KeyName string `json:"key_name"`
}

View File

@ -1 +1,2 @@
ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`;
ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`;
ALTER TABLE `chatgpt_chat_models` ADD `key_id` INT(11) NOT NULL COMMENT '绑定API KEY ID' AFTER `open`;

View File

@ -6,4 +6,4 @@ VUE_APP_ADMIN_USER=admin
VUE_APP_ADMIN_PASS=admin123
VUE_APP_KEY_PREFIX=ChatPLUS_DEV_
VUE_APP_TITLE="Geek-AI 创作系统"
VUE_APP_VERSION=v4.0.2
VUE_APP_VERSION=v4.0.3

View File

@ -2,4 +2,4 @@ VUE_APP_API_HOST=
VUE_APP_WS_HOST=
VUE_APP_KEY_PREFIX=ChatPLUS_
VUE_APP_TITLE="Geek-AI 创作系统"
VUE_APP_VERSION=v4.0.2
VUE_APP_VERSION=v4.0.3

View File

@ -63,7 +63,8 @@ const logo = ref('/images/logo.png')
//
httpGet('/api/admin/config/get?key=system').then(res => {
title.value = res.data['admin_title'];
title.value = res.data['admin_title']
logo.value = res.data['logo']
}).catch(e => {
ElMessage.error("加载系统配置失败: " + e.message)
})
@ -191,9 +192,9 @@ setMenuItems(items)
padding 6px 15px;
.el-image {
width 30px;
height 30px;
padding-top 8px;
width 36px;
height 36px;
padding-top 5px;
border-radius 100%
.el-image__inner {

View File

@ -377,16 +377,7 @@ const initData = () => {
httpGet(`/api/role/list`).then((res) => {
roles.value = res.data;
roleId.value = roles.value[0]['id'];
const chatId = localStorage.getItem("chat_id")
const chat = getChatById(chatId)
if (chat === null) {
//
newChat();
} else {
//
loadChat(chat)
}
newChat();
}).catch((e) => {
ElMessage.error('获取聊天角色失败: ' + e.messages)
})

View File

@ -2,7 +2,7 @@
<div class="home">
<div class="navigator">
<div class="logo">
<el-image :src="logo"/>
<el-image :src="logo" @click="router.push('/')"/>
<div class="divider"></div>
</div>
<ul class="nav-items">
@ -75,6 +75,7 @@ onMounted(() => {
display flex
flex-flow column
align-items center
cursor pointer
.el-image {
width 50px

View File

@ -1,7 +1,7 @@
<template>
<div class="index-page" :style="{height: winHeight+'px'}">
<div class="content">
<h1>{{title}}</h1>
<h1>欢迎使用 {{ title }}</h1>
<p>{{slogan}}</p>
<button class="btn" @click="router.push('/chat')">立即使用</button>
@ -20,15 +20,22 @@ import * as THREE from 'three';
import {onMounted, ref} from "vue";
import {useRouter} from "vue-router";
import FooterBar from "@/components/FooterBar.vue";
import {httpGet} from "@/utils/http";
import {ElMessage} from "element-plus";
const router = useRouter()
const title = ref("欢迎使用 Geek-AI 创作系统")
const title = ref("Geek-AI 创作系统")
const slogan = ref("我辈之人,先干为敬,陪您先把 AI 用起来")
const size = window.innerHeight * 0.8
const winHeight = window.innerHeight - 150
onMounted(() => {
httpGet("/api/config/get?key=system").then(res => {
title.value = res.data['title']
}).catch(e => {
ElMessage.error("获取系统配置失败:" + e.message)
})
init()
})
@ -77,7 +84,7 @@ const init = () => {
requestAnimationFrame(animate);
// 使
earth.rotation.y += 0.002;
earth.rotation.y += 0.001;
renderer.render(scene, camera);
};

View File

@ -1,5 +1,5 @@
<template>
<div class="container list" v-loading="loading">
<div class="container model-list" v-loading="loading">
<div class="handle-box">
<el-button type="primary" :icon="Plus" @click="add">新增</el-button>
@ -13,7 +13,14 @@
</template>
</el-table-column>
<el-table-column prop="name" label="模型名称"/>
<el-table-column prop="value" label="模型值"/>
<el-table-column prop="value" label="模型值">
<template #default="scope">
<span>{{ scope.row.value }}</span>
<el-icon class="copy-model" :data-clipboard-text="scope.row.value">
<DocumentCopy/>
</el-icon>
</template>
</el-table-column>
<el-table-column prop="power" label="费率"/>
<el-table-column prop="max_tokens" label="最大响应长度"/>
<el-table-column prop="max_context" label="最大上下文长度"/>
@ -29,12 +36,12 @@
</template>
</el-table-column>
<el-table-column label="创建时间">
<template #default="scope">
<span>{{ dateFormat(scope.row['created_at']) }}</span>
</template>
</el-table-column>
<!-- <el-table-column label="创建时间">-->
<!-- <template #default="scope">-->
<!-- <span>{{ dateFormat(scope.row['created_at']) }}</span>-->
<!-- </template>-->
<!-- </el-table-column>-->
<el-table-column prop="key_name" label="绑定API-KEY"/>
<el-table-column label="操作" width="180">
<template #default="scope">
<el-button size="small" type="primary" @click="edit(scope.row)">编辑</el-button>
@ -75,7 +82,7 @@
<el-form-item label="费率:" prop="weight">
<template #default>
<div class="tip-input">
<el-input-number :min="1" v-model="item.power" autocomplete="off"/>
<el-input-number :min="0" v-model="item.power" autocomplete="off"/>
<div class="info">
<el-tooltip
class="box-item"
@ -144,6 +151,15 @@
</div>
</el-form-item>
<el-form-item label="绑定API-KEY" prop="apikey">
<el-select v-model="item.key_id" placeholder="请选择 API KEY">
<el-option v-for="v in apiKeys" :value="v.id" :label="v.name" :key="v.id">
{{ v.name }}
<el-text type="info" size="small">{{ substr(v.api_url, 50) }}</el-text>
</el-option>
</el-select>
</el-form-item>
<el-form-item label="启用状态:" prop="enable">
<el-switch v-model="item.enabled"/>
</el-form-item>
@ -178,12 +194,13 @@
</template>
<script setup>
import {onMounted, reactive, ref} from "vue";
import {onMounted, onUnmounted, reactive, ref} from "vue";
import {httpGet, httpPost} from "@/utils/http";
import {ElMessage} from "element-plus";
import {dateFormat, removeArrayItem} from "@/utils/libs";
import {InfoFilled, Plus} from "@element-plus/icons-vue";
import {dateFormat, removeArrayItem, substr} from "@/utils/libs";
import {DocumentCopy, InfoFilled, Plus} from "@element-plus/icons-vue";
import {Sortable} from "sortablejs";
import ClipboardJS from "clipboard";
//
const items = ref([])
@ -207,23 +224,34 @@ const platforms = ref([
])
//
httpGet('/api/admin/model/list').then((res) => {
if (res.data) {
//
const arr = res.data;
for (let i = 0; i < arr.length; i++) {
arr[i].last_used_at = dateFormat(arr[i].last_used_at)
}
items.value = arr
}
loading.value = false
}).catch(() => {
ElMessage.error("获取数据失败");
// API KEY
const apiKeys = ref([])
httpGet('/api/admin/apikey/list?status=true&type=chat').then(res => {
apiKeys.value = res.data
}).catch(e => {
ElMessage.error("获取 API KEY 失败:" + e.message)
})
//
const fetchData = () => {
httpGet('/api/admin/model/list').then((res) => {
if (res.data) {
//
const arr = res.data;
for (let i = 0; i < arr.length; i++) {
arr[i].last_used_at = dateFormat(arr[i].last_used_at)
}
items.value = arr
}
loading.value = false
}).catch(() => {
ElMessage.error("获取数据失败");
})
}
const clipboard = ref(null)
onMounted(() => {
fetchData()
const drawBodyWrapper = document.querySelector('.el-table__body tbody')
//
@ -250,6 +278,19 @@ onMounted(() => {
})
}
})
clipboard.value = new ClipboardJS('.copy-model');
clipboard.value.on('success', () => {
ElMessage.success('复制成功!');
})
clipboard.value.on('error', () => {
ElMessage.error('复制失败!');
})
})
onUnmounted(() => {
clipboard.value.destroy()
})
const add = function () {
@ -267,14 +308,14 @@ const edit = function (row) {
const save = function () {
formRef.value.validate((valid) => {
item.value.temperature = parseFloat(item.value.temperature)
if (!item.value.sort_num) {
item.value.sort_num = items.value.length
}
if (valid) {
showDialog.value = false
httpPost('/api/admin/model/save', item.value).then((res) => {
ElMessage.success('操作成功!')
if (!item.value['id']) {
const newItem = res.data
items.value.push(newItem)
}
fetchData()
}).catch((e) => {
ElMessage.error('操作失败,' + e.message)
})
@ -306,7 +347,7 @@ const remove = function (row) {
<style lang="stylus" scoped>
@import "@/assets/css/admin/form.styl";
.list {
.model-list {
.opt-box {
padding-bottom: 10px;
@ -318,6 +359,13 @@ const remove = function (row) {
}
}
.cell {
.copy-model {
margin-left 6px
cursor pointer
}
}
.el-select {
width: 100%
}

View File

@ -19,7 +19,7 @@ import {ref} from "vue";
import ImageMj from "@/views/mobile/ImageMj.vue";
import ImageSd from "@/views/mobile/ImageSd.vue";
const activeName = ref("sd")
const activeName = ref("mj")
</script>
<style lang="stylus">