diff --git a/api/core/types/config.go b/api/core/types/config.go index 355b08e6..2d712f85 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -145,6 +145,7 @@ const ChatGLM = Platform("ChatGLM") const Baidu = Platform("Baidu") const XunFei = Platform("XunFei") const QWen = Platform("QWen") +const Ollama = Platform("Ollama") type SystemConfig struct { Title string `json:"title,omitempty"` diff --git a/api/go.mod b/api/go.mod index 095eb217..5549d546 100644 --- a/api/go.mod +++ b/api/go.mod @@ -37,6 +37,8 @@ require ( github.com/go-ole/go-ole v1.2.6 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect + github.com/google/go-querystring v1.1.0 // indirect + github.com/levigross/grequests v0.0.0-20231203190023-9c307ef1f48d // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/tklauser/numcpus v0.7.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect @@ -86,12 +88,12 @@ require ( go.uber.org/dig v1.16.1 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect - golang.org/x/mod v0.11.0 // indirect - golang.org/x/net v0.14.0 // indirect - golang.org/x/sync v0.3.0 // indirect - golang.org/x/text v0.12.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.26.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.3.0 // indirect - golang.org/x/tools v0.10.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect @@ -111,7 +113,7 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/fx v1.19.3 go.uber.org/multierr v1.6.0 // indirect - golang.org/x/crypto v0.12.0 - golang.org/x/sys v0.15.0 // indirect + golang.org/x/crypto v0.24.0 + golang.org/x/sys v0.21.0 // indirect gorm.io/gorm v1.25.1 ) diff --git a/api/go.sum b/api/go.sum index 64ea4f3a..13fb41bb 100644 --- a/api/go.sum +++ b/api/go.sum @@ -75,8 +75,11 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= @@ -119,6 +122,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/levigross/grequests v0.0.0-20231203190023-9c307ef1f48d h1:8fVmm2qScPn4JAF/YdTtqrPP3n58FgZ4GbKTNfaPuRs= +github.com/levigross/grequests v0.0.0-20231203190023-9c307ef1f48d/go.mod h1:dFu6nuJHC3u9kCDcyGrEL7LwhK2m6Mt+alyiiIjDrRY= github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 h1:LgmjED/yQILqmUED4GaXjrINWe7YJh4HM6z2EvEINPs= github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= @@ -247,6 +252,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= @@ -256,6 +263,8 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -265,12 +274,16 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -288,6 +301,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -299,6 +314,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -307,6 +324,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 8244bde6..198f26d3 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -235,7 +235,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio Stream: true, } switch session.Model.Platform { - case types.Azure, types.ChatGLM, types.Baidu, types.XunFei: + case types.Azure, types.ChatGLM, types.Baidu, types.XunFei, types.Ollama: req.Temperature = session.Model.Temperature req.MaxTokens = session.Model.MaxTokens break @@ -401,6 +401,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) case types.QWen: return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.Ollama: + return h.sendOllamaMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) } utils.ReplyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, diff --git a/api/handler/chatimpl/ollama_handler.go b/api/handler/chatimpl/ollama_handler.go new file mode 100644 index 00000000..4df36401 --- /dev/null +++ b/api/handler/chatimpl/ollama_handler.go @@ -0,0 +1,315 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "net/http" + "strings" + "time" + "unicode/utf8" + + "github.com/levigross/grequests" + + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" +) + +type ollamaResp struct { + Id string `json:"id"` + Model string `json:"model"` + + CreatedAt string `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + Context []int `json:"context"` + + TotalDuration int64 `json:"total_duration"` // 生成响应所花费的总时间 + LoadDuration int64 `json:"load_duration"` // 以纳秒为单位加载模型所花费的时间 + PromptEvalCount int `json:"prompt_eval_count"` // 提示文本中的标记(tokens)数量 + PromptEvalDuration int64 `json:"prompt_eval_duration"` // 以纳秒为单位评估提示文本所花费的时间 + EvalCount int64 `json:"eval_count"` // 生成响应中的标记数量 + EvalDuration int64 `json:"eval_duration"` // 以纳秒为单位生成响应所花费的时间 +} + +// 通义千问消息发送实现 +func (h *ChatHandler) sendOllamaMessage( + chatCtx []types.Message, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + + //var apiKey = model.ApiKey{} + //response, err := h.doRequest(ctx, req, session, &apiKey) + response, err := h.sendOllamaRequest(session, prompt) + defer response.Body.Close() + + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + + if err != nil { + h.processError(err, prompt, ws) + } + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "application/x-ndjson") { + + h.processOllamaStreamResponse(chatCtx, req, userVo, response, ws, prompt, session, role, promptCreatedAt) + + } else { + if err = h.processOllamaJsonResponse(response, ws); err != nil { + return err + } + } + + return nil +} + +func (h *ChatHandler) sendOllamaRequest(session *types.ChatSession, prompt string) (*http.Response, error) { + apiKey, err := h.queryApiKey(session) + if err != nil { + return nil, err + } + + // todo add context to request body + postData := map[string]interface{}{ + "model": session.Model.Value, + "stream": true, + "prompt": prompt, + } + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + apiKey.Value, + } + + ro := &grequests.RequestOptions{ + JSON: postData, + Headers: headers, + } + resp, err := grequests.Post(apiKey.ApiURL, ro) + if err != nil { + return nil, err + } + + if !resp.Ok { + return nil, resp.Error + } + + return resp.RawResponse, nil +} + +func (h *ChatHandler) queryApiKey(session *types.ChatSession) (*model.ApiKey, error) { + apiKey := &model.ApiKey{} + + // if the chat model bind a KEY, use it directly + if session.Model.KeyId > 0 { + h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey) + } + // use the last unused key + if apiKey.Id == 0 { + h.DB.Debug().Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey) + } + if apiKey.Id == 0 { + return nil, errors.New("no available key, please import key") + } + + h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + return apiKey, nil +} + +func (h *ChatHandler) processOllamaStreamResponse( + chatCtx []types.Message, req types.ApiRequest, userVo vo.User, + response *http.Response, ws *types.WsClient, prompt string, + session *types.ChatSession, role model.ChatRole, promptCreatedAt time.Time) { + + // 记录回复时间 + replyCreatedAt := time.Now() + // 循环读取 Chunk 消息 + var message = types.Message{} + scanner := bufio.NewScanner(response.Body) + + var content string + var replyTokens int + + for scanner.Scan() { + var resp ollamaResp + line := scanner.Text() + + err := utils.JsonDecode(line, &resp) + if err != nil { + logger.Error("error with parse data line: ", content) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) + break + } + + if resp.Done == true && resp.DoneReason == "stop" { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd}) + message.Content = utils.InterfaceToString(resp.Context) + replyTokens = resp.PromptEvalCount + + // 消息发送成功后做记录工作 + h.recordInfoAfterSendMessage(chatCtx, req, userVo, message, prompt, session, role, promptCreatedAt, replyTokens, replyCreatedAt) + + break + } else if resp.Done == true && resp.DoneReason != "stop" { + utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.DoneReason)) + break + } + + if len(resp.Id) > 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + + if len(resp.Response) > 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(resp.Response), + }) + } + + } + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } +} + +func (h *ChatHandler) processOllamaJsonResponse(response *http.Response, ws *types.WsClient) error { + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("error with reading response: %v", err) + } + + var res struct { + Code int `json:"error_code"` + Msg string `json:"error_msg"` + } + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("error with decode response: %v", err) + } + utils.ReplyMessage(ws, "请求Ollama大模型 API 失败:"+res.Msg) + return nil +} + +func (h *ChatHandler) recordInfoAfterSendMessage(chatCtx []types.Message, req types.ApiRequest, userVo vo.User, message types.Message, prompt string, session *types.ChatSession, role model.ChatRole, promptCreatedAt time.Time, replyTokens int, replyCreatedAt time.Time) { + if message.Role == "" { + message.Role = "assistant" + } + + useMsg := types.Message{Role: "user", Content: prompt} + + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if h.App.SysConfig.EnableContext { + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 + h.App.ChatContexts.Put(session.ChatId, chatCtx) + } + + // 追加聊天记录 + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.DB.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + //totalTokens := replyTokens + getTotalTokens(req) + // todo rebuild the tokens + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: replyTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.DB.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) + + // 保存当前会话 + var chatItem model.ChatItem + res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) + if res.Error != nil { + chatItem.ChatId = session.ChatId + chatItem.UserId = session.UserId + chatItem.RoleId = role.Id + chatItem.ModelId = session.Model.Id + if utf8.RuneCountInString(prompt) > 30 { + chatItem.Title = string([]rune(prompt)[:30]) + "..." + } else { + chatItem.Title = prompt + } + chatItem.Model = req.Model + h.DB.Create(&chatItem) + } +} + +func (h *ChatHandler) processError(err error, prompt string, ws *types.WsClient) { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + return + } else if strings.Contains(err.Error(), "no available key") { + utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + return + } else { + logger.Error(err) + } + + utils.ReplyMessage(ws, ErrorMsg) + utils.ReplyMessage(ws, ErrImg) + return +} diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index fec5d14c..33b9e0db 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -174,6 +174,11 @@ const platforms = ref([ value: "QWen", api_url: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" }, + { + name: "【Meta】Ollama", + value: "Ollama", + api_url: "http://localhost:8080/ollama/api/generate" + }, ]) const types = ref([ {name: "聊天", value: "chat"}, diff --git a/web/src/views/admin/ChatModel.vue b/web/src/views/admin/ChatModel.vue index 7eea77bd..582c8cab 100644 --- a/web/src/views/admin/ChatModel.vue +++ b/web/src/views/admin/ChatModel.vue @@ -221,6 +221,7 @@ const platforms = ref([ {name: "【百度】文心一言", value: "Baidu"}, {name: "【微软】Azure", value: "Azure"}, {name: "【阿里】通义千问", value: "QWen"}, + {name: "【Meta】Ollama", value: "Ollama"}, ])