重构主体工作完成

This commit is contained in:
RockYang
2024-03-15 18:35:10 +08:00
61 changed files with 1043 additions and 1097 deletions

View File

@@ -30,8 +30,7 @@ type AppServer struct {
Engine *gin.Engine
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // chat config cache
SysConfig *types.SystemConfig // system config cache
SysConfig *types.SystemConfig // system config cache
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
@@ -69,23 +68,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
}
func (s *AppServer) Run(db *gorm.DB) error {
// load chat config from database
var chatConfig model.Config
res := db.Where("marker", "chat").First(&chatConfig)
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
if err != nil {
return err
}
// load system configs
var sysConfig model.Config
res = db.Where("marker", "system").First(&sysConfig)
res := db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return res.Error
}
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil {
return err
}

View File

@@ -54,10 +54,13 @@ type ChatSession struct {
}
type ChatModel struct {
Id uint `json:"id"`
Platform Platform `json:"platform"`
Value string `json:"value"`
Power int `json:"power"`
Id uint `json:"id"`
Platform Platform `json:"platform"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
}
type ApiError struct {
@@ -72,27 +75,6 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message
var ModelToTokens = map[string]int{
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"chatglm_pro": 32768, // 清华智普
"chatglm_std": 16384,
"chatglm_lite": 4096,
"ernie_bot_turbo": 8192, // 文心一言
"general": 8192, // 科大讯飞
"general2": 8192,
"general3": 8192,
}
func GetModelMaxToken(model string) int {
if token, ok := ModelToTokens[model]; ok {
return token
}
return 4096
}
// PowerType 算力日志类型
type PowerType int

View File

@@ -121,20 +121,6 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
// ChatConfig 系统默认的聊天配置
type ChatConfig struct {
OpenAI ModelAPIConfig `json:"open_ai"`
Azure ModelAPIConfig `json:"azure"`
ChatGML ModelAPIConfig `json:"chat_gml"`
Baidu ModelAPIConfig `json:"baidu"`
XunFei ModelAPIConfig `json:"xun_fei"`
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
ContextDeep int `json:"context_deep"` // 上下文深度
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
}
type Platform string
const OpenAI = Platform("OpenAI")
@@ -144,16 +130,6 @@ const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei")
const QWen = Platform("QWen")
// UserChatConfig 用户的聊天配置
type UserChatConfig struct {
ApiKeys map[Platform]string `json:"api_keys"`
}
type ModelAPIConfig struct {
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
}
type SystemConfig struct {
Title string `json:"title"`
AdminTitle string `json:"admin_title"`
@@ -178,4 +154,7 @@ type SystemConfig struct {
DallPower int `json:"dall_power"` // DALLE3 绘图消耗算力
WechatCardURL string `json:"wechat_card_url"` // 微信客服地址
EnableContext bool `json:"enable_context"`
ContextDeep int `json:"context_deep"`
}

View File

@@ -32,7 +32,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
Value string `json:"value"`
ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"`
UseProxy bool `json:"use_proxy"`
ProxyURL string `json:"proxy_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -48,7 +48,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey.Type = data.Type
apiKey.ApiURL = data.ApiURL
apiKey.Enabled = data.Enabled
apiKey.UseProxy = data.UseProxy
apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name
res := h.db.Save(&apiKey)
if res.Error != nil {

View File

@@ -24,14 +24,15 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
}
type chatItemVo struct {
Username string `json:"username"`
UserId uint `json:"user_id"`
ChatId string `json:"chat_id"`
Title string `json:"title"`
Model string `json:"model"`
Token int `json:"token"`
CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
Username string `json:"username"`
UserId uint `json:"user_id"`
ChatId string `json:"chat_id"`
Title string `json:"title"`
Role vo.ChatRole `json:"role"`
Model string `json:"model"`
Token int `json:"token"`
CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
}
func (h *ChatHandler) List(c *gin.Context) {
@@ -78,18 +79,23 @@ func (h *ChatHandler) List(c *gin.Context) {
if res.Error == nil {
userIds := make([]uint, 0)
chatIds := make([]string, 0)
roleIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
chatIds = append(chatIds, item.ChatId)
roleIds = append(roleIds, item.RoleId)
}
var messages []model.ChatMessage
var users []model.User
var roles []model.ChatRole
h.db.Where("chat_id IN ?", chatIds).Find(&messages)
h.db.Where("id IN ?", userIds).Find(&users)
h.db.Where("id IN ?", roleIds).Find(&roles)
tokenMap := make(map[string]int)
userMap := make(map[uint]string)
msgMap := make(map[string]int)
roleMap := make(map[uint]vo.ChatRole)
for _, msg := range messages {
tokenMap[msg.ChatId] += msg.Tokens
msgMap[msg.ChatId] += 1
@@ -97,6 +103,14 @@ func (h *ChatHandler) List(c *gin.Context) {
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, r := range roles {
var roleVo vo.ChatRole
err := utils.CopyObject(r, &roleVo)
if err != nil {
continue
}
roleMap[r.Id] = roleVo
}
for _, item := range items {
list = append(list, chatItemVo{
UserId: item.UserId,
@@ -106,6 +120,7 @@ func (h *ChatHandler) List(c *gin.Context) {
Model: item.Model,
Token: tokenMap[item.ChatId],
MsgNum: msgMap[item.ChatId],
Role: roleMap[item.RoleId],
CreatedAt: item.CreatedAt.Unix(),
})
}

View File

@@ -26,15 +26,18 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
func (h *ChatModelHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Weight int `json:"weight"`
CreatedAt int64 `json:"created_at"`
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature string `json:"temperature"` // 模型温度
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -42,13 +45,16 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
}
item := model.ChatModel{
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
Power: data.Weight}
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: float32(utils.Str2Float(data.Temperature)),
Power: data.Power}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
@@ -145,19 +151,16 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
}
func (h *ChatModelHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.db.Where("id = ?", data.Id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -56,8 +56,6 @@ func (h *ConfigHandler) Update(c *gin.Context) {
var err error
if data.Key == "system" {
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
} else if data.Key == "chat" {
err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
}
if err != nil {
resp.ERROR(c, "Failed to update config cache: "+err.Error())

View File

@@ -107,13 +107,6 @@ func (h *UserHandler) Save(c *gin.Context) {
ChatRoles: utils.JsonEncode(data.ChatRoles),
ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
}
res = h.db.Create(&u)
_ = utils.CopyObject(u, &userVo)

View File

@@ -111,66 +111,64 @@ func (h *ChatHandler) sendAzureMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId

View File

@@ -135,65 +135,63 @@ func (h *ChatHandler) sendBaiduMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId

View File

@@ -57,8 +57,6 @@ func (h *ChatHandler) Init() {
}
}
var chatConfig types.ChatConfig
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
@@ -109,10 +107,13 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.ChatId = chatId
session.Model = types.ChatModel{
Id: chatModel.Id,
Value: chatModel.Value,
Power: chatModel.Power,
Platform: types.Platform(chatModel.Platform)}
Id: chatModel.Id,
Value: chatModel.Value,
Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
@@ -122,15 +123,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return
}
// 初始化聊天配置
var config model.Config
h.db.Where("marker", "chat").First(&config)
err = utils.JsonDecode(config.Config, &chatConfig)
if err != nil {
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
c.Abort()
return
}
h.Init()
// 保存会话连接
@@ -213,7 +205,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
if userVo.Power <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
if userVo.Power <= 0 {
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
utils.ReplyMessage(ws, ErrImg)
return nil
@@ -227,7 +219,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > types.GetModelMaxToken(session.Model.Value) {
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
@@ -237,21 +229,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Stream: true,
}
switch session.Model.Platform {
case types.Azure:
req.Temperature = h.App.ChatConfig.Azure.Temperature
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
break
case types.ChatGLM:
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
break
case types.Baidu:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
// TODO 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.db.Where("enabled", true).Find(&items)
@@ -283,15 +267,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.Tools = tools
req.ToolChoice = "auto"
}
case types.XunFei:
req.Temperature = h.App.ChatConfig.XunFei.Temperature
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
break
case types.QWen:
req.Input = map[string]interface{}{"messages": []map[string]string{{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}}}
req.Parameters = map[string]interface{}{}
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
@@ -301,14 +283,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
// 加载聊天上下文
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
messages = h.App.ChatContexts.Get(session.ChatId)
} else {
_ = utils.JsonDecode(role.Context, &messages)
if chatConfig.ContextDeep > 0 {
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
@@ -331,12 +313,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= types.GetModelMaxToken(req.Model) {
if tokens+tks >= session.Model.MaxContext {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
break
}
@@ -351,10 +333,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m)
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
if session.Model.Platform == types.QWen {
req.Input = map[string]interface{}{"prompt": prompt}
if len(reqMgs) > 0 {
req.Input["messages"] = reqMgs
}
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
}
switch session.Model.Platform {
case types.Azure:
@@ -497,9 +486,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
var proxyURL string
if h.App.Config.ProxyURL != "" && apiKey.UseProxy { // 使用代理
proxyURL = h.App.Config.ProxyURL
proxy, _ := url.Parse(proxyURL)
if apiKey.ProxyURL != "" { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
@@ -542,7 +530,7 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
res := h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
h.db.Debug().Create(&model.PowerLog{
h.db.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,

View File

@@ -126,7 +126,7 @@ func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Debug().Where("chat_id = ?", chatId).Find(&items)
res := h.db.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return

View File

@@ -114,66 +114,64 @@ func (h *ChatHandler) sendChatGLMMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId

View File

@@ -180,79 +180,77 @@ func (h *ChatHandler) sendOpenAiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext && toolCall == false {
if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId

View File

@@ -20,13 +20,16 @@ type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
@@ -70,6 +73,7 @@ func (h *ChatHandler) sendQWenMessage(
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
@@ -77,24 +81,32 @@ func (h *ChatHandler) sendQWenMessage(
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
var resp qWenResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 { // 发送消息头
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
@@ -135,66 +147,64 @@ func (h *ChatHandler) sendQWenMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId

View File

@@ -50,9 +50,10 @@ type xunFeiResp struct {
}
var Model2URL = map[string]string{
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
@@ -86,6 +87,7 @@ func (h *ChatHandler) sendXunFeiMessage(
}
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
@@ -173,66 +175,64 @@ func (h *ChatHandler) sendXunFeiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -260,7 +260,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": float64(req.Temperature),
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",

View File

@@ -22,7 +22,6 @@ type FunctionHandler struct {
db *gorm.DB
config types.ChatPlusApiConfig
uploadManager *oss.UploaderManager
proxyURL string
}
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
@@ -33,7 +32,6 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo
db: db,
config: config.ApiConfig,
uploadManager: manager,
proxyURL: config.ProxyURL,
}
}
@@ -213,47 +211,28 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return
}
// get image generation api URL
var conf model.Config
var chatConfig types.ChatConfig
tx = h.db.Where("marker", "chat").First(&conf)
if tx.Error != nil {
resp.ERROR(c, "error with get chat configs:"+tx.Error.Error())
return
}
err := utils.JsonDecode(conf.Config, &chatConfig)
if err != nil {
resp.ERROR(c, "error with decode chat config: "+err.Error())
return
}
// translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]), h.App.Config.ProxyURL)
pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
if err == nil {
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt)
prompt = pt
}
imgNum := chatConfig.DallImgNum
if imgNum <= 0 {
imgNum = 1
}
var res imgRes
var errRes ErrRes
var request *req.Request
if apiKey.UseProxy && h.proxyURL != "" {
request = req.C().SetProxyURL(h.proxyURL).R()
if apiKey.ProxyURL != "" {
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
} else {
request = req.C().R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, h.proxyURL)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := request.SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: imgNum,
N: 1,
Size: "1024x1024",
}).
SetErrorResult(&errRes).

View File

@@ -35,7 +35,7 @@ func (h *PromptHandler) Rewrite(c *gin.Context) {
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -53,7 +53,7 @@ func (h *PromptHandler) Translate(c *gin.Context) {
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return

View File

@@ -95,14 +95,7 @@ func (h *UserHandler) Register(c *gin.Context) {
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Power: h.App.SysConfig.InitPower,
Power: h.App.SysConfig.InitPower,
}
res = h.db.Create(&user)
@@ -245,14 +238,13 @@ func (h *UserHandler) Session(c *gin.Context) {
}
type userProfile struct {
Id uint `json:"id"`
Nickname string `json:"nickname"`
Username string `json:"username"`
Avatar string `json:"avatar"`
ChatConfig types.UserChatConfig `json:"chat_config"`
Power int `json:"power"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
Id uint `json:"id"`
Nickname string `json:"nickname"`
Username string `json:"username"`
Avatar string `json:"avatar"`
Power int `json:"power"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
}
func (h *UserHandler) Profile(c *gin.Context) {

View File

@@ -315,7 +315,7 @@ func main() {
group.GET("list", h.List)
group.POST("set", h.Set)
group.POST("sort", h.Sort)
group.POST("remove", h.Remove)
group.GET("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
group := s.Engine.Group("/api/payment/")

View File

@@ -9,6 +9,6 @@ type ApiKey struct {
Value string // API Key 的值
ApiURL string // 当前 KEY 的 API 地址
Enabled bool // 是否启用
UseProxy bool // 是否使用代理访问 API URL
ProxyURL string // 代理地址
LastUsedAt int64 // 最后使用时间
}

View File

@@ -2,11 +2,14 @@ package model
type ChatModel struct {
BaseModel
Platform string
Name string
Value string // API Key 的值
SortNum int
Enabled bool
Power int // 每次对话消耗算力
Open bool // 是否开放模型给所有人使用
Platform string
Name string
Value string // API Key 的值
SortNum int
Enabled bool
Power int // 每次对话消耗算力
Open bool // 是否开放模型给所有人使用
MaxTokens int // 最大响应长度
MaxContext int // 最大上下文长度
Temperature float32 // 模型温度
}

View File

@@ -9,6 +9,6 @@ type ApiKey struct {
Value string `json:"value"` // API Key 的值
ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"`
UseProxy bool `json:"use_proxy"`
ProxyURL string `json:"proxy_url"`
LastUsedAt int64 `json:"last_used_at"` // 最后使用时间
}

View File

@@ -2,11 +2,14 @@ package vo
type ChatModel struct {
BaseVo
Platform string `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Weight int `json:"weight"`
Open bool `json:"open"`
Platform string `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Power int `json:"power"`
Open bool `json:"open"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
}

View File

@@ -5,6 +5,5 @@ import "chatplus/core/types"
type Config struct {
Id uint `json:"id"`
Key string `json:"key"`
ChatConfig types.ChatConfig `json:"chat_config"`
SystemConfig types.SystemConfig `json:"system_config"`
}

View File

@@ -1,20 +1,17 @@
package vo
import "chatplus/core/types"
type User struct {
BaseVo
Username string `json:"username"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Salt string `json:"salt"` // 密码盐
Power int `json:"power"` // 剩余算力
ChatConfig types.UserChatConfig `json:"chat_config"` // 聊天配置
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
ChatModels []string `json:"chat_models"` // AI模型集合
ExpiredTime int64 `json:"expired_time"` // 账户到期时间
Status bool `json:"status"` // 当前状态
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
Vip bool `json:"vip"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Salt string `json:"salt"` // 密码盐
Power int `json:"power"` // 剩余算力
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
ChatModels []string `json:"chat_models"` // AI模型集合
ExpiredTime int64 `json:"expired_time"` // 账户到期时间
Status bool `json:"status"` // 当前状态
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
Vip bool `json:"vip"`
}

View File

@@ -88,7 +88,7 @@ type apiErrRes struct {
} `json:"error"`
}
func OpenAIRequest(db *gorm.DB, prompt string, proxy string) (string, error) {
func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
var apiKey model.ApiKey
res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
if res.Error != nil {
@@ -104,8 +104,8 @@ func OpenAIRequest(db *gorm.DB, prompt string, proxy string) (string, error) {
var response apiRes
var errRes apiErrRes
client := req.C()
if apiKey.UseProxy && proxy != "" {
client.SetProxyURL(proxy)
if apiKey.ProxyURL != "" {
client.SetProxyURL(apiKey.ApiURL)
}
r, err := client.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"strconv"
"time"
"golang.org/x/crypto/sha3"
@@ -59,7 +60,7 @@ func Str2stamp(str string) int64 {
if len(str) == 0 {
return 0
}
layout := "2006-01-02 15:04:05"
t, err := time.Parse(layout, str)
if err != nil {
@@ -92,3 +93,11 @@ func InterfaceToString(value interface{}) string {
}
return JsonEncode(value)
}
func Str2Float(str string) float64 {
num, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0
}
return num
}