mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-11 03:33:48 +08:00
接入 ChatGPT API
This commit is contained in:
@@ -1,9 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"openai/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) Chat(c *gin.Context) {
|
||||
@@ -23,12 +31,113 @@ func (s *Server) Chat(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 接受消息,调用 ChatGPT 返回消息
|
||||
logger.Info(string(message))
|
||||
err = client.Send(message)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
for {
|
||||
err = client.Send([]byte("H"))
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
// TODO: 根据会话请求,传入不同的用户 ID
|
||||
//err = s.sendMessage("test", string(message), client)
|
||||
//if err != nil {
|
||||
// logger.Error(err)
|
||||
//}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
var r = types.ApiRequest{
|
||||
Model: "gpt-3.5-turbo",
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: true,
|
||||
}
|
||||
var history []types.Message
|
||||
if v, ok := s.History[userId]; ok {
|
||||
history = v
|
||||
} else {
|
||||
history = make([]types.Message, 0)
|
||||
}
|
||||
r.Messages = append(history, types.Message{
|
||||
Role: "user",
|
||||
Content: text,
|
||||
})
|
||||
|
||||
logger.Info("上下文历史消息:%+v", s.History[userId])
|
||||
requestBody, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request, err := http.NewRequest(http.MethodPost, s.Config.OpenAi.ApiURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: API KEY 负载均衡
|
||||
request.Header.Add("Content-Type", "application/json")
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKey[0]))
|
||||
|
||||
uri := url.URL{}
|
||||
proxy, _ := uri.Parse(s.Config.ProxyURL)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
response, err := client.Do(request)
|
||||
var retryCount = 3
|
||||
for err != nil {
|
||||
if retryCount <= 0 {
|
||||
return err
|
||||
}
|
||||
response, err = client.Do(request)
|
||||
retryCount--
|
||||
}
|
||||
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var responseBody = types.ApiResponse{}
|
||||
|
||||
reader := bufio.NewReader(response.Body)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && err != io.EOF {
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
|
||||
if line == "" {
|
||||
break
|
||||
} else if len(line) < 20 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
continue
|
||||
}
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
continue
|
||||
} else {
|
||||
contents = append(contents, responseBody.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
fmt.Print(responseBody.Choices[0].Delta.Content)
|
||||
if responseBody.Choices[0].FinishReason != "" {
|
||||
fmt.Println()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 追加历史消息
|
||||
history = append(history, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user