diff --git a/server/chat_handler.go b/server/chat_handler.go index 7c1741c4..ff21660d 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -1,9 +1,17 @@ package server import ( + "bufio" + "bytes" + "encoding/json" + "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" + "io" "net/http" + "net/url" + "openai/types" + "time" ) func (s *Server) Chat(c *gin.Context) { @@ -23,12 +31,113 @@ func (s *Server) Chat(c *gin.Context) { return } - // TODO: 接受消息,调用 ChatGPT 返回消息 logger.Info(string(message)) - err = client.Send(message) - if err != nil { - logger.Error(err) + for { + err = client.Send([]byte("H")) + time.Sleep(time.Second) } + // TODO: 根据会话请求,传入不同的用户 ID + //err = s.sendMessage("test", string(message), client) + //if err != nil { + // logger.Error(err) + //} } }() } + +func (s *Server) sendMessage(userId string, text string, ws Client) error { + var r = types.ApiRequest{ + Model: "gpt-3.5-turbo", + Temperature: 0.9, + MaxTokens: 1024, + Stream: true, + } + var history []types.Message + if v, ok := s.History[userId]; ok { + history = v + } else { + history = make([]types.Message, 0) + } + r.Messages = append(history, types.Message{ + Role: "user", + Content: text, + }) + + logger.Info("上下文历史消息:%+v", s.History[userId]) + requestBody, err := json.Marshal(r) + if err != nil { + return err + } + + request, err := http.NewRequest(http.MethodPost, s.Config.OpenAi.ApiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return err + } + + // TODO: API KEY 负载均衡 + request.Header.Add("Content-Type", "application/json") + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKey[0])) + + uri := url.URL{} + proxy, _ := uri.Parse(s.Config.ProxyURL) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + response, err := client.Do(request) + var retryCount = 3 + for err != nil { + if retryCount <= 0 { + return err + } + response, err = client.Do(request) + retryCount-- + } + + var message = types.Message{} + var contents = make([]string, 0) + var responseBody = types.ApiResponse{} + + reader := bufio.NewReader(response.Body) + for { + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + fmt.Println(err) + break + } + + if line == "" { + break + } else if len(line) < 20 { + continue + } + + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil { + fmt.Println(err) + continue + } + // 初始化 role + if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { + message.Role = responseBody.Choices[0].Delta.Role + continue + } else { + contents = append(contents, responseBody.Choices[0].Delta.Content) + } + + err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content)) + if err != nil { + logger.Error(err) + } + fmt.Print(responseBody.Choices[0].Delta.Content) + if responseBody.Choices[0].FinishReason != "" { + fmt.Println() + break + } + } + + // 追加历史消息 + history = append(history, message) + return nil +} diff --git a/server/server.go b/server/server.go index e71bf31b..36e60342 100644 --- a/server/server.go +++ b/server/server.go @@ -30,11 +30,12 @@ func (s StaticFile) Open(name string) (fs.File, error) { } type Server struct { - Config *types.Config + Config *types.Config + History map[string][]types.Message } func NewServer(config *types.Config) *Server { - return &Server{Config: config} + return &Server{Config: config, History: make(map[string][]types.Message, 16)} } func (s *Server) Run(webRoot embed.FS, path string) { diff --git a/types/config.go b/types/config.go index baaa9555..69061daa 100644 --- a/types/config.go +++ b/types/config.go @@ -9,14 +9,16 @@ import ( ) type Config struct { - Listen string - Session Session - OpenAi OpenAi + Listen string + Session Session + ProxyURL string + OpenAi OpenAi } // OpenAi configs struct type OpenAi struct { - ApiKey string + ApiURL string + ApiKey []string Model string Temperature float32 MaxTokens int @@ -49,6 +51,8 @@ func NewDefaultConfig() *Config { SameSite: http.SameSiteLaxMode, }, OpenAi: OpenAi{ + ApiURL: "https://api.openai.com/v1/chat/completions", + ApiKey: []string{""}, Model: "gpt-3.5-turbo", MaxTokens: 1024, Temperature: 1.0, diff --git a/types/gpt.go b/types/gpt.go new file mode 100644 index 00000000..fe6fe2ca --- /dev/null +++ b/types/gpt.go @@ -0,0 +1,25 @@ +package types + +// ApiRequest API 请求实体 +type ApiRequest struct { + Model string `json:"model"` + Temperature float32 `json:"temperature"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + Messages []Message `json:"messages"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ApiResponse struct { + Choices []ChoiceItem `json:"choices"` +} + +// ChoiceItem API 响应实体 +type ChoiceItem struct { + Delta Message `json:"delta"` + FinishReason string `json:"finish_reason"` +} diff --git a/types/types.go b/types/types.go deleted file mode 100644 index e7a1823c..00000000 --- a/types/types.go +++ /dev/null @@ -1,49 +0,0 @@ -package types - -import ( - "sync" -) - -type LockedMap struct { - lock sync.RWMutex - data map[string]interface{} -} - -func NewLockedMap() *LockedMap { - return &LockedMap{ - lock: sync.RWMutex{}, - data: make(map[string]interface{}), - } -} - -func (m *LockedMap) Put(key string, value interface{}) { - m.lock.Lock() - defer m.lock.Unlock() - - m.data[key] = value -} - -func (m *LockedMap) Get(key string) interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - - return m.data[key] -} - -func (m *LockedMap) Delete(key string) { - m.lock.Lock() - defer m.lock.Unlock() - - delete(m.data, key) -} - -func (m *LockedMap) ToList() []interface{} { - m.lock.Lock() - defer m.lock.Unlock() - - var s = make([]interface{}, 0) - for _, v := range m.data { - s = append(s, v) - } - return s -} diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index 72fec88b..075cb28f 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -108,12 +108,13 @@ export default defineComponent({ const reader = new FileReader(); reader.readAsText(event.data, "UTF-8"); reader.onload = () => { - this.chatData.push({ - type: "reply", - id: randString(32), - icon: 'images/gpt-icon.png', - content: reader.result - }); + // this.chatData.push({ + // type: "reply", + // id: randString(32), + // icon: 'images/gpt-icon.png', + // content: reader.result + // }); + this.chatData[this.chatData.length - 1]["content"] += reader.result this.sending = false; // 将聊天框的滚动条滑动到最底部