From 65ad5fb632408f9d0b4ea39e80708adb1c024fb8 Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 4 Apr 2023 09:05:17 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=B3=BB=E7=BB=9F=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=20API=20=E5=8F=82=E6=95=B0=E8=A7=A3=E6=9E=90=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=82=B9=E5=8D=A1=E6=89=A3=E8=B4=B9=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E6=B2=A1=E6=9C=89=E5=9B=9E=E5=A4=8D=E7=AD=94?= =?UTF-8?q?=E6=A1=88=E4=B8=8D=E8=AE=B0=E6=89=A3=E8=B4=B9=E6=AC=A1=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat_handler.go | 56 ++++++------ server/config_handler.go | 191 +++++++++++++++++++-------------------- server/server.go | 6 +- types/config.go | 2 + 4 files changed, 127 insertions(+), 128 deletions(-) diff --git a/server/chat_handler.go b/server/chat_handler.go index d7647836..e8adf4ae 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -224,37 +224,41 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro } _ = response.Body.Close() // 关闭资源 - // 当前 Username 调用次数减 1 - if user.MaxCalls > 0 { - user.RemainingCalls -= 1 - _ = PutUser(*user) - } + // 消息发送成功 + if len(contents) > 0 { + // 当前 Username 调用次数减 1 + if user.MaxCalls > 0 { + user.RemainingCalls -= 1 + _ = PutUser(*user) + } - if message.Role == "" { - message.Role = "assistant" - } - // 追加上下文消息 - useMsg := types.Message{Role: "user", Content: prompt} - context = append(context, useMsg) - message.Content = strings.Join(contents, "") + if message.Role == "" { + message.Role = "assistant" + } + // 追加上下文消息 + useMsg := types.Message{Role: "user", Content: prompt} + context = append(context, useMsg) + message.Content = strings.Join(contents, "") - // 更新上下文消息 - if s.Config.Chat.EnableContext { - context = append(context, message) - s.ChatContexts[ctxKey] = types.ChatContext{ - Messages: context, - LastAccessTime: time.Now().Unix(), + // 更新上下文消息 + if s.Config.Chat.EnableContext { + context = append(context, message) + s.ChatContexts[ctxKey] = types.ChatContext{ + Messages: context, + LastAccessTime: time.Now().Unix(), + } + } + + // 追加历史消息 + if user.EnableHistory { + err = AppendChatHistory(user.Name, role.Key, useMsg) + if err != nil { + return err + } + err = AppendChatHistory(user.Name, role.Key, message) } } - // 追加历史消息 - if user.EnableHistory { - err = AppendChatHistory(user.Name, role.Key, useMsg) - if err != nil { - return err - } - err = AppendChatHistory(user.Name, role.Key, message) - } return err } diff --git a/server/config_handler.go b/server/config_handler.go index bc65ac6a..b55fff89 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -6,12 +6,22 @@ import ( "net/http" "openai/types" "openai/utils" - "strconv" ) +func (s *Server) TestHandle(c *gin.Context) { + var data map[string]interface{} + err := json.NewDecoder(c.Request.Body).Decode(&data) + if err != nil { + logger.Errorf("Error decode json data: %s", err.Error()) + c.JSON(http.StatusBadRequest, nil) + return + } + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: data}) +} + // ConfigSetHandle set configs func (s *Server) ConfigSetHandle(c *gin.Context) { - var data map[string]string + var data map[string]interface{} err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -19,77 +29,30 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { return } - // Model if model, ok := data["model"]; ok { - s.Config.Chat.Model = model + s.Config.Chat.Model = model.(string) } - if accessKey, ok := data["access_key"]; ok { - s.Config.AccessKey = accessKey + s.Config.AccessKey = accessKey.(string) } - // Temperature if temperature, ok := data["temperature"]; ok { - v, err := strconv.ParseFloat(temperature, 32) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "temperature must be a float parameter", - }) - return - } - s.Config.Chat.Temperature = float32(v) + s.Config.Chat.Temperature = temperature.(float32) } - // max_users if maxTokens, ok := data["max_tokens"]; ok { - v, err := strconv.Atoi(maxTokens) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "max_users must be a int parameter", - }) - return - } - s.Config.Chat.MaxTokens = v + s.Config.Chat.MaxTokens = maxTokens.(int) } - // enable Context if enableContext, ok := data["enable_context"]; ok { - v, err := strconv.ParseBool(enableContext) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "enable_context must be a bool parameter", - }) - return - } - s.Config.Chat.EnableContext = v + s.Config.Chat.EnableContext = enableContext.(bool) } - if expireTime, ok := data["chat_context_expire_time"]; ok { - v, err := strconv.Atoi(expireTime) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "chat_context_expire_time must be a integer parameter", - }) - return - } - s.Config.Chat.ChatContextExpireTime = v + s.Config.Chat.ChatContextExpireTime = expireTime.(int) } - // enable auth if enableAuth, ok := data["enable_auth"]; ok { - v, err := strconv.ParseBool(enableAuth) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "enable_auth must be a bool parameter", - }) - return - } - s.Config.EnableAuth = v + s.Config.EnableAuth = enableAuth.(bool) } // 保存配置文件 @@ -99,7 +62,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { return } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config}) } // SetDebugHandle 开启/关闭调试模式 @@ -179,33 +142,41 @@ func (s *Server) BatchAddUserHandle(c *gin.Context) { } func (s *Server) SetUserHandle(c *gin.Context) { - var data types.User + var data map[string]interface{} err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) return } - // 参数处理 - if data.Name == "" || data.MaxCalls < 0 { + var user *types.User + if name, ok := data["name"]; ok { + user, err = GetUser(name.(string)) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "User not found"}) + return + } + } + var maxCalls int + if v, ok := data["max_calls"]; ok { + maxCalls = v.(int) + } + if maxCalls < 0 { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) return - } - - user, err := GetUser(data.Name) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Username not found"}) - return - } - - if data.MaxCalls > 0 { - user.RemainingCalls += data.MaxCalls - user.MaxCalls + } else if maxCalls > 0 { + user.RemainingCalls += maxCalls - user.MaxCalls if user.RemainingCalls < 0 { user.RemainingCalls = 0 } } - user.MaxCalls = data.MaxCalls - user.EnableHistory = data.EnableHistory + + if v, ok := data["status"]; ok { + user.Status = v.(bool) + } + if v, ok := data["enable_history"]; ok { + user.EnableHistory = v.(bool) + } err = PutUser(*user) if err != nil { @@ -218,7 +189,9 @@ func (s *Server) SetUserHandle(c *gin.Context) { // RemoveUserHandle 删除 Username func (s *Server) RemoveUserHandle(c *gin.Context) { - var data types.User + var data struct { + Name string `json:"name"` + } err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) @@ -241,15 +214,17 @@ func (s *Server) GetUserListHandle(c *gin.Context) { // AddApiKeyHandle 添加一个 API key func (s *Server) AddApiKeyHandle(c *gin.Context) { - var data map[string]string + var data struct { + ApiKey string `json:"api_key"` + } err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) c.JSON(http.StatusBadRequest, nil) return } - if key, ok := data["api_key"]; ok && len(key) > 20 { - s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, key) + if len(data.ApiKey) > 20 { + s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, data.ApiKey) } // 保存配置文件 @@ -264,7 +239,9 @@ func (s *Server) AddApiKeyHandle(c *gin.Context) { // RemoveApiKeyHandle 移除一个 API key func (s *Server) RemoveApiKeyHandle(c *gin.Context) { - var data map[string]string + var data struct { + ApiKey string `json:"api_key"` + } err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -272,11 +249,9 @@ func (s *Server) RemoveApiKeyHandle(c *gin.Context) { return } - if key, ok := data["api_key"]; ok { - for i, v := range s.Config.Chat.ApiKeys { - if v == key { - s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys[:i], s.Config.Chat.ApiKeys[i+1:]...) - } + for i, v := range s.Config.Chat.ApiKeys { + if v == data.ApiKey { + s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys[:i], s.Config.Chat.ApiKeys[i+1:]...) } } @@ -341,7 +316,7 @@ func (s *Server) GetChatRoleHandle(c *gin.Context) { // SetChatRoleHandle 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作 func (s *Server) SetChatRoleHandle(c *gin.Context) { - var data types.ChatRole + var data map[string]interface{} err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -349,18 +324,38 @@ func (s *Server) SetChatRoleHandle(c *gin.Context) { return } - if data.Key == "" { + var key string + if v, ok := data["key"]; !ok { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Must specified the role key"}) return + } else { + key = v.(string) } - _, err = GetChatRole(data.Key) + role, err := GetChatRole(key) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Role key not exists"}) return } - err = PutChatRole(data) + if v, ok := data["name"]; ok { + role.Name = v.(string) + } + if v, ok := data["hello_msg"]; ok { + role.HelloMsg = v.(string) + } + if v, ok := data["icon"]; ok { + role.Icon = v.(string) + } + if v, ok := data["enable"]; ok { + role.Enable = v.(bool) + } + if v, ok := data["context"]; ok { + bytes, _ := json.Marshal(v) + _ = json.Unmarshal(bytes, &role.Context) + } + + err = PutChatRole(*role) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config"}) return @@ -371,7 +366,9 @@ func (s *Server) SetChatRoleHandle(c *gin.Context) { // AddProxyHandle 添加一个代理 func (s *Server) AddProxyHandle(c *gin.Context) { - var data map[string]string + var data struct { + Proxy string `json:"proxy"` + } err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -379,9 +376,9 @@ func (s *Server) AddProxyHandle(c *gin.Context) { return } - if proxy, ok := data["proxy"]; ok { - if !utils.ContainsStr(s.Config.ProxyURL, proxy) { - s.Config.ProxyURL = append(s.Config.ProxyURL, proxy) + if data.Proxy != "" { + if !utils.ContainsStr(s.Config.ProxyURL, data.Proxy) { + s.Config.ProxyURL = append(s.Config.ProxyURL, data.Proxy) } } @@ -397,7 +394,9 @@ func (s *Server) AddProxyHandle(c *gin.Context) { // RemoveProxyHandle 删除一个代理 func (s *Server) RemoveProxyHandle(c *gin.Context) { - var data map[string]string + var data struct { + Proxy string `json:"proxy"` + } err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -405,12 +404,10 @@ func (s *Server) RemoveProxyHandle(c *gin.Context) { return } - if proxy, ok := data["proxy"]; ok { - for i, v := range s.Config.ProxyURL { - if v == proxy { - s.Config.ProxyURL = append(s.Config.ProxyURL[:i], s.Config.ProxyURL[i+1:]...) - break - } + for i, v := range s.Config.ProxyURL { + if v == data.Proxy { + s.Config.ProxyURL = append(s.Config.ProxyURL[:i], s.Config.ProxyURL[i+1:]...) + break } } diff --git a/server/server.go b/server/server.go index 98440126..d9f9f3ae 100644 --- a/server/server.go +++ b/server/server.go @@ -81,7 +81,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { engine.Use(AuthorizeMiddleware(s)) engine.Use(Recover) - engine.GET("/hello", Hello) + engine.POST("/test", s.TestHandle) engine.GET("/api/session/get", s.GetSessionHandle) engine.POST("/api/login", s.LoginHandle) engine.Any("/api/chat", s.ChatHandle) @@ -287,7 +287,3 @@ func (s *Server) LoginHandle(c *gin.Context) { s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Username: data.Token, SessionId: sessionId} c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId}) } - -func Hello(c *gin.Context) { - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: "HELLO, ChatGPT !!!"}) -} diff --git a/types/config.go b/types/config.go index 5e429538..1fddfb60 100644 --- a/types/config.go +++ b/types/config.go @@ -18,6 +18,8 @@ type User struct { MaxCalls int `json:"max_calls"` // 最多调用次数,如果为 0 则表示不限制 RemainingCalls int `json:"remaining_calls"` // 剩余调用次数 EnableHistory bool `json:"enable_history"` // 是否启用聊天记录 + Status bool `json:"status"` // 当前状态 + ApiKey string `json:"api_key"` // OpenAI API KEY } // Chat configs struct