mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-14 05:03:45 +08:00
The 'stop generate' and 'regenerate response' function is ready
This commit is contained in:
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -59,12 +60,13 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
delete(s.ChatClients, sessionId)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
//replyMessage(client, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false)
|
||||
//replyMessage(client, "", true)
|
||||
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
|
||||
err = s.sendMessage(session, chatRole, string(message), client, false)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.ReqCancelFunc[sessionId] = cancel
|
||||
// 回复消息
|
||||
err = s.sendMessage(ctx, session, chatRole, string(message), client, false)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -73,7 +75,13 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error {
|
||||
func (s *Server) sendMessage(ctx context.Context, session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error {
|
||||
cancel := s.ReqCancelFunc[session.SessionId]
|
||||
defer func() {
|
||||
cancel()
|
||||
delete(s.ReqCancelFunc, session.SessionId)
|
||||
}()
|
||||
|
||||
user, err := GetUser(session.Username)
|
||||
if err != nil {
|
||||
replyMessage(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false)
|
||||
@@ -98,38 +106,38 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
|
||||
replyMessage(ws, "", true)
|
||||
return nil
|
||||
}
|
||||
var r = types.ApiRequest{
|
||||
var req = types.ApiRequest{
|
||||
Model: s.Config.Chat.Model,
|
||||
Temperature: s.Config.Chat.Temperature,
|
||||
MaxTokens: s.Config.Chat.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
var context []types.Message
|
||||
var chatCtx []types.Message
|
||||
var ctxKey = fmt.Sprintf("%s-%s", session.SessionId, role.Key)
|
||||
if v, ok := s.ChatContexts[ctxKey]; ok && s.Config.Chat.EnableContext {
|
||||
context = v.Messages
|
||||
chatCtx = v.Messages
|
||||
} else {
|
||||
context = role.Context
|
||||
chatCtx = role.Context
|
||||
}
|
||||
|
||||
if s.DebugMode {
|
||||
logger.Infof("会话上下文:%+v", context)
|
||||
logger.Infof("会话上下文:%+v", chatCtx)
|
||||
}
|
||||
|
||||
req.Messages = append(chatCtx, types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
var retryCount = 5
|
||||
var retryCount = 5 // 重试次数
|
||||
var response *http.Response
|
||||
var apiKey string
|
||||
var failedKey = ""
|
||||
var failedProxyURL = ""
|
||||
for retryCount > 0 {
|
||||
r.Messages = append(context, types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
requestBody, err := json.Marshal(r)
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -139,6 +147,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
|
||||
return err
|
||||
}
|
||||
|
||||
request = request.WithContext(ctx)
|
||||
request.Header.Add("Content-Type", "application/json")
|
||||
|
||||
proxyURL := s.getProxyURL(failedProxyURL)
|
||||
@@ -164,7 +173,10 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
|
||||
response, err = client.Do(request)
|
||||
if err == nil {
|
||||
break
|
||||
} else if strings.Contains(err.Error(), "context canceled") {
|
||||
return errors.New("用户取消了请求:" + prompt)
|
||||
} else {
|
||||
logger.Error("HTTP API 请求失败:" + err.Error())
|
||||
failedKey = apiKey
|
||||
failedProxyURL = proxyURL
|
||||
}
|
||||
@@ -205,14 +217,14 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
|
||||
_ = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
|
||||
// 重发当前消息
|
||||
return s.sendMessage(session, role, prompt, ws, false)
|
||||
return s.sendMessage(ctx, session, role, prompt, ws, false)
|
||||
|
||||
// 上下文超出长度了
|
||||
} else if strings.Contains(line, "This model's maximum context length is 4097 tokens") {
|
||||
logger.Infof("会话上下文长度超出限制, Username: %s", user.Name)
|
||||
// 重置上下文,重发当前消息
|
||||
delete(s.ChatContexts, ctxKey)
|
||||
return s.sendMessage(session, role, prompt, ws, true)
|
||||
return s.sendMessage(ctx, session, role, prompt, ws, true)
|
||||
} else if !strings.Contains(line, "data:") {
|
||||
continue
|
||||
}
|
||||
@@ -246,7 +258,18 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
|
||||
IsHelloMsg: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 监控取消信号
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 结束输出
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
|
||||
_ = response.Body.Close()
|
||||
return errors.New("用户取消了请求:" + prompt)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
} // end for
|
||||
_ = response.Body.Close() // 关闭资源
|
||||
|
||||
// 消息发送成功
|
||||
@@ -260,27 +283,26 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
// 追加上下文消息
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
context = append(context, useMsg)
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息
|
||||
if s.Config.Chat.EnableContext {
|
||||
context = append(context, message)
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
s.ChatContexts[ctxKey] = types.ChatContext{
|
||||
Messages: context,
|
||||
Messages: chatCtx,
|
||||
LastAccessTime: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// 追加历史消息
|
||||
if user.EnableHistory {
|
||||
err = AppendChatHistory(user.Name, role.Key, useMsg)
|
||||
err = AppendChatHistory(user.Name, role.Key, useMsg) // 提问消息
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AppendChatHistory(user.Name, role.Key, message)
|
||||
err = AppendChatHistory(user.Name, role.Key, message) // 回复消息
|
||||
}
|
||||
}
|
||||
|
||||
@@ -431,3 +453,12 @@ func (s *Server) ClearHistoryHandle(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
|
||||
}
|
||||
|
||||
// StopGenerateHandle 停止生成
|
||||
func (s *Server) StopGenerateHandle(c *gin.Context) {
|
||||
sessionId := c.GetHeader(types.TokenName)
|
||||
cancel := s.ReqCancelFunc[sessionId]
|
||||
cancel()
|
||||
delete(s.ReqCancelFunc, sessionId)
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"github.com/gin-contrib/sessions"
|
||||
@@ -39,10 +40,11 @@ type Server struct {
|
||||
|
||||
// 保存 Websocket 会话 Username, 每个 Username 只能连接一次
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
ChatSession map[string]types.ChatSession //map[sessionId]User
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
ChatClients map[string]*WsClient // Websocket 连接集合
|
||||
DebugMode bool // 是否开启调试模式
|
||||
ChatSession map[string]types.ChatSession //map[sessionId]User
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
ChatClients map[string]*WsClient // Websocket 连接集合
|
||||
ReqCancelFunc map[string]context.CancelFunc // HttpClient 请求取消 handle function
|
||||
DebugMode bool // 是否开启调试模式
|
||||
}
|
||||
|
||||
func NewServer(configPath string) (*Server, error) {
|
||||
@@ -67,6 +69,7 @@ func NewServer(configPath string) (*Server, error) {
|
||||
ChatContexts: make(map[string]types.ChatContext, 16),
|
||||
ChatSession: make(map[string]types.ChatSession),
|
||||
ChatClients: make(map[string]*WsClient),
|
||||
ReqCancelFunc: make(map[string]context.CancelFunc),
|
||||
ApiKeyAccessStat: make(map[string]int64),
|
||||
}, nil
|
||||
}
|
||||
@@ -83,17 +86,18 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
|
||||
engine.Use(AuthorizeMiddleware(s))
|
||||
engine.Use(Recover)
|
||||
|
||||
engine.POST("/test", s.TestHandle)
|
||||
engine.GET("/api/session/get", s.GetSessionHandle)
|
||||
engine.POST("/api/login", s.LoginHandle)
|
||||
engine.POST("/api/logout", s.LogoutHandle)
|
||||
engine.Any("/api/chat", s.ChatHandle)
|
||||
engine.POST("test", s.TestHandle)
|
||||
engine.GET("api/session/get", s.GetSessionHandle)
|
||||
engine.POST("api/login", s.LoginHandle)
|
||||
engine.POST("api/logout", s.LogoutHandle)
|
||||
engine.Any("api/chat", s.ChatHandle)
|
||||
engine.POST("api/chat/stop", s.StopGenerateHandle)
|
||||
engine.POST("api/chat/history", s.GetChatHistoryHandle)
|
||||
engine.POST("api/chat/history/clear", s.ClearHistoryHandle)
|
||||
|
||||
engine.POST("/api/config/set", s.ConfigSetHandle)
|
||||
engine.GET("/api/config/chat-roles/get", s.GetChatRoleListHandle)
|
||||
engine.GET("/api/config/chat-roles/add", s.AddChatRoleHandle)
|
||||
engine.POST("api/config/set", s.ConfigSetHandle)
|
||||
engine.GET("api/config/chat-roles/get", s.GetChatRoleListHandle)
|
||||
engine.GET("api/config/chat-roles/add", s.AddChatRoleHandle)
|
||||
engine.POST("api/config/user/add", s.AddUserHandle)
|
||||
engine.POST("api/config/user/batch-add", s.BatchAddUserHandle)
|
||||
engine.POST("api/config/user/set", s.SetUserHandle)
|
||||
|
||||
Reference in New Issue
Block a user