upgrade to v4.0.4

This commit is contained in:
RockYang
2024-05-07 16:32:05 +08:00
parent 3ab5459778
commit be1580e949
96 changed files with 3315 additions and 625 deletions

View File

@@ -30,7 +30,7 @@ func (h *ChatHandler) sendAzureMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@@ -47,7 +47,7 @@ func (h *ChatHandler) sendBaiduMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@@ -6,6 +6,7 @@ import (
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/service"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
@@ -35,15 +36,17 @@ var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
redis *redis.Client
uploadManager *oss.UploaderManager
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
}
}
@@ -68,9 +71,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
modelId := h.GetInt(c, "model_id", 0)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
modelId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
res := h.DB.First(&chatModel, modelId)
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
@@ -111,15 +125,9 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
h.Init()
@@ -235,7 +243,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
break
}
var tools = make([]interface{}, 0)
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
@@ -244,15 +252,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
required := parameters["required"]
delete(parameters, "required")
tools = append(tools, gin.H{
"type": "function",
"function": gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
})
}
// Fixed: compatible for gpt4-turbo-xxx model
if !strings.HasPrefix(req.Model, "gpt-4-turbo-") {
tool.Function.Required = required
}
tools = append(tools, tool)
}
if len(tools) > 0 {
@@ -332,6 +345,34 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Content: prompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
text := prompt
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
data = append(data, gin.H{
"type": "text",
"text": text,
})
content = data
} else {
content = prompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
@@ -339,6 +380,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
})
}
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure:
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
@@ -426,13 +469,29 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Debug().Where("id", session.Model.KeyId).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
h.DB.Debug().Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
}
if apiKey.Id == 0 {
return nil, errors.New("no available key, please import key")
}
// ONLY allow apiURL in blank list
if session.Model.Platform == types.OpenAI {
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
}
var apiURL string
switch platform {
switch session.Model.Platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
@@ -455,7 +514,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
if session.Model.Platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
@@ -479,7 +538,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
var proxyURL string
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
@@ -490,8 +548,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch platform {
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure:
request.Header.Set("api-key", apiKey.Value)
break

View File

@@ -31,7 +31,7 @@ func (h *ChatHandler) sendChatGLMMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@@ -31,7 +31,7 @@ func (h *ChatHandler) sendOpenAiMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
logger.Error(err)
@@ -74,6 +74,10 @@ func (h *ChatHandler) sendOpenAiMessage(
utils.ReplyMessage(ws, ErrImg)
break
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.ReplyMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break
}
var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
@@ -98,8 +102,10 @@ func (h *ChatHandler) sendOpenAiMessage(
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg)
}
continue
}

View File

@@ -45,7 +45,7 @@ func (h *ChatHandler) sendQWenMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {

View File

@@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"html/template"
"io"
"net/http"
@@ -69,7 +70,15 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
}
// use the last unused key
if res.Error != nil {
res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil