From 59782e9e57f142ddc66ca7d6e960591f767bd11d Mon Sep 17 00:00:00 2001 From: RockYang Date: Sat, 18 Mar 2023 20:20:00 +0800 Subject: [PATCH] add config update API --- logger/logger.go | 4 ++- main.go | 8 ++--- server/chat_handler.go | 25 +++++++-------- server/config_handler.go | 66 ++++++++++++++++++++++++++++++++++++++++ server/server.go | 20 +++++++++--- types/config.go | 28 ++++++++++------- types/web.go | 28 +++++++++++++++++ web/.env.development | 2 +- web/src/views/Chat.vue | 2 +- 9 files changed, 148 insertions(+), 35 deletions(-) create mode 100644 server/config_handler.go create mode 100644 types/web.go diff --git a/logger/logger.go b/logger/logger.go index b0bc7684..0dd81e9f 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,6 +1,8 @@ package logger -import "go.uber.org/zap" +import ( + "go.uber.org/zap" +) var logger *zap.SugaredLogger diff --git a/main.go b/main.go index 2111a88b..fc965081 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ import ( "github.com/mitchellh/go-homedir" logger2 "openai/logger" "openai/server" - config2 "openai/types" "os" "path/filepath" ) @@ -33,14 +32,15 @@ func main() { } } - // load service configs - config, err := config2.LoadConfig(filepath.Join(configDir, "/config.toml")) if err != nil { logger.Errorf("failed to load web types: %v", err) return } // start server - s := server.NewServer(config) + s, err := server.NewServer(filepath.Join(configDir, "/config.toml")) + if err != nil { + panic(err) + } s.Run(webRoot, "web") } diff --git a/server/chat_handler.go b/server/chat_handler.go index ff21660d..25efaec2 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -8,13 +8,15 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" + "math/rand" "net/http" "net/url" "openai/types" "time" ) -func (s *Server) Chat(c *gin.Context) { +// ChatHandle 处理聊天 WebSocket 请求 +func (s *Server) ChatHandle(c *gin.Context) { ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Fatal(err) @@ -32,19 +34,16 @@ func (s *Server) Chat(c *gin.Context) { } logger.Info(string(message)) - for { - err = client.Send([]byte("H")) - time.Sleep(time.Second) - } // TODO: 根据会话请求,传入不同的用户 ID - //err = s.sendMessage("test", string(message), client) - //if err != nil { - // logger.Error(err) - //} + err = s.sendMessage("test", string(message), client) + if err != nil { + logger.Error(err) + } } }() } +// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 func (s *Server) sendMessage(userId string, text string, ws Client) error { var r = types.ApiRequest{ Model: "gpt-3.5-turbo", @@ -75,8 +74,10 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { } // TODO: API KEY 负载均衡 + rand.Seed(time.Now().UnixNano()) + index := rand.Intn(len(s.Config.OpenAi.ApiKeys)) request.Header.Add("Content-Type", "application/json") - request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKey[0])) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKeys[index])) uri := url.URL{} proxy, _ := uri.Parse(s.Config.ProxyURL) @@ -125,19 +126,19 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { } else { contents = append(contents, responseBody.Choices[0].Delta.Content) } - + // 推送消息到客户端 err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content)) if err != nil { logger.Error(err) } fmt.Print(responseBody.Choices[0].Delta.Content) if responseBody.Choices[0].FinishReason != "" { - fmt.Println() break } } // 追加历史消息 history = append(history, message) + s.History[userId] = history return nil } diff --git a/server/config_handler.go b/server/config_handler.go new file mode 100644 index 00000000..a33b6778 --- /dev/null +++ b/server/config_handler.go @@ -0,0 +1,66 @@ +package server + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "net/http" + "openai/types" + "strconv" +) + +// ConfigSetHandle set configs +func (s *Server) ConfigSetHandle(c *gin.Context) { + var data map[string]string + err := json.NewDecoder(c.Request.Body).Decode(&data) + if err != nil { + logger.Errorf("Error decode json data: %s", err.Error()) + c.JSON(http.StatusBadRequest, nil) + return + } + // API key + if key, ok := data["api_key"]; ok && len(key) > 20 { + s.Config.OpenAi.ApiKeys = append(s.Config.OpenAi.ApiKeys, key) + } + + // proxy URL + if proxy, ok := data["proxy"]; ok { + s.Config.ProxyURL = proxy + } + + // Model + if model, ok := data["model"]; ok { + s.Config.OpenAi.Model = model + } + + // Temperature + if temperature, ok := data["temperature"]; ok { + v, err := strconv.ParseFloat(temperature, 32) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{ + Code: types.InvalidParams, + Message: "temperature must be a float parameter", + }) + return + } + s.Config.OpenAi.Temperature = float32(v) + } + + // max_tokens + if maxTokens, ok := data["max_tokens"]; ok { + v, err := strconv.Atoi(maxTokens) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{ + Code: types.InvalidParams, + Message: "max_tokens must be a int parameter", + }) + return + } + s.Config.OpenAi.MaxTokens = v + + } + + // 保存配置文件 + logger.Infof("Config: %+v", s.Config) + types.SaveConfig(s.Config, s.ConfigPath) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) +} diff --git a/server/server.go b/server/server.go index 36e60342..ba1fe235 100644 --- a/server/server.go +++ b/server/server.go @@ -30,12 +30,21 @@ func (s StaticFile) Open(name string) (fs.File, error) { } type Server struct { - Config *types.Config - History map[string][]types.Message + Config *types.Config + ConfigPath string + History map[string][]types.Message } -func NewServer(config *types.Config) *Server { - return &Server{Config: config, History: make(map[string][]types.Message, 16)} +func NewServer(configPath string) (*Server, error) { + // load service configs + config, err := types.LoadConfig(configPath) + if err != nil { + return nil, err + } + return &Server{ + Config: config, + ConfigPath: configPath, + History: make(map[string][]types.Message, 16)}, nil } func (s *Server) Run(webRoot embed.FS, path string) { @@ -46,7 +55,8 @@ func (s *Server) Run(webRoot embed.FS, path string) { engine.Use(AuthorizeMiddleware()) engine.GET("/hello", Hello) - engine.Any("/api/chat", s.Chat) + engine.Any("/api/chat", s.ChatHandle) + engine.POST("/api/config/set", s.ConfigSetHandle) // process front-end web static files engine.StaticFS("/chat", http.FS(StaticFile{ diff --git a/types/config.go b/types/config.go index 69061daa..90ebc285 100644 --- a/types/config.go +++ b/types/config.go @@ -2,6 +2,7 @@ package types import ( "bytes" + "fmt" "github.com/BurntSushi/toml" "net/http" "openai/utils" @@ -18,7 +19,7 @@ type Config struct { // OpenAi configs struct type OpenAi struct { ApiURL string - ApiKey []string + ApiKeys []string Model string Temperature float32 MaxTokens int @@ -52,7 +53,7 @@ func NewDefaultConfig() *Config { }, OpenAi: OpenAi{ ApiURL: "https://api.openai.com/v1/chat/completions", - ApiKey: []string{""}, + ApiKeys: []string{""}, Model: "gpt-3.5-turbo", MaxTokens: 1024, Temperature: 1.0, @@ -64,16 +65,10 @@ func LoadConfig(configFile string) (*Config, error) { var config *Config _, err := os.Stat(configFile) if err != nil { + fmt.Errorf("Error: %s", err.Error()) config = NewDefaultConfig() - // generate types file - buf := new(bytes.Buffer) - encoder := toml.NewEncoder(buf) - - if err := encoder.Encode(&config); err != nil { - return nil, err - } - - err := os.WriteFile(configFile, buf.Bytes(), 0644) + // save config + err := SaveConfig(config, configFile) if err != nil { return nil, err } @@ -81,9 +76,20 @@ func LoadConfig(configFile string) (*Config, error) { return config, nil } _, err = toml.DecodeFile(configFile, &config) + fmt.Println(config) if err != nil { return nil, err } return config, err } + +func SaveConfig(config *Config, configFile string) error { + buf := new(bytes.Buffer) + encoder := toml.NewEncoder(buf) + if err := encoder.Encode(&config); err != nil { + return err + } + + return os.WriteFile(configFile, buf.Bytes(), 0644) +} diff --git a/types/web.go b/types/web.go new file mode 100644 index 00000000..858b51fe --- /dev/null +++ b/types/web.go @@ -0,0 +1,28 @@ +package types + +// BizVo 业务返回 VO +type BizVo struct { + Code BizCode `json:"code"` + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + Total int `json:"total,omitempty"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// WsVo Websocket 信息 VO +type WsVo struct { + Stop bool + Content string +} + +type BizCode int + +const ( + Success = BizCode(0) + Failed = BizCode(1) + InvalidParams = BizCode(101) // 非法参数 + NotAuthorized = BizCode(400) // 未授权 + + OkMsg = "Success" +) diff --git a/web/.env.development b/web/.env.development index 1de81d2c..c6d32f62 100644 --- a/web/.env.development +++ b/web/.env.development @@ -1 +1 @@ -VUE_APP_API_HOST=127.0.0.1:5678 \ No newline at end of file +VUE_APP_WS_HOST=ws://127.0.0.1:5678 \ No newline at end of file diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index 075cb28f..e4413ddf 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -99,7 +99,7 @@ export default defineComponent({ window.addEventListener('resize', this.windowResize); // 初始化 WebSocket 对象 - const socket = new WebSocket('ws://172.22.11.200:5678/api/chat'); + const socket = new WebSocket(process.env.VUE_APP_WS_HOST+'/api/chat'); socket.addEventListener('open', () => { console.log('WebSocket 连接已打开'); });