mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
add config update API
This commit is contained in:
parent
396b7440fa
commit
59782e9e57
@ -1,6 +1,8 @@
|
|||||||
package logger
|
package logger
|
||||||
|
|
||||||
import "go.uber.org/zap"
|
import (
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
var logger *zap.SugaredLogger
|
var logger *zap.SugaredLogger
|
||||||
|
|
||||||
|
8
main.go
8
main.go
@ -5,7 +5,6 @@ import (
|
|||||||
"github.com/mitchellh/go-homedir"
|
"github.com/mitchellh/go-homedir"
|
||||||
logger2 "openai/logger"
|
logger2 "openai/logger"
|
||||||
"openai/server"
|
"openai/server"
|
||||||
config2 "openai/types"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
)
|
)
|
||||||
@ -33,14 +32,15 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// load service configs
|
|
||||||
config, err := config2.LoadConfig(filepath.Join(configDir, "/config.toml"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("failed to load web types: %v", err)
|
logger.Errorf("failed to load web types: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// start server
|
// start server
|
||||||
s := server.NewServer(config)
|
s, err := server.NewServer(filepath.Join(configDir, "/config.toml"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
s.Run(webRoot, "web")
|
s.Run(webRoot, "web")
|
||||||
}
|
}
|
||||||
|
@ -8,13 +8,15 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"io"
|
"io"
|
||||||
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"openai/types"
|
"openai/types"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) Chat(c *gin.Context) {
|
// ChatHandle 处理聊天 WebSocket 请求
|
||||||
|
func (s *Server) ChatHandle(c *gin.Context) {
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(err)
|
logger.Fatal(err)
|
||||||
@ -32,19 +34,16 @@ func (s *Server) Chat(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info(string(message))
|
logger.Info(string(message))
|
||||||
for {
|
|
||||||
err = client.Send([]byte("H"))
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
|
||||||
// TODO: 根据会话请求,传入不同的用户 ID
|
// TODO: 根据会话请求,传入不同的用户 ID
|
||||||
//err = s.sendMessage("test", string(message), client)
|
err = s.sendMessage("test", string(message), client)
|
||||||
//if err != nil {
|
if err != nil {
|
||||||
// logger.Error(err)
|
logger.Error(err)
|
||||||
//}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||||
func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||||
var r = types.ApiRequest{
|
var r = types.ApiRequest{
|
||||||
Model: "gpt-3.5-turbo",
|
Model: "gpt-3.5-turbo",
|
||||||
@ -75,8 +74,10 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: API KEY 负载均衡
|
// TODO: API KEY 负载均衡
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
index := rand.Intn(len(s.Config.OpenAi.ApiKeys))
|
||||||
request.Header.Add("Content-Type", "application/json")
|
request.Header.Add("Content-Type", "application/json")
|
||||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKey[0]))
|
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKeys[index]))
|
||||||
|
|
||||||
uri := url.URL{}
|
uri := url.URL{}
|
||||||
proxy, _ := uri.Parse(s.Config.ProxyURL)
|
proxy, _ := uri.Parse(s.Config.ProxyURL)
|
||||||
@ -125,19 +126,19 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
|||||||
} else {
|
} else {
|
||||||
contents = append(contents, responseBody.Choices[0].Delta.Content)
|
contents = append(contents, responseBody.Choices[0].Delta.Content)
|
||||||
}
|
}
|
||||||
|
// 推送消息到客户端
|
||||||
err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content))
|
err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
fmt.Print(responseBody.Choices[0].Delta.Content)
|
fmt.Print(responseBody.Choices[0].Delta.Content)
|
||||||
if responseBody.Choices[0].FinishReason != "" {
|
if responseBody.Choices[0].FinishReason != "" {
|
||||||
fmt.Println()
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加历史消息
|
// 追加历史消息
|
||||||
history = append(history, message)
|
history = append(history, message)
|
||||||
|
s.History[userId] = history
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
66
server/config_handler.go
Normal file
66
server/config_handler.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"openai/types"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigSetHandle set configs
|
||||||
|
func (s *Server) ConfigSetHandle(c *gin.Context) {
|
||||||
|
var data map[string]string
|
||||||
|
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error decode json data: %s", err.Error())
|
||||||
|
c.JSON(http.StatusBadRequest, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// API key
|
||||||
|
if key, ok := data["api_key"]; ok && len(key) > 20 {
|
||||||
|
s.Config.OpenAi.ApiKeys = append(s.Config.OpenAi.ApiKeys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxy URL
|
||||||
|
if proxy, ok := data["proxy"]; ok {
|
||||||
|
s.Config.ProxyURL = proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model
|
||||||
|
if model, ok := data["model"]; ok {
|
||||||
|
s.Config.OpenAi.Model = model
|
||||||
|
}
|
||||||
|
|
||||||
|
// Temperature
|
||||||
|
if temperature, ok := data["temperature"]; ok {
|
||||||
|
v, err := strconv.ParseFloat(temperature, 32)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, types.BizVo{
|
||||||
|
Code: types.InvalidParams,
|
||||||
|
Message: "temperature must be a float parameter",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.Config.OpenAi.Temperature = float32(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// max_tokens
|
||||||
|
if maxTokens, ok := data["max_tokens"]; ok {
|
||||||
|
v, err := strconv.Atoi(maxTokens)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, types.BizVo{
|
||||||
|
Code: types.InvalidParams,
|
||||||
|
Message: "max_tokens must be a int parameter",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.Config.OpenAi.MaxTokens = v
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存配置文件
|
||||||
|
logger.Infof("Config: %+v", s.Config)
|
||||||
|
types.SaveConfig(s.Config, s.ConfigPath)
|
||||||
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||||
|
}
|
@ -30,12 +30,21 @@ func (s StaticFile) Open(name string) (fs.File, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Config *types.Config
|
Config *types.Config
|
||||||
History map[string][]types.Message
|
ConfigPath string
|
||||||
|
History map[string][]types.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(config *types.Config) *Server {
|
func NewServer(configPath string) (*Server, error) {
|
||||||
return &Server{Config: config, History: make(map[string][]types.Message, 16)}
|
// load service configs
|
||||||
|
config, err := types.LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Server{
|
||||||
|
Config: config,
|
||||||
|
ConfigPath: configPath,
|
||||||
|
History: make(map[string][]types.Message, 16)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Run(webRoot embed.FS, path string) {
|
func (s *Server) Run(webRoot embed.FS, path string) {
|
||||||
@ -46,7 +55,8 @@ func (s *Server) Run(webRoot embed.FS, path string) {
|
|||||||
engine.Use(AuthorizeMiddleware())
|
engine.Use(AuthorizeMiddleware())
|
||||||
|
|
||||||
engine.GET("/hello", Hello)
|
engine.GET("/hello", Hello)
|
||||||
engine.Any("/api/chat", s.Chat)
|
engine.Any("/api/chat", s.ChatHandle)
|
||||||
|
engine.POST("/api/config/set", s.ConfigSetHandle)
|
||||||
|
|
||||||
// process front-end web static files
|
// process front-end web static files
|
||||||
engine.StaticFS("/chat", http.FS(StaticFile{
|
engine.StaticFS("/chat", http.FS(StaticFile{
|
||||||
|
@ -2,6 +2,7 @@ package types
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
"net/http"
|
"net/http"
|
||||||
"openai/utils"
|
"openai/utils"
|
||||||
@ -18,7 +19,7 @@ type Config struct {
|
|||||||
// OpenAi configs struct
|
// OpenAi configs struct
|
||||||
type OpenAi struct {
|
type OpenAi struct {
|
||||||
ApiURL string
|
ApiURL string
|
||||||
ApiKey []string
|
ApiKeys []string
|
||||||
Model string
|
Model string
|
||||||
Temperature float32
|
Temperature float32
|
||||||
MaxTokens int
|
MaxTokens int
|
||||||
@ -52,7 +53,7 @@ func NewDefaultConfig() *Config {
|
|||||||
},
|
},
|
||||||
OpenAi: OpenAi{
|
OpenAi: OpenAi{
|
||||||
ApiURL: "https://api.openai.com/v1/chat/completions",
|
ApiURL: "https://api.openai.com/v1/chat/completions",
|
||||||
ApiKey: []string{""},
|
ApiKeys: []string{""},
|
||||||
Model: "gpt-3.5-turbo",
|
Model: "gpt-3.5-turbo",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Temperature: 1.0,
|
Temperature: 1.0,
|
||||||
@ -64,16 +65,10 @@ func LoadConfig(configFile string) (*Config, error) {
|
|||||||
var config *Config
|
var config *Config
|
||||||
_, err := os.Stat(configFile)
|
_, err := os.Stat(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Errorf("Error: %s", err.Error())
|
||||||
config = NewDefaultConfig()
|
config = NewDefaultConfig()
|
||||||
// generate types file
|
// save config
|
||||||
buf := new(bytes.Buffer)
|
err := SaveConfig(config, configFile)
|
||||||
encoder := toml.NewEncoder(buf)
|
|
||||||
|
|
||||||
if err := encoder.Encode(&config); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err := os.WriteFile(configFile, buf.Bytes(), 0644)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -81,9 +76,20 @@ func LoadConfig(configFile string) (*Config, error) {
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
_, err = toml.DecodeFile(configFile, &config)
|
_, err = toml.DecodeFile(configFile, &config)
|
||||||
|
fmt.Println(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, err
|
return config, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveConfig(config *Config, configFile string) error {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
encoder := toml.NewEncoder(buf)
|
||||||
|
if err := encoder.Encode(&config); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(configFile, buf.Bytes(), 0644)
|
||||||
|
}
|
||||||
|
28
types/web.go
Normal file
28
types/web.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
// BizVo 业务返回 VO
|
||||||
|
type BizVo struct {
|
||||||
|
Code BizCode `json:"code"`
|
||||||
|
Page int `json:"page,omitempty"`
|
||||||
|
PageSize int `json:"page_size,omitempty"`
|
||||||
|
Total int `json:"total,omitempty"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WsVo Websocket 信息 VO
|
||||||
|
type WsVo struct {
|
||||||
|
Stop bool
|
||||||
|
Content string
|
||||||
|
}
|
||||||
|
|
||||||
|
type BizCode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Success = BizCode(0)
|
||||||
|
Failed = BizCode(1)
|
||||||
|
InvalidParams = BizCode(101) // 非法参数
|
||||||
|
NotAuthorized = BizCode(400) // 未授权
|
||||||
|
|
||||||
|
OkMsg = "Success"
|
||||||
|
)
|
@ -1 +1 @@
|
|||||||
VUE_APP_API_HOST=127.0.0.1:5678
|
VUE_APP_WS_HOST=ws://127.0.0.1:5678
|
@ -99,7 +99,7 @@ export default defineComponent({
|
|||||||
window.addEventListener('resize', this.windowResize);
|
window.addEventListener('resize', this.windowResize);
|
||||||
|
|
||||||
// 初始化 WebSocket 对象
|
// 初始化 WebSocket 对象
|
||||||
const socket = new WebSocket('ws://172.22.11.200:5678/api/chat');
|
const socket = new WebSocket(process.env.VUE_APP_WS_HOST+'/api/chat');
|
||||||
socket.addEventListener('open', () => {
|
socket.addEventListener('open', () => {
|
||||||
console.log('WebSocket 连接已打开');
|
console.log('WebSocket 连接已打开');
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user