diff --git a/api/core/types/chat.go b/api/core/types/chat.go
index 54917f24..b6b63aa2 100644
--- a/api/core/types/chat.go
+++ b/api/core/types/chat.go
@@ -62,6 +62,7 @@ type ChatModel struct {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
+ KeyId int `json:"key_id"` // 绑定 API KEY
}
type ApiError struct {
diff --git a/api/handler/admin/api_key_handler.go b/api/handler/admin/api_key_handler.go
index 5566b0c0..7935d0ba 100644
--- a/api/handler/admin/api_key_handler.go
+++ b/api/handler/admin/api_key_handler.go
@@ -66,9 +66,20 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
+ status := h.GetBool(c, "status")
+ t := h.GetTrim(c, "type")
+
+ session := h.DB.Session(&gorm.Session{})
+ if status {
+ session = session.Where("enabled", true)
+ }
+ if t != "" {
+ session = session.Where("type", t)
+ }
+
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
- res := h.DB.Find(&items)
+ res := session.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey
diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go
index 4f6ee23e..9e546ac5 100644
--- a/api/handler/admin/chat_model_handler.go
+++ b/api/handler/admin/chat_model_handler.go
@@ -8,8 +8,6 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
- "time"
-
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -35,6 +33,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
+ KeyId int `json:"key_id"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -52,12 +51,15 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
+ KeyId: data.KeyId,
Power: data.Power}
- item.Id = data.Id
- if item.Id > 0 {
- item.CreatedAt = time.Unix(data.CreatedAt, 0)
+ var res *gorm.DB
+ if data.Id > 0 {
+ item.Id = data.Id
+ res = h.DB.Select("*").Omit("created_at").Updates(&item)
+ } else {
+ res = h.DB.Create(&item)
}
- res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -84,18 +86,33 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var cms = make([]vo.ChatModel, 0)
res := session.Order("sort_num ASC").Find(&items)
- if res.Error == nil {
- for _, item := range items {
- var cm vo.ChatModel
- err := utils.CopyObject(item, &cm)
- if err == nil {
- cm.Id = item.Id
- cm.CreatedAt = item.CreatedAt.Unix()
- cm.UpdatedAt = item.UpdatedAt.Unix()
- cms = append(cms, cm)
- } else {
- logger.Error(err)
- }
+ if res.Error != nil {
+ resp.SUCCESS(c, cms)
+ return
+ }
+
+ // initialize key name
+ keyIds := make([]int, 0)
+ for _, v := range items {
+ keyIds = append(keyIds, v.KeyId)
+ }
+ var keys []model.ApiKey
+ keyMap := make(map[uint]string)
+ h.DB.Where("id IN ?", keyIds).Find(&keys)
+ for _, v := range keys {
+ keyMap[v.Id] = v.Name
+ }
+ for _, item := range items {
+ var cm vo.ChatModel
+ err := utils.CopyObject(item, &cm)
+ if err == nil {
+ cm.Id = item.Id
+ cm.CreatedAt = item.CreatedAt.Unix()
+ cm.UpdatedAt = item.UpdatedAt.Unix()
+ cm.KeyName = keyMap[uint(item.KeyId)]
+ cms = append(cms, cm)
+ } else {
+ logger.Error(err)
}
}
resp.SUCCESS(c, cms)
diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go
index a040aae6..11b3b69a 100644
--- a/api/handler/chatimpl/azure_handler.go
+++ b/api/handler/chatimpl/azure_handler.go
@@ -30,7 +30,7 @@ func (h *ChatHandler) sendAzureMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
- response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go
index e39ae455..08809dfe 100644
--- a/api/handler/chatimpl/baidu_handler.go
+++ b/api/handler/chatimpl/baidu_handler.go
@@ -47,7 +47,7 @@ func (h *ChatHandler) sendBaiduMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
- response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go
index 5da9af69..08785752 100644
--- a/api/handler/chatimpl/chat_handler.go
+++ b/api/handler/chatimpl/chat_handler.go
@@ -122,6 +122,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
+ KeyId: chatModel.KeyId,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
@@ -463,13 +464,21 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
-func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
- res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
+func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
+ // if the chat model bind a KEY, use it directly
+ var res *gorm.DB
+ if session.Model.KeyId > 0 {
+ res = h.DB.Where("id", session.Model.KeyId).Find(apiKey)
+ }
+ // use the last unused key
+ if res.Error != nil {
+ res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
+ }
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
var apiURL string
- switch platform {
+ switch session.Model.Platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
@@ -492,7 +501,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
- if platform == types.Baidu {
+ if session.Model.Platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
@@ -527,8 +536,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
- logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
- switch platform {
+ logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, proxyURL, req.Model)
+ switch session.Model.Platform {
case types.Azure:
request.Header.Set("api-key", apiKey.Value)
break
diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go
index 678f481d..5f391b3f 100644
--- a/api/handler/chatimpl/chatglm_handler.go
+++ b/api/handler/chatimpl/chatglm_handler.go
@@ -31,7 +31,7 @@ func (h *ChatHandler) sendChatGLMMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
- response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go
index c4a29338..c991f670 100644
--- a/api/handler/chatimpl/openai_handler.go
+++ b/api/handler/chatimpl/openai_handler.go
@@ -31,7 +31,7 @@ func (h *ChatHandler) sendOpenAiMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
- response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go
index 13b0156d..4484e57b 100644
--- a/api/handler/chatimpl/qwen_handler.go
+++ b/api/handler/chatimpl/qwen_handler.go
@@ -45,7 +45,7 @@ func (h *ChatHandler) sendQWenMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
- response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go
index adb646dc..36a5b785 100644
--- a/api/handler/chatimpl/xunfei_handler.go
+++ b/api/handler/chatimpl/xunfei_handler.go
@@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
+ "gorm.io/gorm"
"html/template"
"io"
"net/http"
@@ -69,7 +70,15 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
- res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
+ var res *gorm.DB
+ // use the bind key
+ if session.Model.KeyId > 0 {
+ res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
+ }
+ // use the last unused key
+ if res.Error != nil {
+ res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
+ }
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
return nil
diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go
index fa8762c9..e0e0f020 100644
--- a/api/handler/mj_handler.go
+++ b/api/handler/mj_handler.go
@@ -125,7 +125,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
params += fmt.Sprintf(" --c %d", data.Chaos)
}
if len(data.ImgArr) > 0 && data.Iw > 0 {
- params += fmt.Sprintf(" --iw %f", data.Iw)
+ params += fmt.Sprintf(" --iw %.2f", data.Iw)
}
if data.Raw {
params += " --style raw"
diff --git a/api/store/model/chat_model.go b/api/store/model/chat_model.go
index 8ddff961..134655f3 100644
--- a/api/store/model/chat_model.go
+++ b/api/store/model/chat_model.go
@@ -12,4 +12,5 @@ type ChatModel struct {
MaxTokens int // 最大响应长度
MaxContext int // 最大上下文长度
Temperature float32 // 模型温度
+ KeyId int // 绑定 API KEY ID
}
diff --git a/api/store/vo/chat_model.go b/api/store/vo/chat_model.go
index 81fc18ca..4fb21051 100644
--- a/api/store/vo/chat_model.go
+++ b/api/store/vo/chat_model.go
@@ -12,4 +12,6 @@ type ChatModel struct {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
+ KeyId int `json:"key_id"`
+ KeyName string `json:"key_name"`
}
diff --git a/database/update-v4.0.3.sql b/database/update-v4.0.3.sql
index fb22e6dd..219c4187 100644
--- a/database/update-v4.0.3.sql
+++ b/database/update-v4.0.3.sql
@@ -1 +1,2 @@
-ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`;
\ No newline at end of file
+ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`;
+ALTER TABLE `chatgpt_chat_models` ADD `key_id` INT(11) NOT NULL COMMENT '绑定API KEY ID' AFTER `open`;
\ No newline at end of file
diff --git a/web/.env.development b/web/.env.development
index 330da87b..8474f044 100644
--- a/web/.env.development
+++ b/web/.env.development
@@ -6,4 +6,4 @@ VUE_APP_ADMIN_USER=admin
VUE_APP_ADMIN_PASS=admin123
VUE_APP_KEY_PREFIX=ChatPLUS_DEV_
VUE_APP_TITLE="Geek-AI 创作系统"
-VUE_APP_VERSION=v4.0.2
+VUE_APP_VERSION=v4.0.3
diff --git a/web/.env.production b/web/.env.production
index c6581695..e1a98fa3 100644
--- a/web/.env.production
+++ b/web/.env.production
@@ -2,4 +2,4 @@ VUE_APP_API_HOST=
VUE_APP_WS_HOST=
VUE_APP_KEY_PREFIX=ChatPLUS_
VUE_APP_TITLE="Geek-AI 创作系统"
-VUE_APP_VERSION=v4.0.2
+VUE_APP_VERSION=v4.0.3
diff --git a/web/src/components/admin/AdminSidebar.vue b/web/src/components/admin/AdminSidebar.vue
index 0393c984..6b52f544 100644
--- a/web/src/components/admin/AdminSidebar.vue
+++ b/web/src/components/admin/AdminSidebar.vue
@@ -63,7 +63,8 @@ const logo = ref('/images/logo.png')
// 加载系统配置
httpGet('/api/admin/config/get?key=system').then(res => {
- title.value = res.data['admin_title'];
+ title.value = res.data['admin_title']
+ logo.value = res.data['logo']
}).catch(e => {
ElMessage.error("加载系统配置失败: " + e.message)
})
@@ -191,9 +192,9 @@ setMenuItems(items)
padding 6px 15px;
.el-image {
- width 30px;
- height 30px;
- padding-top 8px;
+ width 36px;
+ height 36px;
+ padding-top 5px;
border-radius 100%
.el-image__inner {
diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue
index 1da5b107..f4d85f56 100644
--- a/web/src/views/ChatPlus.vue
+++ b/web/src/views/ChatPlus.vue
@@ -377,16 +377,7 @@ const initData = () => {
httpGet(`/api/role/list`).then((res) => {
roles.value = res.data;
roleId.value = roles.value[0]['id'];
-
- const chatId = localStorage.getItem("chat_id")
- const chat = getChatById(chatId)
- if (chat === null) {
- // 创建新的对话
- newChat();
- } else {
- // 加载对话
- loadChat(chat)
- }
+ newChat();
}).catch((e) => {
ElMessage.error('获取聊天角色失败: ' + e.messages)
})
diff --git a/web/src/views/Home.vue b/web/src/views/Home.vue
index e842ae90..e320098a 100644
--- a/web/src/views/Home.vue
+++ b/web/src/views/Home.vue
@@ -2,7 +2,7 @@
@@ -75,6 +75,7 @@ onMounted(() => {
display flex
flex-flow column
align-items center
+ cursor pointer
.el-image {
width 50px
diff --git a/web/src/views/Index.vue b/web/src/views/Index.vue
index 3fc9bf6b..014cedf4 100644
--- a/web/src/views/Index.vue
+++ b/web/src/views/Index.vue
@@ -1,7 +1,7 @@
-
{{title}}
+
欢迎使用 {{ title }}
{{slogan}}
@@ -20,15 +20,22 @@ import * as THREE from 'three';
import {onMounted, ref} from "vue";
import {useRouter} from "vue-router";
import FooterBar from "@/components/FooterBar.vue";
+import {httpGet} from "@/utils/http";
+import {ElMessage} from "element-plus";
const router = useRouter()
-const title = ref("欢迎使用 Geek-AI 创作系统")
+const title = ref("Geek-AI 创作系统")
const slogan = ref("我辈之人,先干为敬,陪您先把 AI 用起来")
const size = window.innerHeight * 0.8
const winHeight = window.innerHeight - 150
onMounted(() => {
+ httpGet("/api/config/get?key=system").then(res => {
+ title.value = res.data['title']
+ }).catch(e => {
+ ElMessage.error("获取系统配置失败:" + e.message)
+ })
init()
})
@@ -77,7 +84,7 @@ const init = () => {
requestAnimationFrame(animate);
// 使地球自转和公转
- earth.rotation.y += 0.002;
+ earth.rotation.y += 0.001;
renderer.render(scene, camera);
};
diff --git a/web/src/views/admin/ChatModel.vue b/web/src/views/admin/ChatModel.vue
index 5a4c614f..421fc4bd 100644
--- a/web/src/views/admin/ChatModel.vue
+++ b/web/src/views/admin/ChatModel.vue
@@ -1,5 +1,5 @@
-
+
新增
@@ -13,7 +13,14 @@
-
+
+
+ {{ scope.row.value }}
+
+
+
+
+
@@ -29,12 +36,12 @@
-
-
- {{ dateFormat(scope.row['created_at']) }}
-
-
-
+
+
+
+
+
+
编辑
@@ -75,7 +82,7 @@