mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 18:23:45 +08:00
接入 ChatGPT API
This commit is contained in:
@@ -1,9 +1,17 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"openai/types"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) Chat(c *gin.Context) {
|
func (s *Server) Chat(c *gin.Context) {
|
||||||
@@ -23,12 +31,113 @@ func (s *Server) Chat(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: 接受消息,调用 ChatGPT 返回消息
|
|
||||||
logger.Info(string(message))
|
logger.Info(string(message))
|
||||||
err = client.Send(message)
|
for {
|
||||||
if err != nil {
|
err = client.Send([]byte("H"))
|
||||||
logger.Error(err)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
// TODO: 根据会话请求,传入不同的用户 ID
|
||||||
|
//err = s.sendMessage("test", string(message), client)
|
||||||
|
//if err != nil {
|
||||||
|
// logger.Error(err)
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||||
|
var r = types.ApiRequest{
|
||||||
|
Model: "gpt-3.5-turbo",
|
||||||
|
Temperature: 0.9,
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
var history []types.Message
|
||||||
|
if v, ok := s.History[userId]; ok {
|
||||||
|
history = v
|
||||||
|
} else {
|
||||||
|
history = make([]types.Message, 0)
|
||||||
|
}
|
||||||
|
r.Messages = append(history, types.Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: text,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("上下文历史消息:%+v", s.History[userId])
|
||||||
|
requestBody, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
request, err := http.NewRequest(http.MethodPost, s.Config.OpenAi.ApiURL, bytes.NewBuffer(requestBody))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: API KEY 负载均衡
|
||||||
|
request.Header.Add("Content-Type", "application/json")
|
||||||
|
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKey[0]))
|
||||||
|
|
||||||
|
uri := url.URL{}
|
||||||
|
proxy, _ := uri.Parse(s.Config.ProxyURL)
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxy),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
response, err := client.Do(request)
|
||||||
|
var retryCount = 3
|
||||||
|
for err != nil {
|
||||||
|
if retryCount <= 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
response, err = client.Do(request)
|
||||||
|
retryCount--
|
||||||
|
}
|
||||||
|
|
||||||
|
var message = types.Message{}
|
||||||
|
var contents = make([]string, 0)
|
||||||
|
var responseBody = types.ApiResponse{}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(response.Body)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
fmt.Println(err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "" {
|
||||||
|
break
|
||||||
|
} else if len(line) < 20 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 初始化 role
|
||||||
|
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||||
|
message.Role = responseBody.Choices[0].Delta.Role
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
contents = append(contents, responseBody.Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
fmt.Print(responseBody.Choices[0].Delta.Content)
|
||||||
|
if responseBody.Choices[0].FinishReason != "" {
|
||||||
|
fmt.Println()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 追加历史消息
|
||||||
|
history = append(history, message)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,11 +30,12 @@ func (s StaticFile) Open(name string) (fs.File, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Config *types.Config
|
Config *types.Config
|
||||||
|
History map[string][]types.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(config *types.Config) *Server {
|
func NewServer(config *types.Config) *Server {
|
||||||
return &Server{Config: config}
|
return &Server{Config: config, History: make(map[string][]types.Message, 16)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Run(webRoot embed.FS, path string) {
|
func (s *Server) Run(webRoot embed.FS, path string) {
|
||||||
|
|||||||
@@ -9,14 +9,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Listen string
|
Listen string
|
||||||
Session Session
|
Session Session
|
||||||
OpenAi OpenAi
|
ProxyURL string
|
||||||
|
OpenAi OpenAi
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAi configs struct
|
// OpenAi configs struct
|
||||||
type OpenAi struct {
|
type OpenAi struct {
|
||||||
ApiKey string
|
ApiURL string
|
||||||
|
ApiKey []string
|
||||||
Model string
|
Model string
|
||||||
Temperature float32
|
Temperature float32
|
||||||
MaxTokens int
|
MaxTokens int
|
||||||
@@ -49,6 +51,8 @@ func NewDefaultConfig() *Config {
|
|||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
},
|
},
|
||||||
OpenAi: OpenAi{
|
OpenAi: OpenAi{
|
||||||
|
ApiURL: "https://api.openai.com/v1/chat/completions",
|
||||||
|
ApiKey: []string{""},
|
||||||
Model: "gpt-3.5-turbo",
|
Model: "gpt-3.5-turbo",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Temperature: 1.0,
|
Temperature: 1.0,
|
||||||
|
|||||||
25
types/gpt.go
Normal file
25
types/gpt.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
// ApiRequest API 请求实体
|
||||||
|
type ApiRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ApiResponse struct {
|
||||||
|
Choices []ChoiceItem `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChoiceItem API 响应实体
|
||||||
|
type ChoiceItem struct {
|
||||||
|
Delta Message `json:"delta"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type LockedMap struct {
|
|
||||||
lock sync.RWMutex
|
|
||||||
data map[string]interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLockedMap() *LockedMap {
|
|
||||||
return &LockedMap{
|
|
||||||
lock: sync.RWMutex{},
|
|
||||||
data: make(map[string]interface{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *LockedMap) Put(key string, value interface{}) {
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
|
|
||||||
m.data[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *LockedMap) Get(key string) interface{} {
|
|
||||||
m.lock.RLock()
|
|
||||||
defer m.lock.RUnlock()
|
|
||||||
|
|
||||||
return m.data[key]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *LockedMap) Delete(key string) {
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
|
|
||||||
delete(m.data, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *LockedMap) ToList() []interface{} {
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
|
|
||||||
var s = make([]interface{}, 0)
|
|
||||||
for _, v := range m.data {
|
|
||||||
s = append(s, v)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
@@ -108,12 +108,13 @@ export default defineComponent({
|
|||||||
const reader = new FileReader();
|
const reader = new FileReader();
|
||||||
reader.readAsText(event.data, "UTF-8");
|
reader.readAsText(event.data, "UTF-8");
|
||||||
reader.onload = () => {
|
reader.onload = () => {
|
||||||
this.chatData.push({
|
// this.chatData.push({
|
||||||
type: "reply",
|
// type: "reply",
|
||||||
id: randString(32),
|
// id: randString(32),
|
||||||
icon: 'images/gpt-icon.png',
|
// icon: 'images/gpt-icon.png',
|
||||||
content: reader.result
|
// content: reader.result
|
||||||
});
|
// });
|
||||||
|
this.chatData[this.chatData.length - 1]["content"] += reader.result
|
||||||
this.sending = false;
|
this.sending = false;
|
||||||
|
|
||||||
// 将聊天框的滚动条滑动到最底部
|
// 将聊天框的滚动条滑动到最底部
|
||||||
|
|||||||
Reference in New Issue
Block a user