diff --git a/api/core/types/chat.go b/api/core/types/chat.go
index efcd344e..07e2e13f 100644
--- a/api/core/types/chat.go
+++ b/api/core/types/chat.go
@@ -6,8 +6,9 @@ type ApiRequest struct {
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
- Messages []interface{} `json:"messages"`
- Functions []Function `json:"functions"`
+ Messages []interface{} `json:"messages,omitempty"`
+ Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
+ Functions []Function `json:"functions,omitempty"`
}
type Message struct {
@@ -34,12 +35,18 @@ type Delta struct {
// ChatSession 聊天会话对象
type ChatSession struct {
- SessionId string `json:"session_id"`
- ClientIP string `json:"client_ip"` // 客户端 IP
- Username string `json:"username"` // 当前登录的 username
- UserId uint `json:"user_id"` // 当前登录的 user ID
- ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
- Model string `json:"model"` // GPT 模型
+ SessionId string `json:"session_id"`
+ ClientIP string `json:"client_ip"` // 客户端 IP
+ Username string `json:"username"` // 当前登录的 username
+ UserId uint `json:"user_id"` // 当前登录的 user ID
+ ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
+ Model ChatModel `json:"model"` // GPT 模型
+}
+
+type ChatModel struct {
+ Id uint `json:"id"`
+ Platform Platform `json:"platform"`
+ Value string `json:"value"`
}
type MjTask struct {
diff --git a/api/core/types/config.go b/api/core/types/config.go
index 5589e243..3a1755f7 100644
--- a/api/core/types/config.go
+++ b/api/core/types/config.go
@@ -64,6 +64,7 @@ type RedisConfig struct {
Host string
Port int
Password string
+ DB int
}
func (c RedisConfig) Url() string {
@@ -99,14 +100,31 @@ type Session struct {
// ChatConfig 系统默认的聊天配置
type ChatConfig struct {
- ApiURL string `json:"api_url,omitempty"`
- Model string `json:"model"` // 默认模型
- Temperature float32 `json:"temperature"`
- MaxTokens int `json:"max_tokens"`
- EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
- EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
- ApiKey string `json:"api_key"`
- ContextDeep int `json:"context_deep"` // 上下文深度
+ OpenAI ModelAPIConfig `json:"open_ai"`
+ Azure ModelAPIConfig `json:"azure"`
+ ChatGML ModelAPIConfig `json:"chat_gml"`
+
+ EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
+ EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
+ ContextDeep int `json:"context_deep"` // 上下文深度
+}
+
+type Platform string
+
+const OpenAI = Platform("OpenAI")
+const Azure = Platform("Azure")
+const ChatGML = Platform("ChatGML")
+
+// UserChatConfig 用户的聊天配置
+type UserChatConfig struct {
+ ApiKeys map[Platform]string
+}
+
+type ModelAPIConfig struct {
+ ApiURL string `json:"api_url,omitempty"`
+ Temperature float32 `json:"temperature"`
+ MaxTokens int `json:"max_tokens"`
+ ApiKey string `json:"api_key"`
}
type SystemConfig struct {
@@ -115,6 +133,8 @@ type SystemConfig struct {
Models []string `json:"models"`
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
InitImgCalls int `json:"init_img_calls"`
+ VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数
EnabledRegister bool `json:"enabled_register"`
EnabledMsgService bool `json:"enabled_msg_service"`
+ EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能
}
diff --git a/api/core/types/function.go b/api/core/types/function.go
index 020ab4bb..c422118d 100644
--- a/api/core/types/function.go
+++ b/api/core/types/function.go
@@ -87,11 +87,15 @@ var InnerFunctions = []Function{
},
"ar": {
Type: "string",
- Description: "图片长宽比,如 --ar 4:3",
+ Description: "图片长宽比,默认值 16:9",
},
"niji": {
Type: "string",
- Description: "动漫模型版本,例如 --niji 5",
+ Description: "动漫模型版本,默认值空",
+ },
+ "v": {
+ Type: "string",
+ Description: "模型版本,默认值: 5.2",
},
},
Required: []string{},
diff --git a/api/go.mod b/api/go.mod
index 507a8a3b..a6ac94ed 100644
--- a/api/go.mod
+++ b/api/go.mod
@@ -7,9 +7,12 @@ require (
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
github.com/gin-contrib/sessions v0.0.5
github.com/gin-gonic/gin v1.9.1
+ github.com/go-redis/redis/v8 v8.11.5
+ github.com/golang-jwt/jwt/v5 v5.0.0
github.com/gorilla/websocket v1.5.0
github.com/imroc/req/v3 v3.37.2
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0
+ github.com/minio/minio-go/v7 v7.0.62
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
github.com/syndtr/goleveldb v1.0.0
go.uber.org/zap v1.23.0
@@ -21,7 +24,9 @@ require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
github.com/bytedance/sonic v1.9.1 // indirect
+ github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
+ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
@@ -41,7 +46,6 @@ require (
github.com/klauspost/compress v1.16.7 // indirect
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
- github.com/minio/minio-go/v7 v7.0.62 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/onsi/ginkgo/v2 v2.10.0 // indirect
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
diff --git a/api/go.sum b/api/go.sum
index 45b58b34..8c496823 100644
--- a/api/go.sum
+++ b/api/go.sum
@@ -10,17 +10,22 @@ github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
+github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
+github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
+github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk=
@@ -39,6 +44,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
+github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
+github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
@@ -46,6 +53,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
+github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
+github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -77,7 +86,6 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
-github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
@@ -91,14 +99,10 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
-github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw=
-github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4=
github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I=
github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
-github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
-github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
@@ -124,9 +128,10 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
@@ -197,8 +202,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
-golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
@@ -210,8 +213,6 @@ golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
-golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
-golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -223,19 +224,14 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
-golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
-golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -253,9 +249,7 @@ google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
-gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
diff --git a/api/handler/admin/api_key_handler.go b/api/handler/admin/api_key_handler.go
index 06dd4db9..c6d977ef 100644
--- a/api/handler/admin/api_key_handler.go
+++ b/api/handler/admin/api_key_handler.go
@@ -28,7 +28,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
- UserId uint `json:"user_id"`
Value string `json:"value"`
LastUsedAt string `json:"last_used_at"`
CreatedAt int64 `json:"created_at"`
@@ -38,7 +37,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
return
}
- apiKey := model.ApiKey{Value: data.Value, UserId: data.UserId, LastUsedAt: utils.Str2stamp(data.LastUsedAt)}
+ apiKey := model.ApiKey{Value: data.Value, LastUsedAt: utils.Str2stamp(data.LastUsedAt)}
apiKey.Id = data.Id
if apiKey.Id > 0 {
apiKey.CreatedAt = time.Unix(data.CreatedAt, 0)
diff --git a/api/handler/admin/reward_handler.go b/api/handler/admin/reward_handler.go
index 5bfc42cd..b0470c37 100644
--- a/api/handler/admin/reward_handler.go
+++ b/api/handler/admin/reward_handler.go
@@ -46,7 +46,7 @@ func (h *RewardHandler) List(c *gin.Context) {
}
r.Id = v.Id
- r.Username = userMap[v.UserId].Username
+ r.Username = userMap[v.UserId].Mobile
r.CreatedAt = v.CreatedAt.Unix()
r.UpdatedAt = v.UpdatedAt.Unix()
rewards = append(rewards, r)
diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go
index fd254fe8..6fb5d4b1 100644
--- a/api/handler/admin/user_handler.go
+++ b/api/handler/admin/user_handler.go
@@ -8,8 +8,6 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
- "fmt"
-
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -100,22 +98,19 @@ func (h *UserHandler) Save(c *gin.Context) {
} else {
salt := utils.RandString(8)
u := model.User{
- Username: data.Username,
+ Mobile: data.Mobile,
Password: utils.GenPassword(data.Password, salt),
- Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(5)),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
- Mobile: data.Mobile,
ChatRoles: utils.JsonEncode(data.ChatRoles),
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
- ChatConfig: utils.JsonEncode(types.ChatConfig{
- Temperature: h.App.ChatConfig.Temperature,
- MaxTokens: h.App.ChatConfig.MaxTokens,
- EnableContext: h.App.ChatConfig.EnableContext,
- EnableHistory: true,
- Model: h.App.ChatConfig.Model,
- ApiKey: "",
+ ChatConfig: utils.JsonEncode(types.UserChatConfig{
+ ApiKeys: map[types.Platform]string{
+ types.OpenAI: "",
+ types.Azure: "",
+ types.ChatGML: "",
+ },
}),
Calls: h.App.SysConfig.UserInitCalls,
}
diff --git a/api/handler/azure_handler.go b/api/handler/azure_handler.go
new file mode 100644
index 00000000..1776d6a4
--- /dev/null
+++ b/api/handler/azure_handler.go
@@ -0,0 +1,301 @@
+package handler
+
+import (
+ "bufio"
+ "chatplus/core/types"
+ "chatplus/store/model"
+ "chatplus/store/vo"
+ "chatplus/utils"
+ "context"
+ "encoding/json"
+ "fmt"
+ "gorm.io/gorm"
+ "io"
+ "strings"
+ "time"
+ "unicode/utf8"
+)
+
+// 将消息发送给 Azure API 并获取结果,通过 WebSocket 推送到客户端
+func (h *ChatHandler) sendAzureMessage(
+ chatCtx []interface{},
+ req types.ApiRequest,
+ userVo vo.User,
+ ctx context.Context,
+ session *types.ChatSession,
+ role model.ChatRole,
+ prompt string,
+ ws *types.WsClient) error {
+ promptCreatedAt := time.Now() // 记录提问时间
+ start := time.Now()
+ var apiKey string
+ response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
+ if err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ logger.Info("用户取消了请求:", prompt)
+ return nil
+ } else if strings.Contains(err.Error(), "no available key") {
+ utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
+ return nil
+ } else {
+ logger.Error(err)
+ }
+
+ utils.ReplyMessage(ws, ErrorMsg)
+ utils.ReplyMessage(ws, "")
+ return err
+ } else {
+ defer response.Body.Close()
+ }
+
+ contentType := response.Header.Get("Content-Type")
+ if strings.Contains(contentType, "text/event-stream") {
+ replyCreatedAt := time.Now() // 记录回复时间
+ // 循环读取 Chunk 消息
+ var message = types.Message{}
+ var contents = make([]string, 0)
+ var functionCall = false
+ var functionName string
+ var arguments = make([]string, 0)
+ scanner := bufio.NewScanner(response.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.Contains(line, "data:") || len(line) < 30 {
+ continue
+ }
+
+ var responseBody = types.ApiResponse{}
+ err = json.Unmarshal([]byte(line[6:]), &responseBody)
+ if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
+ logger.Error(err, line)
+ utils.ReplyMessage(ws, ErrorMsg)
+ utils.ReplyMessage(ws, "")
+ break
+ }
+
+ fun := responseBody.Choices[0].Delta.FunctionCall
+ if functionCall && fun.Name == "" {
+ arguments = append(arguments, fun.Arguments)
+ continue
+ }
+
+ if !utils.IsEmptyValue(fun) {
+ functionName = fun.Name
+ f := h.App.Functions[functionName]
+ if f != nil {
+ functionCall = true
+ utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
+ utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
+ continue
+ }
+ }
+
+ if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
+ break
+ }
+
+ // 初始化 role
+ if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
+ message.Role = responseBody.Choices[0].Delta.Role
+ utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
+ continue
+ } else if responseBody.Choices[0].FinishReason != "" {
+ break // 输出完成或者输出中断了
+ } else {
+ content := responseBody.Choices[0].Delta.Content
+ contents = append(contents, utils.InterfaceToString(content))
+ utils.ReplyChunkMessage(ws, types.WsMessage{
+ Type: types.WsMiddle,
+ Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
+ })
+ }
+ } // end for
+
+ if err := scanner.Err(); err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ logger.Info("用户取消了请求:", prompt)
+ } else {
+ logger.Error("信息读取出错:", err)
+ }
+ }
+
+ if functionCall { // 调用函数完成任务
+ var params map[string]interface{}
+ _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
+ logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
+
+ // for creating image, check if the user's img_calls > 0
+ if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
+ utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
+ utils.ReplyMessage(ws, "")
+ } else {
+ f := h.App.Functions[functionName]
+ data, err := f.Invoke(params)
+ if err != nil {
+ msg := "调用函数出错:" + err.Error()
+ utils.ReplyChunkMessage(ws, types.WsMessage{
+ Type: types.WsMiddle,
+ Content: msg,
+ })
+ contents = append(contents, msg)
+ } else {
+ content := data
+ if functionName == types.FuncMidJourney {
+ key := utils.Sha256(data)
+ logger.Debug(data, ",", key)
+ // add task for MidJourney
+ h.App.MjTaskClients.Put(key, ws)
+ task := types.MjTask{
+ UserId: userVo.Id,
+ RoleId: role.Id,
+ Icon: "/images/avatar/mid_journey.png",
+ ChatId: session.ChatId,
+ }
+ err := h.leveldb.Put(types.TaskStorePrefix+key, task)
+ if err != nil {
+ logger.Error("error with store MidJourney task: ", err)
+ }
+ content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
+
+ // update user's img_calls
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
+ }
+
+ utils.ReplyChunkMessage(ws, types.WsMessage{
+ Type: types.WsMiddle,
+ Content: content,
+ })
+ contents = append(contents, content)
+ }
+ }
+ }
+
+ // 消息发送成功
+ if len(contents) > 0 {
+ // 更新用户的对话次数
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
+
+ if message.Role == "" {
+ message.Role = "assistant"
+ }
+ message.Content = strings.Join(contents, "")
+ useMsg := types.Message{Role: "user", Content: prompt}
+
+ // 更新上下文消息,如果是调用函数则不需要更新上下文
+ if userVo.ChatConfig.EnableContext && functionCall == false {
+ chatCtx = append(chatCtx, useMsg) // 提问消息
+ chatCtx = append(chatCtx, message) // 回复消息
+ h.App.ChatContexts.Put(session.ChatId, chatCtx)
+ }
+
+ // 追加聊天记录
+ if userVo.ChatConfig.EnableHistory {
+ useContext := true
+ if functionCall {
+ useContext = false
+ }
+
+ // for prompt
+ promptToken, err := utils.CalcTokens(prompt, req.Model)
+ if err != nil {
+ logger.Error(err)
+ }
+ historyUserMsg := model.HistoryMessage{
+ UserId: userVo.Id,
+ ChatId: session.ChatId,
+ RoleId: role.Id,
+ Type: types.PromptMsg,
+ Icon: userVo.Avatar,
+ Content: prompt,
+ Tokens: promptToken,
+ UseContext: useContext,
+ }
+ historyUserMsg.CreatedAt = promptCreatedAt
+ historyUserMsg.UpdatedAt = promptCreatedAt
+ res := h.db.Save(&historyUserMsg)
+ if res.Error != nil {
+ logger.Error("failed to save prompt history message: ", res.Error)
+ }
+
+ // for reply
+ // 计算本次对话消耗的总 token 数量
+ var replyToken = 0
+ if functionCall { // 函数名 + 参数 token
+ tokens, _ := utils.CalcTokens(functionName, req.Model)
+ replyToken += tokens
+ tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
+ replyToken += tokens
+ } else {
+ replyToken, _ = utils.CalcTokens(message.Content, req.Model)
+ }
+
+ historyReplyMsg := model.HistoryMessage{
+ UserId: userVo.Id,
+ ChatId: session.ChatId,
+ RoleId: role.Id,
+ Type: types.ReplyMsg,
+ Icon: role.Icon,
+ Content: message.Content,
+ Tokens: replyToken,
+ UseContext: useContext,
+ }
+ historyReplyMsg.CreatedAt = replyCreatedAt
+ historyReplyMsg.UpdatedAt = replyCreatedAt
+ res = h.db.Create(&historyReplyMsg)
+ if res.Error != nil {
+ logger.Error("failed to save reply history message: ", res.Error)
+ }
+
+ // 计算本次对话消耗的总 token 数量
+ var totalTokens = 0
+ if functionCall { // prompt + 函数名 + 参数 token
+ totalTokens = promptToken + replyToken
+ } else {
+ totalTokens = replyToken + getTotalTokens(req)
+ }
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
+ UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
+ UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
+ }
+
+ // 保存当前会话
+ var chatItem model.ChatItem
+ res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
+ if res.Error != nil {
+ chatItem.ChatId = session.ChatId
+ chatItem.UserId = session.UserId
+ chatItem.RoleId = role.Id
+ chatItem.ModelId = session.Model.Id
+ if utf8.RuneCountInString(prompt) > 30 {
+ chatItem.Title = string([]rune(prompt)[:30]) + "..."
+ } else {
+ chatItem.Title = prompt
+ }
+ h.db.Create(&chatItem)
+ }
+ }
+ } else {
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ return fmt.Errorf("error with reading response: %v", err)
+ }
+ var res types.ApiError
+ err = json.Unmarshal(body, &res)
+ if err != nil {
+ return fmt.Errorf("error with decode response: %v", err)
+ }
+
+ if strings.Contains(res.Error.Message, "maximum context length") {
+ logger.Error(res.Error.Message)
+ utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
+ h.App.ChatContexts.Delete(session.ChatId)
+ return h.sendMessage(ctx, session, role, prompt, ws)
+ } else {
+ utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message)
+ }
+ }
+
+ return nil
+}
diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go
index 5bbff3fd..0dc932e7 100644
--- a/api/handler/chat_handler.go
+++ b/api/handler/chat_handler.go
@@ -1,7 +1,6 @@
package handler
import (
- "bufio"
"bytes"
"chatplus/core"
"chatplus/core/types"
@@ -14,16 +13,14 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
+ "github.com/gin-gonic/gin"
+ "github.com/go-redis/redis/v8"
+ "github.com/gorilla/websocket"
+ "gorm.io/gorm"
"net/http"
"net/url"
"strings"
"time"
- "unicode/utf8"
-
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "gorm.io/gorm"
)
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
@@ -32,10 +29,11 @@ type ChatHandler struct {
BaseHandler
db *gorm.DB
leveldb *store.LevelDB
+ redis *redis.Client
}
-func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB) *ChatHandler {
- handler := ChatHandler{db: db, leveldb: levelDB}
+func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client) *ChatHandler {
+ handler := ChatHandler{db: db, leveldb: levelDB, redis: redis}
handler.App = app
return &handler
}
@@ -53,7 +51,17 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
sessionId := c.Query("session_id")
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
- chatModel := c.Query("model")
+ modelId := h.GetInt(c, "model_id", 0)
+
+ client := types.NewWsClient(ws)
+ // get model info
+ var chatModel model.ChatModel
+ res := h.db.First(&chatModel, modelId)
+ if res.Error != nil || chatModel.Enabled == false {
+ utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
+ c.Abort()
+ return
+ }
session := h.App.ChatSession.Get(sessionId)
if session == nil {
@@ -66,7 +74,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session = &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
- Username: user.Username,
+ Username: user.Mobile,
UserId: user.Id,
}
h.App.ChatSession.Put(sessionId, session)
@@ -74,16 +82,18 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
// use old chat data override the chat model and role ID
var chat model.ChatItem
- res := h.db.Where("chat_id=?", chatId).First(&chat)
+ res = h.db.Where("chat_id=?", chatId).First(&chat)
if res.Error == nil {
- chatModel = chat.Model
+ chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
}
session.ChatId = chatId
- session.Model = chatModel
+ session.Model = types.ChatModel{
+ Id: chatModel.Id,
+ Value: chatModel.Value,
+ Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
- client := types.NewWsClient(ws)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
@@ -133,9 +143,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}()
}
-// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
- promptCreatedAt := time.Now() // 记录提问时间
+ defer func() {
+ if r := recover(); r != nil {
+ logger.Error("Recover message from error: ", r)
+ }
+ }()
var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId)
@@ -156,7 +169,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
- if userVo.Calls <= 0 && userVo.ChatConfig.ApiKey == "" {
+ if userVo.Calls <= 0 {
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!")
utils.ReplyMessage(ws, "")
return nil
@@ -168,11 +181,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
var req = types.ApiRequest{
- Model: session.Model,
- Temperature: userVo.ChatConfig.Temperature,
- MaxTokens: userVo.ChatConfig.MaxTokens,
- Stream: true,
- Functions: types.InnerFunctions,
+ Model: session.Model.Value,
+ Stream: true,
+ }
+ switch session.Model.Platform {
+ case types.Azure:
+ req.Temperature = h.App.ChatConfig.Azure.Temperature
+ req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
+ break
+ case types.ChatGML:
+ req.Temperature = h.App.ChatConfig.ChatGML.Temperature
+ req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
+ break
+ default:
+ req.Temperature = h.App.ChatConfig.OpenAI.Temperature
+ req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
+ var functions = make([]types.Function, 0)
+ for _, f := range types.InnerFunctions {
+ if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
+ continue
+ }
+ functions = append(functions, f)
+ }
+ req.Functions = functions
}
// 加载聊天上下文
@@ -208,7 +239,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("created_at desc").Find(&historyMessages)
if res.Error == nil {
for _, msg := range historyMessages {
- if tokens+msg.Tokens >= types.ModelToTokens[session.Model] {
+ if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
break
}
tokens += msg.Tokens
@@ -232,341 +263,31 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
"role": "user",
"content": prompt,
})
- var apiKey string
- response, err := h.doRequest(ctx, userVo, &apiKey, req)
- if err != nil {
- if strings.Contains(err.Error(), "context canceled") {
- logger.Info("用户取消了请求:", prompt)
- return nil
- } else if strings.Contains(err.Error(), "no available key") {
- utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑,您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏")
- return nil
- } else {
- logger.Error(err)
- }
- utils.ReplyMessage(ws, ErrorMsg)
- utils.ReplyMessage(ws, "")
- return err
- } else {
- defer response.Body.Close()
- }
-
- contentType := response.Header.Get("Content-Type")
- if strings.Contains(contentType, "text/event-stream") {
- if true {
- replyCreatedAt := time.Now()
- // 循环读取 Chunk 消息
- var message = types.Message{}
- var contents = make([]string, 0)
- var functionCall = false
- var functionName string
- var arguments = make([]string, 0)
- reader := bufio.NewReader(response.Body)
- for {
- line, err := reader.ReadString('\n')
- if err != nil {
- if strings.Contains(err.Error(), "context canceled") {
- logger.Info("用户取消了请求:", prompt)
- } else if err != io.EOF {
- logger.Error("信息读取出错:", err)
- }
- break
- }
- if !strings.Contains(line, "data:") || len(line) < 30 {
- continue
- }
-
- var responseBody = types.ApiResponse{}
- err = json.Unmarshal([]byte(line[6:]), &responseBody)
- if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
- logger.Error(err, line)
- utils.ReplyMessage(ws, ErrorMsg)
- utils.ReplyMessage(ws, "")
- break
- }
-
- fun := responseBody.Choices[0].Delta.FunctionCall
- if functionCall && fun.Name == "" {
- arguments = append(arguments, fun.Arguments)
- continue
- }
-
- if !utils.IsEmptyValue(fun) {
- functionCall = true
- functionName = fun.Name
- f := h.App.Functions[functionName]
- utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
- utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
- continue
- }
-
- if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
- break
- }
-
- // 初始化 role
- if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
- message.Role = responseBody.Choices[0].Delta.Role
- utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
- continue
- } else if responseBody.Choices[0].FinishReason != "" {
- break // 输出完成或者输出中断了
- } else {
- content := responseBody.Choices[0].Delta.Content
- contents = append(contents, utils.InterfaceToString(content))
- utils.ReplyChunkMessage(ws, types.WsMessage{
- Type: types.WsMiddle,
- Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
- })
- }
- } // end for
-
- if functionCall { // 调用函数完成任务
- var params map[string]interface{}
- _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
- logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
-
- // for creating image, check if the user's img_calls > 0
- if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
- utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
- utils.ReplyMessage(ws, "")
- } else {
- f := h.App.Functions[functionName]
- data, err := f.Invoke(params)
- if err != nil {
- msg := "调用函数出错:" + err.Error()
- utils.ReplyChunkMessage(ws, types.WsMessage{
- Type: types.WsMiddle,
- Content: msg,
- })
- contents = append(contents, msg)
- } else {
- content := data
- if functionName == types.FuncMidJourney {
- key := utils.Sha256(data)
- logger.Debug(data, ",", key)
- // add task for MidJourney
- h.App.MjTaskClients.Put(key, ws)
- task := types.MjTask{
- UserId: userVo.Id,
- RoleId: role.Id,
- Icon: "/images/avatar/mid_journey.png",
- ChatId: session.ChatId,
- }
- err := h.leveldb.Put(types.TaskStorePrefix+key, task)
- if err != nil {
- logger.Error("error with store MidJourney task: ", err)
- }
- content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
-
- // update user's img_calls
- h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
- }
-
- utils.ReplyChunkMessage(ws, types.WsMessage{
- Type: types.WsMiddle,
- Content: content,
- })
- contents = append(contents, content)
- }
- }
- }
-
- // 消息发送成功
- if len(contents) > 0 {
- // 更新用户的对话次数
- if userVo.ChatConfig.ApiKey == "" { // 如果用户使用的是自己绑定的 API KEY 则不扣减对话次数
- h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
- }
-
- if message.Role == "" {
- message.Role = "assistant"
- }
- message.Content = strings.Join(contents, "")
- useMsg := types.Message{Role: "user", Content: prompt}
-
- // 更新上下文消息,如果是调用函数则不需要更新上下文
- if userVo.ChatConfig.EnableContext && functionCall == false {
- chatCtx = append(chatCtx, useMsg) // 提问消息
- chatCtx = append(chatCtx, message) // 回复消息
- h.App.ChatContexts.Put(session.ChatId, chatCtx)
- }
-
- // 追加聊天记录
- if userVo.ChatConfig.EnableHistory {
- useContext := true
- if functionCall {
- useContext = false
- }
-
- // for prompt
- promptToken, err := utils.CalcTokens(prompt, req.Model)
- if err != nil {
- logger.Error(err)
- }
- historyUserMsg := model.HistoryMessage{
- UserId: userVo.Id,
- ChatId: session.ChatId,
- RoleId: role.Id,
- Type: types.PromptMsg,
- Icon: user.Avatar,
- Content: prompt,
- Tokens: promptToken,
- UseContext: useContext,
- }
- historyUserMsg.CreatedAt = promptCreatedAt
- historyUserMsg.UpdatedAt = promptCreatedAt
- res := h.db.Save(&historyUserMsg)
- if res.Error != nil {
- logger.Error("failed to save prompt history message: ", res.Error)
- }
-
- // for reply
- // 计算本次对话消耗的总 token 数量
- var replyToken = 0
- if functionCall { // 函数名 + 参数 token
- tokens, _ := utils.CalcTokens(functionName, req.Model)
- replyToken += tokens
- tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
- replyToken += tokens
- } else {
- replyToken, _ = utils.CalcTokens(message.Content, req.Model)
- }
-
- historyReplyMsg := model.HistoryMessage{
- UserId: userVo.Id,
- ChatId: session.ChatId,
- RoleId: role.Id,
- Type: types.ReplyMsg,
- Icon: role.Icon,
- Content: message.Content,
- Tokens: replyToken,
- UseContext: useContext,
- }
- historyReplyMsg.CreatedAt = replyCreatedAt
- historyReplyMsg.UpdatedAt = replyCreatedAt
- res = h.db.Create(&historyReplyMsg)
- if res.Error != nil {
- logger.Error("failed to save reply history message: ", res.Error)
- }
-
- // 计算本次对话消耗的总 token 数量
- var totalTokens = 0
- if functionCall { // prompt + 函数名 + 参数 token
- totalTokens = promptToken + replyToken
- } else {
- totalTokens = replyToken + getTotalTokens(req)
- }
- //utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
- if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗
- h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
- totalTokens))
- }
- }
-
- // 保存当前会话
- var chatItem model.ChatItem
- res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
- if res.Error != nil {
- chatItem.ChatId = session.ChatId
- chatItem.UserId = session.UserId
- chatItem.RoleId = role.Id
- chatItem.Model = session.Model
- if utf8.RuneCountInString(prompt) > 30 {
- chatItem.Title = string([]rune(prompt)[:30]) + "..."
- } else {
- chatItem.Title = prompt
- }
- h.db.Create(&chatItem)
- }
- }
- }
- } else {
- body, err := io.ReadAll(response.Body)
- if err != nil {
- return fmt.Errorf("error with reading response: %v", err)
- }
- var res types.ApiError
- err = json.Unmarshal(body, &res)
- if err != nil {
- return fmt.Errorf("error with decode response: %v", err)
- }
-
- // OpenAI API 调用异常处理
- // TODO: 是否考虑重发消息?
- if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
- utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
- // 移除当前 API key
- h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
- } else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
- utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
- } else if strings.Contains(res.Error.Message, "This model's maximum context length") {
- logger.Error(res.Error.Message)
- utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
- h.App.ChatContexts.Delete(session.ChatId)
- return h.sendMessage(ctx, session, role, prompt, ws)
- } else {
- utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
- }
+ switch session.Model.Platform {
+ case types.Azure:
+ return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
+ case types.OpenAI:
+ return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
+ case types.ChatGML:
+ return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
return nil
}
-// 发送请求到 OpenAI 服务器
-// useOwnApiKey: 是否使用了用户自己的 API KEY
-func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
- var client *http.Client
- requestBody, err := json.Marshal(req)
- if err != nil {
- return nil, err
- }
- // 创建 HttpClient 请求对象
- request, err := http.NewRequest(http.MethodPost, h.App.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
- if err != nil {
- return nil, err
- }
-
- request = request.WithContext(ctx)
- request.Header.Add("Content-Type", "application/json")
-
- proxyURL := h.App.Config.ProxyURL
- if proxyURL == "" {
- client = &http.Client{}
- } else { // 使用代理
- proxy, _ := url.Parse(proxyURL)
- client = &http.Client{
- Transport: &http.Transport{
- Proxy: http.ProxyURL(proxy),
- },
- }
- }
- // 查询当前用户是否导入了自己的 API KEY
- if user.ChatConfig.ApiKey != "" {
- logger.Info("使用用户自己的 API KEY: ", user.ChatConfig.ApiKey)
- *apiKey = user.ChatConfig.ApiKey
- } else { // 获取系统的 API KEY
- var key model.ApiKey
- res := h.db.Where("user_id = ?", 0).Order("last_used_at ASC").First(&key)
- if res.Error != nil {
- return nil, errors.New("no available key, please import key")
- }
- *apiKey = key.Value
- // 更新 API KEY 的最后使用时间
- h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
- }
-
- logger.Infof("Sending OpenAI request, KEY: %s, PROXY: %s, Model: %s", *apiKey, proxyURL, req.Model)
- request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
- return client.Do(request)
-}
-
// Tokens 统计 token 数量
func (h *ChatHandler) Tokens(c *gin.Context) {
- text := c.Query("text")
- md := c.Query("model")
- tokens, err := utils.CalcTokens(text, md)
+ var data struct {
+ Text string `json:"text"`
+ Model string `json:"model"`
+ }
+ if err := c.ShouldBindJSON(&data); err != nil {
+ resp.ERROR(c, types.InvalidArgs)
+ return
+ }
+
+ tokens, err := utils.CalcTokens(data.Text, data.Model)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -604,3 +325,73 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
}
resp.SUCCESS(c, types.OkMsg)
}
+
+// 发送请求到 OpenAI 服务器
+// useOwnApiKey: 是否使用了用户自己的 API KEY
+func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
+
+ var apiURL string
+ switch platform {
+ case types.Azure:
+ md := strings.Replace(req.Model, ".", "", 1)
+ apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
+ break
+ case types.ChatGML:
+ apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
+ req.Prompt = req.Messages
+ req.Messages = nil
+ break
+ default:
+ apiURL = h.App.ChatConfig.OpenAI.ApiURL
+ }
+ // 创建 HttpClient 请求对象
+ var client *http.Client
+ requestBody, err := json.Marshal(req)
+ if err != nil {
+ return nil, err
+ }
+ request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
+ if err != nil {
+ return nil, err
+ }
+
+ request = request.WithContext(ctx)
+ request.Header.Set("Content-Type", "application/json")
+ proxyURL := h.App.Config.ProxyURL
+ if proxyURL != "" && platform == types.OpenAI { // 使用代理
+ proxy, _ := url.Parse(proxyURL)
+ client = &http.Client{
+ Transport: &http.Transport{
+ Proxy: http.ProxyURL(proxy),
+ },
+ }
+ } else {
+ client = http.DefaultClient
+ }
+ var key model.ApiKey
+ res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
+ if res.Error != nil {
+ return nil, errors.New("no available key, please import key")
+ }
+ // 更新 API KEY 的最后使用时间
+ h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
+
+ logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, key.Value, proxyURL, req.Model)
+ switch platform {
+ case types.Azure:
+ request.Header.Set("api-key", key.Value)
+ break
+ case types.ChatGML:
+ token, err := h.getChatGLMToken(key.Value)
+ if err != nil {
+ return nil, err
+ }
+ logger.Info(token)
+ request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
+ break
+ default:
+ request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
+ }
+ *apiKey = key.Value
+ return client.Do(request)
+}
diff --git a/api/handler/chatglm_handler.go b/api/handler/chatglm_handler.go
new file mode 100644
index 00000000..bedd92e5
--- /dev/null
+++ b/api/handler/chatglm_handler.go
@@ -0,0 +1,239 @@
+package handler
+
+import (
+ "bufio"
+ "chatplus/core/types"
+ "chatplus/store/model"
+ "chatplus/store/vo"
+ "chatplus/utils"
+ "context"
+ "encoding/json"
+ "fmt"
+ "github.com/golang-jwt/jwt/v5"
+ "gorm.io/gorm"
+ "io"
+ "strings"
+ "time"
+ "unicode/utf8"
+)
+
+// 将消息发送给 ChatGLM API 并获取结果,通过 WebSocket 推送到客户端
+func (h *ChatHandler) sendChatGLMMessage(
+ chatCtx []interface{},
+ req types.ApiRequest,
+ userVo vo.User,
+ ctx context.Context,
+ session *types.ChatSession,
+ role model.ChatRole,
+ prompt string,
+ ws *types.WsClient) error {
+ promptCreatedAt := time.Now() // 记录提问时间
+ start := time.Now()
+ var apiKey string
+ response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
+ if err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ logger.Info("用户取消了请求:", prompt)
+ return nil
+ } else if strings.Contains(err.Error(), "no available key") {
+ utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
+ return nil
+ } else {
+ logger.Error(err)
+ }
+
+ utils.ReplyMessage(ws, ErrorMsg)
+ utils.ReplyMessage(ws, "")
+ return err
+ } else {
+ defer response.Body.Close()
+ }
+
+ contentType := response.Header.Get("Content-Type")
+ if strings.Contains(contentType, "text/event-stream") {
+ replyCreatedAt := time.Now() // 记录回复时间
+ // 循环读取 Chunk 消息
+ var message = types.Message{}
+ var contents = make([]string, 0)
+ var event, content string
+ scanner := bufio.NewScanner(response.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if len(line) < 5 || strings.HasPrefix(line, "id:") {
+ continue
+ }
+ if strings.HasPrefix(line, "event:") {
+ event = line[6:]
+ continue
+ }
+
+ if strings.HasPrefix(line, "data:") {
+ content = line[5:]
+ }
+ switch event {
+ case "add":
+ if len(contents) == 0 {
+ utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
+ }
+ utils.ReplyChunkMessage(ws, types.WsMessage{
+ Type: types.WsMiddle,
+ Content: utils.InterfaceToString(content),
+ })
+ contents = append(contents, content)
+ case "finish":
+ break
+ case "error":
+ utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
+ break
+ case "interrupted":
+ utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
+ }
+
+ } // end for
+
+ if err := scanner.Err(); err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ logger.Info("用户取消了请求:", prompt)
+ } else {
+ logger.Error("信息读取出错:", err)
+ }
+ }
+
+ // 消息发送成功
+ if len(contents) > 0 {
+ // 更新用户的对话次数
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
+
+ if message.Role == "" {
+ message.Role = "assistant"
+ }
+ message.Content = strings.Join(contents, "")
+ useMsg := types.Message{Role: "user", Content: prompt}
+
+ // 更新上下文消息,如果是调用函数则不需要更新上下文
+ if userVo.ChatConfig.EnableContext {
+ chatCtx = append(chatCtx, useMsg) // 提问消息
+ chatCtx = append(chatCtx, message) // 回复消息
+ h.App.ChatContexts.Put(session.ChatId, chatCtx)
+ }
+
+ // 追加聊天记录
+ if userVo.ChatConfig.EnableHistory {
+ // for prompt
+ promptToken, err := utils.CalcTokens(prompt, req.Model)
+ if err != nil {
+ logger.Error(err)
+ }
+ historyUserMsg := model.HistoryMessage{
+ UserId: userVo.Id,
+ ChatId: session.ChatId,
+ RoleId: role.Id,
+ Type: types.PromptMsg,
+ Icon: userVo.Avatar,
+ Content: prompt,
+ Tokens: promptToken,
+ UseContext: true,
+ }
+ historyUserMsg.CreatedAt = promptCreatedAt
+ historyUserMsg.UpdatedAt = promptCreatedAt
+ res := h.db.Save(&historyUserMsg)
+ if res.Error != nil {
+ logger.Error("failed to save prompt history message: ", res.Error)
+ }
+
+ // for reply
+ // 计算本次对话消耗的总 token 数量
+ var replyToken = 0
+ replyToken, _ = utils.CalcTokens(message.Content, req.Model)
+
+ historyReplyMsg := model.HistoryMessage{
+ UserId: userVo.Id,
+ ChatId: session.ChatId,
+ RoleId: role.Id,
+ Type: types.ReplyMsg,
+ Icon: role.Icon,
+ Content: message.Content,
+ Tokens: replyToken,
+ UseContext: true,
+ }
+ historyReplyMsg.CreatedAt = replyCreatedAt
+ historyReplyMsg.UpdatedAt = replyCreatedAt
+ res = h.db.Create(&historyReplyMsg)
+ if res.Error != nil {
+ logger.Error("failed to save reply history message: ", res.Error)
+ }
+
+ // 计算本次对话消耗的总 token 数量
+ var totalTokens = 0
+ totalTokens = replyToken + getTotalTokens(req)
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
+ UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
+ UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
+ }
+
+ // 保存当前会话
+ var chatItem model.ChatItem
+ res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
+ if res.Error != nil {
+ chatItem.ChatId = session.ChatId
+ chatItem.UserId = session.UserId
+ chatItem.RoleId = role.Id
+ chatItem.ModelId = session.Model.Id
+ if utf8.RuneCountInString(prompt) > 30 {
+ chatItem.Title = string([]rune(prompt)[:30]) + "..."
+ } else {
+ chatItem.Title = prompt
+ }
+ h.db.Create(&chatItem)
+ }
+ }
+ } else {
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ return fmt.Errorf("error with reading response: %v", err)
+ }
+
+ var res struct {
+ Code int `json:"code"`
+ Success bool `json:"success"`
+ Msg string `json:"msg"`
+ }
+ err = json.Unmarshal(body, &res)
+ if err != nil {
+ return fmt.Errorf("error with decode response: %v", err)
+ }
+ if !res.Success {
+ utils.ReplyMessage(ws, "请求 ChatGML 失败:"+res.Msg)
+ }
+ }
+
+ return nil
+}
+
+func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) {
+ ctx := context.Background()
+ tokenString, err := h.redis.Get(ctx, apiKey).Result()
+ if err == nil {
+ return tokenString, nil
+ }
+
+ expr := time.Hour * 2
+ key := strings.Split(apiKey, ".")
+ if len(key) != 2 {
+ return "", fmt.Errorf("invalid api key: %s", apiKey)
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "api_key": key[0],
+ "timestamp": time.Now().Unix(),
+ "exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
+ })
+ token.Header["alg"] = "HS256"
+ token.Header["sign_type"] = "SIGN"
+ delete(token.Header, "typ")
+ // Sign and get the complete encoded token as a string using the secret
+ tokenString, err = token.SignedString([]byte(key[1]))
+ h.redis.Set(ctx, apiKey, tokenString, expr)
+ return tokenString, err
+}
diff --git a/api/handler/openai_handler.go b/api/handler/openai_handler.go
new file mode 100644
index 00000000..68a6c6e4
--- /dev/null
+++ b/api/handler/openai_handler.go
@@ -0,0 +1,308 @@
+package handler
+
+import (
+ "bufio"
+ "chatplus/core/types"
+ "chatplus/store/model"
+ "chatplus/store/vo"
+ "chatplus/utils"
+ "context"
+ "encoding/json"
+ "fmt"
+ "gorm.io/gorm"
+ "io"
+ "strings"
+ "time"
+ "unicode/utf8"
+)
+
+// 将消息发送给 OpenAI API 并获取结果,通过 WebSocket 推送到客户端
+func (h *ChatHandler) sendOpenAiMessage(
+ chatCtx []interface{},
+ req types.ApiRequest,
+ userVo vo.User,
+ ctx context.Context,
+ session *types.ChatSession,
+ role model.ChatRole,
+ prompt string,
+ ws *types.WsClient) error {
+ promptCreatedAt := time.Now() // 记录提问时间
+ start := time.Now()
+ var apiKey string
+ response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
+ logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
+ if err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ logger.Info("用户取消了请求:", prompt)
+ return nil
+ } else if strings.Contains(err.Error(), "no available key") {
+ utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
+ return nil
+ } else {
+ logger.Error(err)
+ }
+
+ utils.ReplyMessage(ws, ErrorMsg)
+ utils.ReplyMessage(ws, "")
+ return err
+ } else {
+ defer response.Body.Close()
+ }
+
+ contentType := response.Header.Get("Content-Type")
+ if strings.Contains(contentType, "text/event-stream") {
+ replyCreatedAt := time.Now() // 记录回复时间
+ // 循环读取 Chunk 消息
+ var message = types.Message{}
+ var contents = make([]string, 0)
+ var functionCall = false
+ var functionName string
+ var arguments = make([]string, 0)
+ scanner := bufio.NewScanner(response.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.Contains(line, "data:") || len(line) < 30 {
+ continue
+ }
+
+ var responseBody = types.ApiResponse{}
+ err = json.Unmarshal([]byte(line[6:]), &responseBody)
+ if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
+ logger.Error(err, line)
+ utils.ReplyMessage(ws, ErrorMsg)
+ utils.ReplyMessage(ws, "")
+ break
+ }
+
+ fun := responseBody.Choices[0].Delta.FunctionCall
+ if functionCall && fun.Name == "" {
+ arguments = append(arguments, fun.Arguments)
+ continue
+ }
+
+ if !utils.IsEmptyValue(fun) {
+ functionName = fun.Name
+ f := h.App.Functions[functionName]
+ if f != nil {
+ functionCall = true
+ utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
+ utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
+ }
+ continue
+ }
+
+ if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
+ break
+ }
+
+ // 初始化 role
+ if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
+ message.Role = responseBody.Choices[0].Delta.Role
+ utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
+ continue
+ } else if responseBody.Choices[0].FinishReason != "" {
+ break // 输出完成或者输出中断了
+ } else {
+ content := responseBody.Choices[0].Delta.Content
+ contents = append(contents, utils.InterfaceToString(content))
+ utils.ReplyChunkMessage(ws, types.WsMessage{
+ Type: types.WsMiddle,
+ Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
+ })
+ }
+ } // end for
+
+ if err := scanner.Err(); err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ logger.Info("用户取消了请求:", prompt)
+ } else {
+ logger.Error("信息读取出错:", err)
+ }
+ }
+
+ if functionCall { // 调用函数完成任务
+ var params map[string]interface{}
+ _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
+ logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
+
+ // for creating image, check if the user's img_calls > 0
+ if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
+ utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
+ utils.ReplyMessage(ws, "")
+ } else {
+ f := h.App.Functions[functionName]
+ data, err := f.Invoke(params)
+ if err != nil {
+ msg := "调用函数出错:" + err.Error()
+ utils.ReplyChunkMessage(ws, types.WsMessage{
+ Type: types.WsMiddle,
+ Content: msg,
+ })
+ contents = append(contents, msg)
+ } else {
+ content := data
+ if functionName == types.FuncMidJourney {
+ key := utils.Sha256(data)
+ logger.Debug(data, ",", key)
+ // add task for MidJourney
+ h.App.MjTaskClients.Put(key, ws)
+ task := types.MjTask{
+ UserId: userVo.Id,
+ RoleId: role.Id,
+ Icon: "/images/avatar/mid_journey.png",
+ ChatId: session.ChatId,
+ }
+ err := h.leveldb.Put(types.TaskStorePrefix+key, task)
+ if err != nil {
+ logger.Error("error with store MidJourney task: ", err)
+ }
+ content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
+
+ // update user's img_calls
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
+ }
+
+ utils.ReplyChunkMessage(ws, types.WsMessage{
+ Type: types.WsMiddle,
+ Content: content,
+ })
+ contents = append(contents, content)
+ }
+ }
+ }
+
+ // 消息发送成功
+ if len(contents) > 0 {
+ // 更新用户的对话次数
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
+
+ if message.Role == "" {
+ message.Role = "assistant"
+ }
+ message.Content = strings.Join(contents, "")
+ useMsg := types.Message{Role: "user", Content: prompt}
+
+ // 更新上下文消息,如果是调用函数则不需要更新上下文
+ if userVo.ChatConfig.EnableContext && functionCall == false {
+ chatCtx = append(chatCtx, useMsg) // 提问消息
+ chatCtx = append(chatCtx, message) // 回复消息
+ h.App.ChatContexts.Put(session.ChatId, chatCtx)
+ }
+
+ // 追加聊天记录
+ if userVo.ChatConfig.EnableHistory {
+ useContext := true
+ if functionCall {
+ useContext = false
+ }
+
+ // for prompt
+ promptToken, err := utils.CalcTokens(prompt, req.Model)
+ if err != nil {
+ logger.Error(err)
+ }
+ historyUserMsg := model.HistoryMessage{
+ UserId: userVo.Id,
+ ChatId: session.ChatId,
+ RoleId: role.Id,
+ Type: types.PromptMsg,
+ Icon: userVo.Avatar,
+ Content: prompt,
+ Tokens: promptToken,
+ UseContext: useContext,
+ }
+ historyUserMsg.CreatedAt = promptCreatedAt
+ historyUserMsg.UpdatedAt = promptCreatedAt
+ res := h.db.Save(&historyUserMsg)
+ if res.Error != nil {
+ logger.Error("failed to save prompt history message: ", res.Error)
+ }
+
+ // for reply
+ // 计算本次对话消耗的总 token 数量
+ var replyToken = 0
+ if functionCall { // 函数名 + 参数 token
+ tokens, _ := utils.CalcTokens(functionName, req.Model)
+ replyToken += tokens
+ tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
+ replyToken += tokens
+ } else {
+ replyToken, _ = utils.CalcTokens(message.Content, req.Model)
+ }
+
+ historyReplyMsg := model.HistoryMessage{
+ UserId: userVo.Id,
+ ChatId: session.ChatId,
+ RoleId: role.Id,
+ Type: types.ReplyMsg,
+ Icon: role.Icon,
+ Content: message.Content,
+ Tokens: replyToken,
+ UseContext: useContext,
+ }
+ historyReplyMsg.CreatedAt = replyCreatedAt
+ historyReplyMsg.UpdatedAt = replyCreatedAt
+ res = h.db.Create(&historyReplyMsg)
+ if res.Error != nil {
+ logger.Error("failed to save reply history message: ", res.Error)
+ }
+
+ // 计算本次对话消耗的总 token 数量
+ var totalTokens = 0
+ if functionCall { // prompt + 函数名 + 参数 token
+ totalTokens = promptToken + replyToken
+ } else {
+ totalTokens = replyToken + getTotalTokens(req)
+ }
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
+ UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens))
+ h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
+ UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
+ }
+
+ // 保存当前会话
+ var chatItem model.ChatItem
+ res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
+ if res.Error != nil {
+ chatItem.ChatId = session.ChatId
+ chatItem.UserId = session.UserId
+ chatItem.RoleId = role.Id
+ chatItem.ModelId = session.Model.Id
+ if utf8.RuneCountInString(prompt) > 30 {
+ chatItem.Title = string([]rune(prompt)[:30]) + "..."
+ } else {
+ chatItem.Title = prompt
+ }
+ h.db.Create(&chatItem)
+ }
+ }
+ } else {
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ return fmt.Errorf("error with reading response: %v", err)
+ }
+ var res types.ApiError
+ err = json.Unmarshal(body, &res)
+ if err != nil {
+ return fmt.Errorf("error with decode response: %v", err)
+ }
+
+ // OpenAI API 调用异常处理
+ if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
+ utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
+ // 移除当前 API key
+ h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
+ } else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
+ utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
+ } else if strings.Contains(res.Error.Message, "This model's maximum context length") {
+ logger.Error(res.Error.Message)
+ utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
+ h.App.ChatContexts.Delete(session.ChatId)
+ return h.sendMessage(ctx, session, role, prompt, ws)
+ } else {
+ utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
+ }
+ }
+
+ return nil
+}
diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go
index 7e21dd44..79016ff8 100644
--- a/api/handler/user_handler.go
+++ b/api/handler/user_handler.go
@@ -9,7 +9,6 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
- "fmt"
"strings"
"time"
@@ -42,20 +41,18 @@ func NewUserHandler(
func (h *UserHandler) Register(c *gin.Context) {
// parameters process
var data struct {
- Username string `json:"username"`
- Password string `json:"password"`
Mobile string `json:"mobile"`
+ Password string `json:"password"`
Code int `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
- data.Username = strings.TrimSpace(data.Username)
data.Password = strings.TrimSpace(data.Password)
- if len(data.Username) < 5 {
- resp.ERROR(c, "用户名长度不能少于5个字符")
+ if len(data.Mobile) < 10 {
+ resp.ERROR(c, "请输入合法的手机号")
return
}
if len(data.Password) < 8 {
@@ -77,13 +74,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// check if the username is exists
var item model.User
- res := h.db.Where("username = ?", data.Username).First(&item)
- if res.RowsAffected > 0 {
- resp.ERROR(c, "用户名已存在")
- return
- }
-
- res = h.db.Where("mobile = ?", data.Mobile).First(&item)
+ res := h.db.Where("mobile = ?", data.Mobile).First(&item)
if res.RowsAffected > 0 {
resp.ERROR(c, "该手机号码以及被注册,请更换其他手机号")
return
@@ -99,21 +90,18 @@ func (h *UserHandler) Register(c *gin.Context) {
salt := utils.RandString(8)
user := model.User{
- Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
- Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(5)),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
Mobile: data.Mobile,
ChatRoles: utils.JsonEncode(roleKeys),
- ChatConfig: utils.JsonEncode(types.ChatConfig{
- Temperature: h.App.ChatConfig.Temperature,
- MaxTokens: h.App.ChatConfig.MaxTokens,
- EnableContext: h.App.ChatConfig.EnableContext,
- EnableHistory: true,
- Model: h.App.ChatConfig.Model,
- ApiKey: "",
+ ChatConfig: utils.JsonEncode(types.UserChatConfig{
+ ApiKeys: map[types.Platform]string{
+ types.OpenAI: "",
+ types.Azure: "",
+ types.ChatGML: "",
+ },
}),
Calls: h.App.SysConfig.UserInitCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
@@ -134,15 +122,15 @@ func (h *UserHandler) Register(c *gin.Context) {
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
var data struct {
- Username string
- Password string
+ Username string `json:"mobile"`
+ Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.User
- res := h.db.Where("username = ? OR mobile = ?", data.Username, data.Username).First(&user)
+ res := h.db.Where("mobile = ?", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户名不存在")
return
@@ -173,7 +161,7 @@ func (h *UserHandler) Login(c *gin.Context) {
h.db.Create(&model.UserLoginLog{
UserId: user.Id,
- Username: user.Username,
+ Username: user.Mobile,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
@@ -218,15 +206,13 @@ func (h *UserHandler) Session(c *gin.Context) {
}
type userProfile struct {
- Id uint `json:"id"`
- Username string `json:"username"`
- Nickname string `json:"nickname"`
- Mobile string `json:"mobile"`
- Avatar string `json:"avatar"`
- ChatConfig types.ChatConfig `json:"chat_config"`
- Calls int `json:"calls"`
- ImgCalls int `json:"img_calls"`
- Tokens int64 `json:"tokens"`
+ Id uint `json:"id"`
+ Mobile string `json:"mobile"`
+ Avatar string `json:"avatar"`
+ ChatConfig types.UserChatConfig `json:"chat_config"`
+ Calls int `json:"calls"`
+ ImgCalls int `json:"img_calls"`
+ TotalTokens int64 `json:"total_tokens"`
}
func (h *UserHandler) Profile(c *gin.Context) {
@@ -262,25 +248,9 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
return
}
h.db.First(&user, user.Id)
- user.Nickname = data.Nickname
oldAvatar := user.Avatar
user.Avatar = data.Avatar
-
- var chatConfig types.ChatConfig
- err = utils.JsonDecode(user.ChatConfig, &chatConfig)
- if err != nil {
- resp.ERROR(c, "用户配置解析失败")
- return
- }
-
- chatConfig.EnableHistory = data.ChatConfig.EnableHistory
- chatConfig.EnableContext = data.ChatConfig.EnableContext
- chatConfig.Model = data.ChatConfig.Model
- chatConfig.MaxTokens = data.ChatConfig.MaxTokens
- chatConfig.ApiKey = data.ChatConfig.ApiKey
- chatConfig.Temperature = data.ChatConfig.Temperature
-
- user.ChatConfig = utils.JsonEncode(chatConfig)
+ user.ChatConfig = utils.JsonEncode(data.ChatConfig)
res := h.db.Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新用户信息失败")
diff --git a/api/main.go b/api/main.go
index 1621d0db..477e5253 100644
--- a/api/main.go
+++ b/api/main.go
@@ -89,6 +89,7 @@ func main() {
fx.Provide(store.NewGormConfig),
fx.Provide(store.NewMysql),
fx.Provide(store.NewLevelDB),
+ fx.Provide(store.NewRedisClient),
// 创建 Ip2Region 查询对象
fx.Provide(func() (*xdb.Searcher, error) {
diff --git a/api/store/model/api_key.go b/api/store/model/api_key.go
index 88686d26..80f8d44e 100644
--- a/api/store/model/api_key.go
+++ b/api/store/model/api_key.go
@@ -3,7 +3,7 @@ package model
// ApiKey OpenAI API 模型
type ApiKey struct {
BaseModel
- UserId uint //用户ID,系统添加的用户 ID 为 0
+ Platform string
Value string // API Key 的值
LastUsedAt int64 // 最后使用时间
}
diff --git a/api/store/model/chat_history.go b/api/store/model/chat_history.go
index b1eb85e7..8fa09683 100644
--- a/api/store/model/chat_history.go
+++ b/api/store/model/chat_history.go
@@ -1,5 +1,7 @@
package model
+import "gorm.io/gorm"
+
type HistoryMessage struct {
BaseModel
ChatId string // 会话 ID
@@ -10,6 +12,7 @@ type HistoryMessage struct {
Tokens int
Content string
UseContext bool // 是否可以作为聊天上下文
+ DeletedAt gorm.DeletedAt
}
func (HistoryMessage) TableName() string {
diff --git a/api/store/model/chat_item.go b/api/store/model/chat_item.go
index 1cdf7ba0..f0b653d1 100644
--- a/api/store/model/chat_item.go
+++ b/api/store/model/chat_item.go
@@ -1,10 +1,13 @@
package model
+import "gorm.io/gorm"
+
type ChatItem struct {
BaseModel
- ChatId string `gorm:"column:chat_id;unique"` // 会话 ID
- UserId uint // 用户 ID
- RoleId uint // 角色 ID
- Model string // 会话模型
- Title string // 会话标题
+ ChatId string `gorm:"column:chat_id;unique"` // 会话 ID
+ UserId uint // 用户 ID
+ RoleId uint // 角色 ID
+ ModelId uint // 会话模型
+ Title string // 会话标题
+ DeletedAt gorm.DeletedAt
}
diff --git a/api/store/model/chat_model.go b/api/store/model/chat_model.go
new file mode 100644
index 00000000..89639e7d
--- /dev/null
+++ b/api/store/model/chat_model.go
@@ -0,0 +1,10 @@
+package model
+
+type ChatModel struct {
+ BaseModel
+ Platform string
+ Name string
+ Value string // API Key 的值
+ SortNum int
+ Enabled bool
+}
diff --git a/api/store/model/user.go b/api/store/model/user.go
index 76644685..1129307c 100644
--- a/api/store/model/user.go
+++ b/api/store/model/user.go
@@ -2,13 +2,11 @@ package model
type User struct {
BaseModel
- Username string `gorm:"index:username,unique"`
Mobile string
Password string
- Nickname string
Avatar string
Salt string // 密码盐
- Tokens int64 // 剩余tokens
+ TotalTokens int64 // 总消耗 tokens
Calls int // 剩余对话次数
ImgCalls int // 剩余绘图次数
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
diff --git a/api/store/redis.go b/api/store/redis.go
new file mode 100644
index 00000000..c72521d2
--- /dev/null
+++ b/api/store/redis.go
@@ -0,0 +1,20 @@
+package store
+
+import (
+ "chatplus/core/types"
+ "context"
+ "github.com/go-redis/redis/v8"
+)
+
+func NewRedisClient(config *types.AppConfig) (*redis.Client, error) {
+ client := redis.NewClient(&redis.Options{
+ Addr: config.Redis.Url(),
+ Password: config.Redis.Password,
+ DB: config.Redis.DB,
+ })
+ _, err := client.Ping(context.Background()).Result()
+ if err != nil {
+ return nil, err
+ }
+ return client, nil
+}
diff --git a/api/store/vo/api_key.go b/api/store/vo/api_key.go
index f11af1fd..77403e74 100644
--- a/api/store/vo/api_key.go
+++ b/api/store/vo/api_key.go
@@ -3,7 +3,7 @@ package vo
// ApiKey OpenAI API 模型
type ApiKey struct {
BaseVo
- UserId uint `json:"user_id"` //用户ID,系统添加的用户 ID 为 0
+ Platform string `json:"platform"`
Value string `json:"value"` // API Key 的值
LastUsedAt int64 `json:"last_used_at"` // 最后使用时间
}
diff --git a/api/store/vo/chat_history.go b/api/store/vo/chat_history.go
index 60ca8838..1f4cbd4e 100644
--- a/api/store/vo/chat_history.go
+++ b/api/store/vo/chat_history.go
@@ -11,7 +11,3 @@ type HistoryMessage struct {
Content string `json:"content"`
UseContext bool `json:"use_context"`
}
-
-func (HistoryMessage) TableName() string {
- return "chatgpt_chat_history"
-}
diff --git a/api/store/vo/chat_item.go b/api/store/vo/chat_item.go
index d856138f..eeb34031 100644
--- a/api/store/vo/chat_item.go
+++ b/api/store/vo/chat_item.go
@@ -2,10 +2,10 @@ package vo
type ChatItem struct {
BaseVo
- UserId uint `json:"user_id"`
- Icon string `json:"icon"`
- RoleId uint `json:"role_id"`
- ChatId string `json:"chat_id"`
- Model string `json:"model"`
- Title string `json:"title"`
+ UserId uint `json:"user_id"`
+ Icon string `json:"icon"`
+ RoleId uint `json:"role_id"`
+ ChatId string `json:"chat_id"`
+ ModelId uint `json:"model_id"`
+ Title string `json:"title"`
}
diff --git a/api/store/vo/chat_model.go b/api/store/vo/chat_model.go
new file mode 100644
index 00000000..b42cdafa
--- /dev/null
+++ b/api/store/vo/chat_model.go
@@ -0,0 +1,9 @@
+package vo
+
+type ChatModel struct {
+ BaseVo
+ Platform string `json:"platform"`
+ Name string `json:"name"`
+ Value string `json:"value"`
+ Enabled bool `json:"enabled"`
+}
diff --git a/api/store/vo/user.go b/api/store/vo/user.go
index 37cd1626..bb13eec2 100644
--- a/api/store/vo/user.go
+++ b/api/store/vo/user.go
@@ -4,13 +4,11 @@ import "chatplus/core/types"
type User struct {
BaseVo
- Username string `json:"username"`
Mobile string `json:"mobile"`
- Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
- Salt string `json:"salt"` // 密码盐
- Tokens int64 `json:"tokens"` // 剩余tokens
- Calls int `json:"calls"` // 剩余对话次数
+ Salt string `json:"salt"` // 密码盐
+ TotalTokens int64 `json:"total_tokens"` // 总消耗tokens
+ Calls int `json:"calls"` // 剩余对话次数
ImgCalls int `json:"img_calls"`
ChatConfig types.ChatConfig `json:"chat_config"` // 聊天配置
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
diff --git a/database/update-v3.1.0.sql b/database/update-v3.1.0.sql
new file mode 100644
index 00000000..70aef44b
--- /dev/null
+++ b/database/update-v3.1.0.sql
@@ -0,0 +1,31 @@
+ALTER TABLE `chatgpt_chat_items` CHANGE `model` `model_id` INT(11) NOT NULL DEFAULT '0' COMMENT '模型 ID';
+ALTER TABLE `chatgpt_api_keys` ADD `platform` CHAR(20) DEFAULT NULL COMMENT '平台' AFTER id;
+ALTER TABLE `chatgpt_users` CHANGE `tokens` `total_tokens` BIGINT NOT NULL DEFAULT '0' COMMENT '累计消耗 tokens';
+ALTER TABLE `chatgpt_chat_items` ADD `deleted_at` DATETIME NULL DEFAULT NULL AFTER `updated_at`;
+ALTER TABLE `chatgpt_chat_history` ADD `deleted_at` DATETIME NULL DEFAULT NULL AFTER `updated_at`;
+
+CREATE TABLE `chatgpt_chat_models` (
+ `id` int NOT NULL,
+ `platform` varchar(20) DEFAULT NULL COMMENT '模型平台',
+ `name` varchar(50) NOT NULL COMMENT '模型名称',
+ `value` varchar(50) NOT NULL COMMENT '模型值',
+ `sort_num` tinyint(1) NOT NULL COMMENT '排序数字',
+ `enabled` tinyint(1) NOT NULL DEFAULT '0' COMMENT '是否启用模型',
+ `created_at` datetime DEFAULT NULL,
+ `updated_at` datetime DEFAULT NULL
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='AI 模型表';
+ALTER TABLE `chatgpt_chat_models`
+ ADD PRIMARY KEY (`id`);
+ALTER TABLE `chatgpt_chat_models`
+ MODIFY `id` int NOT NULL AUTO_INCREMENT, AUTO_INCREMENT=7;
+
+INSERT INTO `chatgpt_chat_models` (`id`, `platform`, `name`, `value`, `sort_num`, `enabled`, `created_at`, `updated_at`) VALUES
+ (1, 'OpenAI', 'Bot GPT-3.5', 'gpt-3.5-turbo', 0, 1, '2023-08-23 12:06:36', '2023-09-02 16:49:36'),
+ (2, 'Azure', 'Bot Azure-3.5', 'gpt-3.5-turbo', 0, 1, '2023-08-23 12:15:30', '2023-09-02 16:49:46'),
+ (3, 'ChatGML', 'ChatGML-Pro', 'chatglm_pro', 3, 1, '2023-08-23 13:35:45', '2023-08-29 11:41:29'),
+ (5, 'ChatGML', 'ChatGLM-Std', 'chatglm_std', 2, 1, '2023-08-24 15:05:38', '2023-08-29 11:41:28'),
+ (6, 'ChatGML', 'ChatGLM-Lite', 'chatglm_lite', 4, 1, '2023-08-24 15:06:15', '2023-08-29 11:41:29');
+
+ALTER TABLE `chatgpt_users`
+DROP `username`,
+DROP `nickname`;
\ No newline at end of file
diff --git a/web/.env.development b/web/.env.development
new file mode 100644
index 00000000..a5d967bd
--- /dev/null
+++ b/web/.env.development
@@ -0,0 +1,7 @@
+VUE_APP_API_HOST=http://localhost:5678
+VUE_APP_WS_HOST=ws://localhost:5678
+VUE_APP_USER=geekmaster
+VUE_APP_PASS=12345678
+VUE_APP_ADMIN_USER=admin
+VUE_APP_ADMIN_PASS=admin123
+VUE_APP_TITLE="ChatGPT-PLUS V3"
diff --git a/web/.env.production b/web/.env.production
index 1ab0ff38..5481a75e 100644
--- a/web/.env.production
+++ b/web/.env.production
@@ -1,3 +1,7 @@
VUE_APP_API_HOST=
VUE_APP_WS_HOST=
-VUE_APP_BASE_URL=
+VUE_APP_USER=
+VUE_APP_PASS=
+VUE_APP_ADMIN_USER=
+VUE_APP_ADMIN_PASS=
+VUE_APP_TITLE="ChatGPT-PLUS V3"
diff --git a/web/.gitignore b/web/.gitignore
index 1614d33f..a4cad713 100644
--- a/web/.gitignore
+++ b/web/.gitignore
@@ -8,5 +8,4 @@ lerna-debug.log*
node_modules
dist
dist.tar.gz
-.env.development
diff --git a/web/src/App.vue b/web/src/App.vue
index 5b243cf8..610e4ac1 100644
--- a/web/src/App.vue
+++ b/web/src/App.vue
@@ -6,8 +6,7 @@
+import zhCn from 'element-plus/es/locale/lang/zh-cn';