diff --git a/api/core/app_server.go b/api/core/app_server.go index 3af4445e..16c864c4 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -28,7 +28,7 @@ type AppServer struct { // 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API ChatSession *types.LMap[string, types.ChatSession] //map[sessionId]UserId - ChatClients *types.LMap[string, *types.WsClient] // Websocket 连接集合 + ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合 ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function } diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 4e10b5b7..60858b84 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -15,11 +15,6 @@ type Message struct { FunctionCall FunctionCall `json:"function_call"` } -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - type ApiResponse struct { Choices []ChoiceItem `json:"choices"` } diff --git a/api/core/types/function.go b/api/core/types/function.go new file mode 100644 index 00000000..18d64715 --- /dev/null +++ b/api/core/types/function.go @@ -0,0 +1,18 @@ +package types + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type Function struct { + Name string + Description string + Parameters []Parameter +} + +type Parameter struct { + Type string + Required []string + Properties map[string]interface{} +} diff --git a/api/go.mod b/api/go.mod index 8a341c75..05ecab65 100644 --- a/api/go.mod +++ b/api/go.mod @@ -33,8 +33,8 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect go.uber.org/dig v1.16.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/net v0.7.0 // indirect - golang.org/x/text v0.7.0 // indirect + golang.org/x/net v0.9.0 // indirect + golang.org/x/text v0.9.0 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/ini.v1 v1.66.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect @@ -59,6 +59,6 @@ require ( go.uber.org/fx v1.19.3 go.uber.org/multierr v1.6.0 // indirect golang.org/x/crypto v0.6.0 - golang.org/x/sys v0.5.0 // indirect + golang.org/x/sys v0.7.0 // indirect gorm.io/gorm v1.25.1 ) diff --git a/api/go.sum b/api/go.sum index 2f19b477..e6a87d6c 100644 --- a/api/go.sum +++ b/api/go.sum @@ -140,16 +140,16 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 7c64d3a7..2c7dec5d 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -5,6 +5,7 @@ import ( "bytes" "chatplus/core" "chatplus/core/types" + "chatplus/service/function" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" @@ -29,11 +30,12 @@ const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" type ChatHandler struct { BaseHandler - db *gorm.DB + db *gorm.DB + funcZaoBao *function.FuncZaoBao } -func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler { - handler := ChatHandler{db: db} +func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBao) *ChatHandler { + handler := ChatHandler{db: db, funcZaoBao: zaoBao} handler.App = app return &handler } @@ -192,7 +194,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession Content: prompt, }) var apiKey string - response, err := h.doRequest(ctx, userVo, &apiKey, req) + response, err := h.fakeRequest(ctx, userVo, &apiKey, req) if err != nil { if strings.Contains(err.Error(), "context canceled") { logger.Info("用户取消了请求:", prompt) @@ -211,13 +213,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession defer response.Body.Close() } - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { + //contentType := response.Header.Get("Content-Type") + //if strings.Contains(contentType, "text/event-stream") || true { + if true { replyCreatedAt := time.Now() // 循环读取 Chunk 消息 var message = types.Message{} var contents = make([]string, 0) - var responseBody = types.ApiResponse{} + var functionCall = false + var functionName string + var arguments = make([]string, 0) reader := bufio.NewReader(response.Body) for { line, err := reader.ReadString('\n') @@ -229,10 +234,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession } break } - if !strings.Contains(line, "data:") { + if !strings.Contains(line, "data:") || len(line) < 30 { continue } + var responseBody = types.ApiResponse{} err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 logger.Error(err, line) @@ -241,6 +247,24 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession break } + fun := responseBody.Choices[0].Delta.FunctionCall + if functionCall && fun.Name == "" { + arguments = append(arguments, fun.Arguments) + continue + } + + if !utils.IsEmptyValue(fun) { + functionCall = true + functionName = fun.Name + replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)}) + continue + } + + if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 + break + } + // 初始化 role if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { message.Role = responseBody.Choices[0].Delta.Role @@ -258,6 +282,23 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession } } // end for + if functionCall { // 调用函数完成任务 + // TODO 调用函数完成任务 + data, err := h.funcZaoBao.Fetch() + if err != nil { + replyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: "调用函数出错", + }) + } else { + replyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: data, + }) + } + contents = append(contents, data) + } + // 消息发送成功 if len(contents) > 0 { // 更新用户的对话次数 @@ -272,8 +313,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession message.Content = strings.Join(contents, "") useMsg := types.Message{Role: "user", Content: prompt} - // 更新上下文消息 - if userVo.ChatConfig.EnableContext { + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if userVo.ChatConfig.EnableContext && functionCall == false { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) @@ -401,8 +442,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin if proxyURL == "" { client = &http.Client{} } else { // 使用代理 - uri := url.URL{} - proxy, _ := uri.Parse(proxyURL) + proxy, _ := url.Parse(proxyURL) client = &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxy), @@ -429,6 +469,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin return client.Do(request) } +func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) { + link := "https://img.r9it.com/chatgpt/response" + client := &http.Client{} + request, _ := http.NewRequest(http.MethodGet, link, nil) + return client.Do(request) +} + // 回复客户片段端消息 func replyChunkMessage(client types.Client, message types.WsMessage) { msg, err := json.Marshal(message) diff --git a/api/main.go b/api/main.go index 25dca9e4..4eedaab9 100644 --- a/api/main.go +++ b/api/main.go @@ -7,6 +7,7 @@ import ( "chatplus/handler/admin" logger2 "chatplus/logger" "chatplus/service" + "chatplus/service/function" "chatplus/store" "context" "embed" @@ -99,6 +100,12 @@ func main() { return xdb.NewWithBuffer(cBuff) }), + // 创建函数 + fx.Provide(func() *function.FuncZaoBao { + token := os.Getenv("AL_API_TOKEN") + return function.NewZaoBao(token) + }), + // 创建控制器 fx.Provide(handler.NewChatRoleHandler), fx.Provide(handler.NewUserHandler), diff --git a/api/service/function/zao_bao.go b/api/service/function/zao_bao.go new file mode 100644 index 00000000..851425d0 --- /dev/null +++ b/api/service/function/zao_bao.go @@ -0,0 +1,51 @@ +package function + +import ( + "chatplus/utils" + "fmt" + "strings" +) + +// 每日早报函数实现 + +type FuncZaoBao struct { + apiURL string + token string +} + +func NewZaoBao(token string) *FuncZaoBao { + return &FuncZaoBao{apiURL: "https://v2.alapi.cn/api/zaobao", token: token} +} + +type resVo struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + Date string `json:"date"` + News []string `json:"news"` + WeiYu string `json:"weiyu"` + } `json:"data"` +} + +func (f *FuncZaoBao) Fetch() (string, error) { + + url := fmt.Sprintf("%s?format=json&token=%s", f.apiURL, f.token) + bytes, err := utils.HttpGet(url, "") + if err != nil { + return "", err + } + var res resVo + err = utils.JsonDecode(string(bytes), &res) + if err != nil { + return "", err + } + + if res.Code != 200 { + return "", fmt.Errorf("call api fail: %s", res.Msg) + } + builder := make([]string, 0) + builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.Date)) + builder = append(builder, res.Data.News...) + builder = append(builder, fmt.Sprintf("%s", res.Data.WeiYu)) + return strings.Join(builder, "\n\n"), nil +} diff --git a/api/test/data.txt b/api/test/data.txt deleted file mode 100644 index abfa444c..00000000 --- a/api/test/data.txt +++ /dev/null @@ -1,32 +0,0 @@ -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"role":"assistant","content":null,"function_call":{"name":"browser","arguments":""}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"{\n"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":" "}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":" \""}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"keyword"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"\":"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":" \""}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"特"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"斯"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"拉"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"股"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"价"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"\"\n"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"function_call":{"arguments":"}"}},"finish_reason":null}]} - -data: {"id":"chatcmpl-7aKdKN9CaL7g0CI1wJv6PHxqNKZc8","object":"chat.completion.chunk","created":1688893478,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{},"finish_reason":"function_call"}]} - -data: [DONE] - diff --git a/api/test/test.go b/api/test/test.go index 56c6fd3a..c6ec4b0a 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -9,20 +9,17 @@ import ( "context" "encoding/json" "fmt" + "github.com/lionsoul2014/ip2region/binding/golang/xdb" + "github.com/pkoukk/tiktoken-go" "io" "log" "net/http" "os" "strings" "time" - - "github.com/lionsoul2014/ip2region/binding/golang/xdb" - "github.com/pkoukk/tiktoken-go" ) func main() { - err := extractFunction() - fmt.Println(err) } // Http client 取消操作 @@ -163,32 +160,43 @@ func testAesEncrypt() { } func extractFunction() error { - open, err := os.Open("data.txt") + open, err := os.Open("res/data.txt") if err != nil { return err } reader := bufio.NewReader(open) - //var contents = make([]string, 0) - var responseBody = types.ApiResponse{} - //var functionCall = false + var contents = make([]string, 0) + var functionCall = false + var functionName string for { line, err := reader.ReadString('\n') if err != nil { - return err + break } if !strings.Contains(line, "data:") { continue } + var responseBody = types.ApiResponse{} err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 - return err + break } - if !utils.IsEmptyValue(responseBody.Choices[0].Delta.FunctionCall) { - //functionCall = true - fmt.Println("函数调用") + function := responseBody.Choices[0].Delta.FunctionCall + if functionCall && function.Name == "" { + contents = append(contents, function.Arguments) + continue + } + + if !utils.IsEmptyValue(function) { + functionCall = true + functionName = function.Name continue } } + + fmt.Println("函数名称: ", functionName) + fmt.Println(strings.Join(contents, "")) + return err } diff --git a/api/utils/http.go b/api/utils/http.go new file mode 100644 index 00000000..80dfc7d1 --- /dev/null +++ b/api/utils/http.go @@ -0,0 +1,68 @@ +package utils + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/url" +) + +func HttpGet(uri string, proxy string) ([]byte, error) { + var client *http.Client + if proxy == "" { + client = &http.Client{} + } else { + proxy, _ := url.Parse(proxy) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + } + + req, err := http.NewRequest("GET", uri, nil) + if err != nil { + return nil, err + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(resp.Body) +} + +func HttpPost(uri string, params map[string]interface{}, proxy string) ([]byte, error) { + data, err := json.Marshal(params) + if err != nil { + return nil, err + } + + var client *http.Client + if proxy == "" { + client = &http.Client{} + } else { + proxy, _ := url.Parse(proxy) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + } + + req, err := http.NewRequest("POST", uri, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(resp.Body) +}