mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-11 11:43:43 +08:00
add config update API
This commit is contained in:
@@ -8,13 +8,15 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"openai/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) Chat(c *gin.Context) {
|
||||
// ChatHandle 处理聊天 WebSocket 请求
|
||||
func (s *Server) ChatHandle(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
@@ -32,19 +34,16 @@ func (s *Server) Chat(c *gin.Context) {
|
||||
}
|
||||
|
||||
logger.Info(string(message))
|
||||
for {
|
||||
err = client.Send([]byte("H"))
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
// TODO: 根据会话请求,传入不同的用户 ID
|
||||
//err = s.sendMessage("test", string(message), client)
|
||||
//if err != nil {
|
||||
// logger.Error(err)
|
||||
//}
|
||||
err = s.sendMessage("test", string(message), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||
func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
var r = types.ApiRequest{
|
||||
Model: "gpt-3.5-turbo",
|
||||
@@ -75,8 +74,10 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
}
|
||||
|
||||
// TODO: API KEY 负载均衡
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
index := rand.Intn(len(s.Config.OpenAi.ApiKeys))
|
||||
request.Header.Add("Content-Type", "application/json")
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKey[0]))
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKeys[index]))
|
||||
|
||||
uri := url.URL{}
|
||||
proxy, _ := uri.Parse(s.Config.ProxyURL)
|
||||
@@ -125,19 +126,19 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
} else {
|
||||
contents = append(contents, responseBody.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
// 推送消息到客户端
|
||||
err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
fmt.Print(responseBody.Choices[0].Delta.Content)
|
||||
if responseBody.Choices[0].FinishReason != "" {
|
||||
fmt.Println()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 追加历史消息
|
||||
history = append(history, message)
|
||||
s.History[userId] = history
|
||||
return nil
|
||||
}
|
||||
|
||||
66
server/config_handler.go
Normal file
66
server/config_handler.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"openai/types"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ConfigSetHandle set configs
|
||||
func (s *Server) ConfigSetHandle(c *gin.Context) {
|
||||
var data map[string]string
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||
if err != nil {
|
||||
logger.Errorf("Error decode json data: %s", err.Error())
|
||||
c.JSON(http.StatusBadRequest, nil)
|
||||
return
|
||||
}
|
||||
// API key
|
||||
if key, ok := data["api_key"]; ok && len(key) > 20 {
|
||||
s.Config.OpenAi.ApiKeys = append(s.Config.OpenAi.ApiKeys, key)
|
||||
}
|
||||
|
||||
// proxy URL
|
||||
if proxy, ok := data["proxy"]; ok {
|
||||
s.Config.ProxyURL = proxy
|
||||
}
|
||||
|
||||
// Model
|
||||
if model, ok := data["model"]; ok {
|
||||
s.Config.OpenAi.Model = model
|
||||
}
|
||||
|
||||
// Temperature
|
||||
if temperature, ok := data["temperature"]; ok {
|
||||
v, err := strconv.ParseFloat(temperature, 32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "temperature must be a float parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
s.Config.OpenAi.Temperature = float32(v)
|
||||
}
|
||||
|
||||
// max_tokens
|
||||
if maxTokens, ok := data["max_tokens"]; ok {
|
||||
v, err := strconv.Atoi(maxTokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "max_tokens must be a int parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
s.Config.OpenAi.MaxTokens = v
|
||||
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
logger.Infof("Config: %+v", s.Config)
|
||||
types.SaveConfig(s.Config, s.ConfigPath)
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
}
|
||||
@@ -30,12 +30,21 @@ func (s StaticFile) Open(name string) (fs.File, error) {
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
Config *types.Config
|
||||
History map[string][]types.Message
|
||||
Config *types.Config
|
||||
ConfigPath string
|
||||
History map[string][]types.Message
|
||||
}
|
||||
|
||||
func NewServer(config *types.Config) *Server {
|
||||
return &Server{Config: config, History: make(map[string][]types.Message, 16)}
|
||||
func NewServer(configPath string) (*Server, error) {
|
||||
// load service configs
|
||||
config, err := types.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{
|
||||
Config: config,
|
||||
ConfigPath: configPath,
|
||||
History: make(map[string][]types.Message, 16)}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Run(webRoot embed.FS, path string) {
|
||||
@@ -46,7 +55,8 @@ func (s *Server) Run(webRoot embed.FS, path string) {
|
||||
engine.Use(AuthorizeMiddleware())
|
||||
|
||||
engine.GET("/hello", Hello)
|
||||
engine.Any("/api/chat", s.Chat)
|
||||
engine.Any("/api/chat", s.ChatHandle)
|
||||
engine.POST("/api/config/set", s.ConfigSetHandle)
|
||||
|
||||
// process front-end web static files
|
||||
engine.StaticFS("/chat", http.FS(StaticFile{
|
||||
|
||||
Reference in New Issue
Block a user