Compare commits

...

78 Commits

Author SHA1 Message Date
JustSong
91b80ae879 fix: remove extra space 2024-05-07 23:57:34 +08:00
JustSong
2720e1a358 feat: support minimax's 6.5 models (close #1395) 2024-04-30 02:23:14 +08:00
JustSong
71f4403fd5 feat: add together.ai support (#1298) 2024-04-30 02:16:53 +08:00
JustSong
1f76c80553 fix: fix aws claude panic (#1384) 2024-04-29 22:49:06 +08:00
JustSong
7e027d2bd0 fix: fix minimax prompt & completion tokens is empty (#1391) 2024-04-29 22:35:47 +08:00
JustSong
30f373b623 fix: fix usage is empty (close #1391) 2024-04-29 22:29:13 +08:00
plusye
1c2654320e fix: fix getPreConsumedQuota (#1312) 2024-04-27 16:07:06 +08:00
caixinjiang
6cffb116b7 fix: fix zhipu embedding error when input is array but not string (#1306)
* fix zhipu embedding error when input is array but not string

* fix: only use the first one

---------

Co-authored-by: 蔡新疆 <cxj@icc.link>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-27 16:05:14 +08:00
Qiying Wang
a84c7b38b7 fix: claude stream response parse (#1334) 2024-04-27 15:58:07 +08:00
tylinux
1bd14af47b feat: use mapped model name to test (#1370) 2024-04-27 15:53:20 +08:00
NongMO
6170b91d1c feat: support for the ollama vision model (#1376)
* feat: support for the ollama vision model

`llava` model, pass test

* Update main.go

format code

* chore: remove useless log

---------

Co-authored-by: nongqiqin <nongqiqin@tipdm.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-27 15:47:27 +08:00
JustSong
04b49aa0ec chore: use StringContent() to convert response to text 2024-04-27 15:41:02 +08:00
Wei Tingjiang
ef88497f25 fix: refactor Gemini adaptor to support streaming content generation (#1382) 2024-04-27 15:39:59 +08:00
JustSong
007906216d feat: support DeepL's model (close #1126) 2024-04-27 13:37:22 +08:00
JustSong
e64e7707a0 feat: support cohere's web search 2024-04-27 00:06:43 +08:00
JustSong
ea210b6ed7 chore: update ollama models 2024-04-26 23:12:39 +08:00
JustSong
9026ec7510 feat: support cloudflare now 2024-04-26 23:05:48 +08:00
JustSong
c317872097 feat: support deepseek now 2024-04-26 00:48:53 +08:00
JustSong
da0842272c fix: add model to response (close #1362) 2024-04-24 22:19:58 +08:00
JustSong
0a650b85b4 chore: update berry 2024-04-24 22:08:47 +08:00
Ghostz
24f026d18e feat: add cohere support (#1355)
* support cohere

* chore: tiny improvements

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-24 21:50:01 +08:00
tylinux
cb33e8aad5 fix: fix default theme blank screen when edit channel again (#1363)
* fix: throw exception after submit channel edit

* fix: replace with destructuring assignment
2024-04-24 21:29:48 +08:00
Wei Tingjiang
779b747e9e feat: add function and tools support for Gemini (#1358)
* Update model.go

* Support Gemini tool_calls.

* Fix gemini tool calls (also keep support functions).

* Fixed the problem of arguments not being stringified.

Fix panic: candidate.Content.Parts out of range
2024-04-24 21:26:45 +08:00
JustSong
3d149fedf4 chore: do not hardcode context key 2024-04-21 19:43:23 +08:00
JustSong
83517f687c chore: move config key to package ctxkey 2024-04-21 18:55:25 +08:00
JustSong
e30ebda0fe chore: move config key to package ctxkey 2024-04-21 18:55:13 +08:00
JustSong
d87c55f542 chore: render unknown channel type 2024-04-21 18:54:35 +08:00
JustSong
e5b3e37c46 feat: support bot prefix for coze 2024-04-21 18:04:56 +08:00
JustSong
8de489cf06 feat: support coze now 2024-04-21 17:59:57 +08:00
JustSong
d14e4aa01b fix: key is wrongly updated 2024-04-21 17:38:39 +08:00
JustSong
541182102e fix: ignore empty choice response for azure (close #1324) 2024-04-21 16:22:28 +08:00
JustSong
b2679cca65 fix: fix preview completion ratio (close #1326) 2024-04-21 15:57:01 +08:00
JustSong
8572fac7a2 fix: add back chat dropdown item for chatgpt next web (close #1330) 2024-04-21 15:50:35 +08:00
tylinux
a2a00dfbc3 feat: groq support Llama3 now (#1333)
* feat: groq support Llama3 now

* fix: update model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-21 14:53:03 +08:00
JustSong
129282f4a9 fix: fix wrong log type 2024-04-21 14:36:48 +08:00
Laisky.Cai
a873cbd392 fix: logger race (#1339)
- Refactor logger using sync.Once to improve performance
- Initiate log setup in a goroutine to prevent blocking
- Integrate gin.DefaultErrorWriter and gin.DefaultWriter for logging
- Introduce request ID generation for better request tracking
- Simplify setup logic by removing redundant variables and code
2024-04-21 14:35:51 +08:00
JustSong
35ba1da984 fix: fix cannot submit aws claude config (close #1343) 2024-04-21 11:04:34 +08:00
JustSong
2369025842 fix: use prefix to match more json response 2024-04-20 01:15:33 +08:00
JustSong
f452bd481e ci: update ci condition 2024-04-20 00:57:50 +08:00
JustSong
ddee58df36 fix: fix loading 2024-04-20 00:54:34 +08:00
JustSong
520a62e704 docs: update readme 2024-04-20 00:43:07 +08:00
Laisky.Cai
fc9a784950 feat: support aws bedrockruntime claude3 (#1328)
* feat: support aws bedrockruntime claude3

closes #622, closes #749, closes #1300

* fix: convert to aws claude model id

* fix: Update AWS adapter to handle stream completions and calculate usage metrics

Based on the file summaries provided, here are the important bullet points for the commit message:

- Add functionality to handle stream completion events from AWS in the relay/adaptor/aws/main.go file
- Marshall AWS response to OpenAI format and calculate usage metrics in the same file
- Implement a custom render function for streaming events in the same file
- Improve error handling for JSON unmarshalling and marshalling errors in the same file

* fix: Implement AWS handler with usage tracking and error handling

- Implemented streaming response handling for AWS handler
- Set response content type to text/event-stream
- Added error handling for failed marshaling/unmarshaling
- Updated return values to include `relaymodel.ErrorWithStatusCode` and `relaymodel.Usage`
- Improved error handling and response formatting for AWS adaptor

* fix: Refactor AWS Adapter for Improved Model Mapping and Error Handling

* Refactor AWS adapter to improve model management
  - Replace hardcoded model list in `adapter.go` with a function to get models from `awsModelIDMap`
  - Update `GetModelList` function to return model list directly
  - Add `GetChannelName` function to get channel name from `Adaptor` object
* Improve error handling and code organization in main.go
  - Replace switch statement with a map to map AWS model IDs to OpenAI model IDs
  - Return an error if the model is not found in the map
  - Use a single return statement instead of wrapping multiple return statements in the `awsModelID` function
  - Add a new error message for when the model is not found in the map in the `Handler` function

* fix: bug fix

* chore: change variable name & package

* chore: change variable name

* perf: update config related code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-20 00:40:47 +08:00
dependabot[bot]
1a0b039bcf chore(deps): bump golang.org/x/net from 0.17.0 to 0.23.0 (#1335)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.17.0 to 0.23.0.
- [Commits](https://github.com/golang/net/compare/v0.17.0...v0.23.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-19 21:52:51 +08:00
JustSong
7bf61f9165 fix: fix retry not working (close #1314) 2024-04-15 23:09:12 +08:00
JustSong
a10232f43a feat: add gpt-4-turbo support (close #1304) 2024-04-13 11:39:31 +08:00
JustSong
af543ab8ec docs: update readme 2024-04-06 20:50:43 +08:00
JustSong
e086da05b1 feat: able to change gemini version (close #1211) 2024-04-06 20:48:22 +08:00
JustSong
3af4649b52 fix: only check model when request path in whitelist 2024-04-06 20:42:35 +08:00
GAI Group
52c32c0b4a chore: resolve the issue of onclick event scope for custom Lark button (#1281)
chore: Resolve the issue of onclick event scope for custom Lark button
2024-04-06 20:08:05 +08:00
Buer
3fe2863ff7 feat: berry theme update & bug fix (#1282)
* ️ improve: delete google fonts

* ️ improve: Optimized priority input handling in TableRow component.

* 🔖 chore: channel batch add

*  feat: add dark mod

*  feat: support token limit ip range and models

*  feat: add MessagePusher

*  feat: add lark login
2024-04-06 19:44:23 +08:00
JustSong
acf8cb6248 chore: update default nextweb link 2024-04-06 11:47:31 +08:00
JustSong
572fc9ffb8 fix: fix stepfun model ratio & id 2024-04-06 10:43:54 +08:00
GAI Group
569c04acb0 fix: fix Lark icon button style (#1279) 2024-04-06 10:18:59 +08:00
JustSong
961b4108e6 chore: fix refactor caused typo 2024-04-06 02:12:50 +08:00
JustSong
0b8ccb94eb chore: reorganize common package 2024-04-06 02:03:59 +08:00
JustSong
f586ae0ad8 chore: remove helper & util subpackage for relay 2024-04-06 01:50:12 +08:00
JustSong
24ed170e7b chore: reorganize adaptor related package 2024-04-06 01:36:48 +08:00
JustSong
f70506eac1 chore: reorganize relay related package 2024-04-06 01:31:44 +08:00
JustSong
8f4d78e24d chore: reorganize billing related package 2024-04-06 01:26:48 +08:00
JustSong
cd2707692f chore: reorganize billing related package 2024-04-06 01:09:23 +08:00
JustSong
2ab7d25a80 chore: reorganize helper related package 2024-04-06 01:02:35 +08:00
JustSong
f9d914873f chore: reorganize constant related package 2024-04-06 00:44:33 +08:00
JustSong
880e12c855 feat: support cogview-3 2024-04-06 00:30:08 +08:00
JustSong
0cb224e62e chore: fix typo 2024-04-05 23:55:25 +08:00
JustSong
a44fb5d482 fix: fix channel model list is empty 2024-04-05 23:44:57 +08:00
JustSong
eec41849ec chore: fix ali image implementation 2024-04-05 18:25:57 +08:00
Mo
d4347e7a35 feat: support Ali stable-diffusion-xl and wanx-v1 model (#1240)
* Fix ali ConvertRequest function to use baidu keyword

* Support Ali stable-diffusion-xl and wanx-v1 model

* Support Ali stable-diffusion-xl and wanx-v1 model

* Support Ali stable-diffusion-xl and wanx-v1 model

* chore: update ali constants and model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-04-05 18:09:54 +08:00
manjieqi
b50b43eb65 feat: update baidu model name & ratio (#1277) 2024-04-05 17:30:48 +08:00
JustSong
348adc2b02 feat: able to set multiple subnets 2024-04-05 17:25:28 +08:00
JustSong
dcf24b98dc chore: update berry copy 2024-04-05 14:28:38 +08:00
JustSong
af679e04f4 chore: sort channel type for berry 2024-04-05 14:23:39 +08:00
JustSong
93cbca6a9f chore: update show notice duration 2024-04-05 14:14:21 +08:00
JustSong
840ef80d94 fix: do not try to parse model when requesting /v1/models (close #1272) 2024-04-05 12:50:31 +08:00
JustSong
9a2662af0d feat: show token info when quota is not enough (close #1274) 2024-04-05 12:42:14 +08:00
JustSong
77f9e75654 fix: fix IsValidSubnet 2024-04-05 12:40:03 +08:00
JustSong
5b41f57423 feat: support stepfun's models 2024-04-05 12:32:05 +08:00
JustSong
0bb7db0b44 fix: do not detect quota field in error message (close #1276) 2024-04-05 12:11:50 +08:00
JustSong
4d61b9937b feat: support feishu login now 2024-04-05 12:10:43 +08:00
240 changed files with 5967 additions and 2301 deletions

View File

@@ -3,7 +3,7 @@ name: Publish Docker image (amd64, English)
on:
push:
tags:
- '*'
- 'v*.*.*'
workflow_dispatch:
inputs:
name:

View File

@@ -3,7 +3,7 @@ name: Publish Docker image (amd64)
on:
push:
tags:
- '*'
- 'v*.*.*'
workflow_dispatch:
inputs:
name:

View File

@@ -3,7 +3,7 @@ name: Publish Docker image (arm64)
on:
push:
tags:
- '*'
- 'v*.*.*'
- '!*-alpha*'
workflow_dispatch:
inputs:

View File

@@ -5,7 +5,7 @@ permissions:
on:
push:
tags:
- '*'
- 'v*.*.*'
- '!*-alpha*'
workflow_dispatch:
inputs:

View File

@@ -5,7 +5,7 @@ permissions:
on:
push:
tags:
- '*'
- 'v*.*.*'
- '!*-alpha*'
workflow_dispatch:
inputs:

View File

@@ -5,7 +5,7 @@ permissions:
on:
push:
tags:
- '*'
- 'v*.*.*'
- '!*-alpha*'
workflow_dispatch:
inputs:

View File

@@ -65,7 +65,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
## 功能
1. 支持多种大模型:
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude)
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ [x] [Mistral 系列模型](https://mistral.ai/)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
@@ -81,6 +81,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama)
+ [x] [零一万物](https://platform.lingyiwanwu.com/)
+ [x] [阶跃星辰](https://platform.stepfun.com/)
+ [x] [Coze](https://www.coze.com/)
+ [x] [Cohere](https://cohere.com/)
+ [x] [DeepSeek](https://www.deepseek.com/)
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -105,6 +112,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ 支持使用飞书进行授权登录。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
@@ -361,28 +369,29 @@ graph LR
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
11. 例子:`CHANNEL_TEST_FREQUENCY=1440`
12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
14. 请求频率限制:
15. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
15. 编码器缓存设置:
16. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
17. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
18. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)
20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`
21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
23. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
18. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
19. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
20. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`
21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)
22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`
25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -4,6 +4,7 @@ import (
"github.com/songquanpeng/one-api/common/env"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -51,9 +52,9 @@ var EmailDomainWhitelist = []string{
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true"
var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true"
var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true"
var LogConsumeEnabled = true
@@ -66,6 +67,9 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -138,3 +142,5 @@ var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1")

View File

@@ -4,116 +4,3 @@ import "time"
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
const (
RoleGuestUser = 0
RoleCommonUser = 1
RoleAdminUser = 10
RoleRootUser = 100
)
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
TokenStatusExpired = 3
TokenStatusExhausted = 4
)
const (
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
RedemptionCodeStatusDisabled = 2 // also don't use 0
RedemptionCodeStatusUsed = 3 // also don't use 0
)
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
const (
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeDummy
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
}
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

22
common/ctxkey/key.go Normal file
View File

@@ -0,0 +1,22 @@
package ctxkey
const (
Config = "config"
Id = "id"
Username = "username"
Role = "role"
Status = "status"
Channel = "channel"
ChannelId = "channel_id"
SpecificChannelId = "specific_channel_id"
RequestModel = "request_model"
ConvertedRequest = "converted_request"
OriginalModel = "original_model"
Group = "group"
ModelMapping = "model_mapping"
ChannelName = "channel_name"
TokenId = "token_id"
TokenName = "token_name"
BaseURL = "base_url"
AvailableModels = "available_models"
)

View File

@@ -2,16 +2,15 @@ package helper
import (
"fmt"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
"html/template"
"log"
"math/rand"
"net"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
func OpenBrowser(url string) {
@@ -79,31 +78,6 @@ func Bytes2Size(num int64) string {
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter := inter.(type) {
case string:
@@ -128,65 +102,13 @@ func IntMax(a int, b int) int {
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func GenRequestID() string {
return GetTimeString() + GetRandomNumberString(8)
return GetTimeString() + random.GetRandomNumberString(8)
}
func GetResponseID(c *gin.Context) string {
logID := c.GetString(RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID)
}
func Max(a int, b int) int {

5
common/helper/key.go Normal file
View File

@@ -0,0 +1,5 @@
package helper
const (
RequestIdKey = "X-Oneapi-Request-Id"
)

15
common/helper/time.go Normal file
View File

@@ -0,0 +1,15 @@
package helper
import (
"fmt"
"time"
)
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}

View File

@@ -16,7 +16,7 @@ import (
)
// Regex to match data URL pattern
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)

View File

@@ -1,7 +1,3 @@
package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string

View File

@@ -3,15 +3,16 @@ package logger
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"io"
"log"
"os"
"path/filepath"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
)
const (
@@ -21,28 +22,20 @@ const (
loggerError = "ERR"
)
var setupLogLock sync.Mutex
var setupLogWorking bool
var setupLogOnce sync.Once
func SetupLogger() {
if LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
setupLogOnce.Do(func() {
if LogDir != "" {
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
})
}
func SysLog(s string) {
@@ -94,18 +87,13 @@ func logHelper(ctx context.Context, level string, msg string) {
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
id := ctx.Value(helper.RequestIdKey)
if id == nil {
id = helper.GenRequestID()
}
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
if !setupLogWorking {
setupLogWorking = true
go func() {
SetupLogger()
}()
}
SetupLogger()
}
func FatalLog(v ...any) {

View File

@@ -5,14 +5,26 @@ import (
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"net"
"strings"
)
func IsValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet)
return fmt.Errorf("failed to parse subnet: %w", err)
func splitSubnets(subnets string) []string {
res := strings.Split(subnets, ",")
for i := 0; i < len(res); i++ {
res[i] = strings.TrimSpace(res[i])
}
return res
}
func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
func isValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet)
if err != nil {
return fmt.Errorf("failed to parse subnet: %w", err)
}
return nil
}
func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
_, ipNet, err := net.ParseCIDR(subnet)
if err != nil {
logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
@@ -20,3 +32,21 @@ func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
}
return ipNet.Contains(net.ParseIP(ip))
}
func IsValidSubnets(subnets string) error {
for _, subnet := range splitSubnets(subnets) {
if err := isValidSubnet(subnet); err != nil {
return err
}
}
return nil
}
func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
for _, subnet := range splitSubnets(subnets) {
if isIpInSubnet(ctx, ip, subnet) {
return true
}
}
return false
}

View File

@@ -13,7 +13,7 @@ func TestIsIpInSubnet(t *testing.T) {
ip2 := "125.216.250.89"
subnet := "192.168.0.0/24"
Convey("TestIsIpInSubnet", t, func() {
So(IsIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
So(IsIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
})
}

View File

@@ -1,8 +0,0 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

61
common/random/main.go Normal file
View File

@@ -0,0 +1,61 @@
package random
import (
"github.com/google/uuid"
"math/rand"
"strings"
"time"
)
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@@ -1,4 +1,4 @@
package controller
package auth
import (
"bytes"
@@ -7,10 +7,10 @@ import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
user.DisplayName = "GitHub User"
}
user.Email = githubUser.Email
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -152,14 +152,14 @@ func GitHubOAuth(c *gin.Context) {
}
}
if user.Status != common.UserStatusEnabled {
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
controller.SetupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {
@@ -219,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := helper.GetRandomString(12)
state := random.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {

200
controller/auth/lark.go Normal file
View File

@@ -0,0 +1,200 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)
type LarkOAuthResponse struct {
AccessToken string `json:"access_token"`
}
type LarkUser struct {
Name string `json:"name"`
OpenID string `json:"open_id"`
}
func getLarkUserInfoByCode(code string) (*LarkUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.LarkClientId,
"client_secret": config.LarkClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LarkOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
}
var larkUser LarkUser
err = json.NewDecoder(res2.Body).Decode(&larkUser)
if err != nil {
return nil, err
}
return &larkUser, nil
}
func LarkOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LarkBind(c)
return
}
code := c.Query("code")
larkUser, err := getLarkUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LarkId: larkUser.OpenID,
}
if model.IsLarkIdAlreadyTaken(user.LarkId) {
err := user.FillUserByLarkId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
if larkUser.Name != "" {
user.DisplayName = larkUser.Name
} else {
user.DisplayName = "Lark User"
}
user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func LarkBind(c *gin.Context) {
code := c.Query("code")
larkUser, err := getLarkUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LarkId: larkUser.OpenID,
}
if model.IsLarkIdAlreadyTaken(user.LarkId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该飞书账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LarkId = larkUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -1,12 +1,13 @@
package controller
package auth
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -83,8 +84,8 @@ func WeChatAuth(c *gin.Context) {
if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -102,14 +103,14 @@ func WeChatAuth(c *gin.Context) {
}
}
if user.Status != common.UserStatusEnabled {
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
controller.SetupLogin(&user, c)
}
func WeChatBind(c *gin.Context) {
@@ -136,7 +137,7 @@ func WeChatBind(c *gin.Context) {
})
return
}
id := c.GetInt("id")
id := c.GetInt(ctxkey.Id)
user := model.User{
Id: id,
}

View File

@@ -3,6 +3,7 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
@@ -14,13 +15,13 @@ func GetSubscription(c *gin.Context) {
var token *model.Token
var expiredTime int64
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
} else {
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
remainQuota, err = model.GetUserQuota(userId)
if err != nil {
usedQuota, err = model.GetUserUsedQuota(userId)
@@ -64,11 +65,11 @@ func GetUsage(c *gin.Context) {
var err error
var token *model.Token
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota
} else {
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {

View File

@@ -4,12 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
"io"
"net/http"
"strconv"
@@ -96,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := util.HTTPClient.Do(req)
res, err := client.HTTPClient.Do(req)
if err != nil {
return nil, err
}
@@ -204,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
case common.ChannelTypeOpenAI:
case channeltype.OpenAI:
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
case common.ChannelTypeAzure:
case channeltype.Azure:
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
case channeltype.Custom:
baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI:
case channeltype.CloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:
case channeltype.OpenAISB:
return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
case channeltype.AIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
case channeltype.API2GPT:
return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D:
case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel)
default:
return 0, errors.New("尚未实现")
@@ -301,11 +301,11 @@ func updateAllChannelsBalance() error {
return err
}
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
if channel.Status != model.ChannelStatusEnabled {
continue
}
// TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
continue
}
balance, err := updateChannelBalance(channel)

View File

@@ -5,17 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"net/http/httptest"
@@ -25,6 +14,20 @@ import (
"sync"
"time"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/gin-gonic/gin"
)
@@ -53,27 +56,37 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
cfg, _ := channel.LoadConfig()
c.Set(ctxkey.Config, cfg)
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := helper.GetAdaptor(apiType)
meta := meta.GetByContext(c)
apiType := channeltype.ToAPIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := adaptor.GetModelList()[0]
if !strings.Contains(channel.Models, modelName) {
var modelName string
modelList := adaptor.GetModelList()
modelMap := channel.GetModelMapping()
if len(modelList) != 0 {
modelName = modelList[0]
}
if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
if modelMap != nil && modelMap[modelName] != "" {
modelName = modelMap[modelName]
}
}
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil {
return err, nil
}
@@ -81,14 +94,15 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
if err != nil {
return err, nil
}
logger.SysLog(string(jsonData))
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
if resp.StatusCode != http.StatusOK {
err := util.RelayErrorHandler(resp)
if resp != nil && resp.StatusCode != http.StatusOK {
err := controller.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
@@ -171,7 +185,7 @@ func testChannels(notify bool, scope string) error {
}
go func() {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel)
tok := time.Now()
@@ -184,10 +198,10 @@ func testChannels(notify bool, scope string) error {
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error())
}
}
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)

View File

@@ -2,13 +2,13 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"net/http"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName := range common.GroupRatio {
for groupName := range billingratio.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{

View File

@@ -3,6 +3,7 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -41,7 +42,7 @@ func GetUserLogs(c *gin.Context) {
if p < 0 {
p = 0
}
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
@@ -83,7 +84,7 @@ func SearchAllLogs(c *gin.Context) {
func SearchUserLogs(c *gin.Context) {
keyword := c.Query("keyword")
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -122,7 +123,7 @@ func GetLogsStat(c *gin.Context) {
}
func GetLogsSelfStat(c *gin.Context) {
username := c.GetString("username")
username := c.GetString(ctxkey.Username)
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)

View File

@@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) {
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,

View File

@@ -3,13 +3,14 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
"strings"
)
@@ -41,8 +42,8 @@ type OpenAIModels struct {
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
var models []OpenAIModels
var modelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() {
@@ -62,15 +63,15 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ {
if i == constant.APITypeAIProxyLibrary {
for i := 0; i < apitype.Dummy; i++ {
if i == apitype.AIProxyLibrary {
continue
}
adaptor := helper.GetAdaptor(i)
adaptor := relay.GetAdaptor(i)
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{
models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
@@ -82,12 +83,12 @@ func init() {
}
}
for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure {
if channelType == channeltype.Azure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{
models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
@@ -98,14 +99,14 @@ func init() {
})
}
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
modelsMap = make(map[string]OpenAIModels)
for _, model := range models {
modelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
meta := &util.RelayMeta{
for i := 1; i < channeltype.Dummy; i++ {
adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
meta := &meta.Meta{
ChannelType: i,
}
adaptor.Init(meta)
@@ -121,13 +122,20 @@ func DashboardListModels(c *gin.Context) {
})
}
func ListAllModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": models,
})
}
func ListModels(c *gin.Context) {
ctx := c.Request.Context()
var availableModels []string
if c.GetString("available_models") != "" {
availableModels = strings.Split(c.GetString("available_models"), ",")
if c.GetString(ctxkey.AvailableModels) != "" {
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
} else {
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
userGroup, _ := model.CacheGetUserGroup(userId)
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
}
@@ -136,7 +144,7 @@ func ListModels(c *gin.Context) {
modelSet[availableModel] = true
}
availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range openAIModels {
for _, model := range models {
if _, ok := modelSet[model.Id]; ok {
modelSet[model.Id] = false
availableOpenAIModels = append(availableOpenAIModels, model)
@@ -162,7 +170,7 @@ func ListModels(c *gin.Context) {
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
if model, ok := modelsMap[modelId]; ok {
c.JSON(200, model)
} else {
Error := relaymodel.Error{
@@ -179,7 +187,7 @@ func RetrieveModel(c *gin.Context) {
func GetUserAvailableModels(c *gin.Context) {
ctx := c.Request.Context()
id := c.GetInt("id")
id := c.GetInt(ctxkey.Id)
userGroup, err := model.CacheGetUserGroup(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@@ -3,7 +3,9 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -106,9 +108,9 @@ func AddRedemption(c *gin.Context) {
}
var keys []string
for i := 0; i < redemption.Count; i++ {
key := helper.GetUUID()
key := random.GetUUID()
cleanRedemption := model.Redemption{
UserId: c.GetInt("id"),
UserId: c.GetInt(ctxkey.Id),
Name: redemption.Name,
Key: key,
CreatedTime: helper.GetTimestamp(),

View File

@@ -7,31 +7,31 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
)
// https://platform.openai.com/docs/api-reference/chat
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
case constant.RelayModeImagesGenerations:
case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech:
case relaymode.AudioSpeech:
fallthrough
case constant.RelayModeAudioTranslation:
case relaymode.AudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
@@ -41,23 +41,23 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
channelId := c.GetInt("channel_id")
bizErr := relay(c, relayMode)
channelId := c.GetInt(ctxkey.ChannelId)
bizErr := relayHelper(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
return
}
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey)
requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
@@ -66,7 +66,7 @@ func Relay(c *gin.Context) {
for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err)
break
}
logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
@@ -76,13 +76,13 @@ func Relay(c *gin.Context) {
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relay(c, relayMode)
bizErr = relayHelper(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId
channelName := c.GetString("channel_name")
channelName := c.GetString(ctxkey.ChannelName)
go processChannelRelayError(ctx, channelId, channelName, bizErr)
}
if bizErr != nil {
@@ -97,7 +97,7 @@ func Relay(c *gin.Context) {
}
func shouldRetry(c *gin.Context, statusCode int) bool {
if _, ok := c.Get("specific_channel_id"); ok {
if _, ok := c.Get(ctxkey.SpecificChannelId); ok {
return false
}
if statusCode == http.StatusTooManyRequests {
@@ -118,7 +118,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)

View File

@@ -3,17 +3,18 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
)
func GetAllTokens(c *gin.Context) {
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
@@ -38,7 +39,7 @@ func GetAllTokens(c *gin.Context) {
}
func SearchTokens(c *gin.Context) {
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
keyword := c.Query("keyword")
tokens, err := model.SearchUserTokens(userId, keyword)
if err != nil {
@@ -58,7 +59,7 @@ func SearchTokens(c *gin.Context) {
func GetToken(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -83,8 +84,8 @@ func GetToken(c *gin.Context) {
}
func GetTokenStatus(c *gin.Context) {
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
tokenId := c.GetInt(ctxkey.TokenId)
userId := c.GetInt(ctxkey.Id)
token, err := model.GetTokenByIds(tokenId, userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -111,7 +112,7 @@ func validateToken(c *gin.Context, token model.Token) error {
return fmt.Errorf("令牌名称过长")
}
if token.Subnet != nil && *token.Subnet != "" {
err := network.IsValidSubnet(*token.Subnet)
err := network.IsValidSubnets(*token.Subnet)
if err != nil {
return fmt.Errorf("无效的网段:%s", err.Error())
}
@@ -139,9 +140,9 @@ func AddToken(c *gin.Context) {
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
UserId: c.GetInt(ctxkey.Id),
Name: token.Name,
Key: helper.GenerateKey(),
Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
@@ -168,7 +169,7 @@ func AddToken(c *gin.Context) {
func DeleteToken(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
err := model.DeleteTokenById(id, userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -185,7 +186,7 @@ func DeleteToken(c *gin.Context) {
}
func UpdateToken(c *gin.Context) {
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
statusOnly := c.Query("status_only")
token := model.Token{}
err := c.ShouldBindJSON(&token)
@@ -212,15 +213,15 @@ func UpdateToken(c *gin.Context) {
})
return
}
if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
if token.Status == model.TokenStatusEnabled {
if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
})
return
}
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",

View File

@@ -5,7 +5,8 @@ import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -58,11 +59,11 @@ func Login(c *gin.Context) {
})
return
}
setupLogin(&user, c)
SetupLogin(&user, c)
}
// setup session & cookies and then return user info
func setupLogin(user *model.User, c *gin.Context) {
func SetupLogin(user *model.User, c *gin.Context) {
session := sessions.Default(c)
session.Set("id", user.Id)
session.Set("username", user.Username)
@@ -238,8 +239,8 @@ func GetUser(c *gin.Context) {
})
return
}
myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser {
myRole := c.GetInt(ctxkey.Role)
if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权获取同级或更高等级用户的信息",
@@ -255,7 +256,7 @@ func GetUser(c *gin.Context) {
}
func GetUserDashboard(c *gin.Context) {
id := c.GetInt("id")
id := c.GetInt(ctxkey.Id)
now := time.Now()
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
@@ -278,7 +279,7 @@ func GetUserDashboard(c *gin.Context) {
}
func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id")
id := c.GetInt(ctxkey.Id)
user, err := model.GetUserById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -287,7 +288,7 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
user.AccessToken = helper.GetUUID()
user.AccessToken = random.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@@ -314,7 +315,7 @@ func GenerateAccessToken(c *gin.Context) {
}
func GetAffCode(c *gin.Context) {
id := c.GetInt("id")
id := c.GetInt(ctxkey.Id)
user, err := model.GetUserById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -324,7 +325,7 @@ func GetAffCode(c *gin.Context) {
return
}
if user.AffCode == "" {
user.AffCode = helper.GetRandomString(4)
user.AffCode = random.GetRandomString(4)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -342,7 +343,7 @@ func GetAffCode(c *gin.Context) {
}
func GetSelf(c *gin.Context) {
id := c.GetInt("id")
id := c.GetInt(ctxkey.Id)
user, err := model.GetUserById(id, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -387,15 +388,15 @@ func UpdateUser(c *gin.Context) {
})
return
}
myRole := c.GetInt("role")
if myRole <= originUser.Role && myRole != common.RoleRootUser {
myRole := c.GetInt(ctxkey.Role)
if myRole <= originUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
})
return
}
if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
@@ -445,7 +446,7 @@ func UpdateSelf(c *gin.Context) {
}
cleanUser := model.User{
Id: c.GetInt("id"),
Id: c.GetInt(ctxkey.Id),
Username: user.Username,
Password: user.Password,
DisplayName: user.DisplayName,
@@ -509,7 +510,7 @@ func DeleteSelf(c *gin.Context) {
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
if user.Role == common.RoleRootUser {
if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除超级管理员账户",
@@ -611,7 +612,7 @@ func ManageUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser {
if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
@@ -620,8 +621,8 @@ func ManageUser(c *gin.Context) {
}
switch req.Action {
case "disable":
user.Status = common.UserStatusDisabled
if user.Role == common.RoleRootUser {
user.Status = model.UserStatusDisabled
if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法禁用超级管理员用户",
@@ -629,9 +630,9 @@ func ManageUser(c *gin.Context) {
return
}
case "enable":
user.Status = common.UserStatusEnabled
user.Status = model.UserStatusEnabled
case "delete":
if user.Role == common.RoleRootUser {
if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法删除超级管理员用户",
@@ -646,37 +647,37 @@ func ManageUser(c *gin.Context) {
return
}
case "promote":
if myRole != common.RoleRootUser {
if myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "普通管理员用户无法提升其他用户为管理员",
})
return
}
if user.Role >= common.RoleAdminUser {
if user.Role >= model.RoleAdminUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是管理员",
})
return
}
user.Role = common.RoleAdminUser
user.Role = model.RoleAdminUser
case "demote":
if user.Role == common.RoleRootUser {
if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法降级超级管理员用户",
})
return
}
if user.Role == common.RoleCommonUser {
if user.Role == model.RoleCommonUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是普通用户",
})
return
}
user.Role = common.RoleCommonUser
user.Role = model.RoleCommonUser
}
if err := user.Update(false); err != nil {
@@ -730,7 +731,7 @@ func EmailBind(c *gin.Context) {
})
return
}
if user.Role == common.RoleRootUser {
if user.Role == model.RoleRootUser {
config.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{

86
go.mod
View File

@@ -1,70 +1,84 @@
module github.com/songquanpeng/one-api
// +heroku goVersion go1.18
go 1.18
go 1.20
require (
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/gin-contrib/cors v1.7.1
github.com/gin-contrib/gzip v1.0.0
github.com/gin-contrib/sessions v1.0.0
github.com/gin-contrib/static v1.1.1
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.14.0
github.com/go-playground/validator/v10 v10.19.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.1
github.com/jinzhu/copier v0.4.0
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.6
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.22.0
golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5
gorm.io/gorm v1.25.9
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.11.5 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.3 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/gorilla/context v1.1.2 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/gorilla/sessions v1.2.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.7.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

240
go.sum
View File

@@ -1,136 +1,133 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k=
github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw=
github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg=
github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE=
github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs=
github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps=
github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk=
github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE=
github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y=
github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Swm1U=
github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs=
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k=
github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY=
github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
@@ -138,81 +135,54 @@ github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8=
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8=
gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -12,7 +12,7 @@ import (
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/router"
"os"
"strconv"
@@ -71,7 +71,7 @@ func main() {
}
if config.MemoryCacheEnabled {
logger.SysLog("memory cache enabled")
logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
logger.SysLog(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache()
}
if config.MemoryCacheEnabled {

View File

@@ -4,8 +4,8 @@ import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
@@ -45,7 +45,7 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
@@ -72,19 +72,19 @@ func authHelper(c *gin.Context, minRole int) {
func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser)
authHelper(c, model.RoleCommonUser)
}
}
func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleAdminUser)
authHelper(c, model.RoleAdminUser)
}
}
func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleRootUser)
authHelper(c, model.RoleRootUser)
}
}
@@ -102,7 +102,7 @@ func TokenAuth() func(c *gin.Context) {
return
}
if token.Subnet != nil && *token.Subnet != "" {
if !network.IsIpInSubnet(ctx, c.ClientIP(), *token.Subnet) {
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s当前 ip%s", *token.Subnet, c.ClientIP()))
return
}
@@ -117,24 +117,24 @@ func TokenAuth() func(c *gin.Context) {
return
}
requestModel, err := getRequestModel(c)
if err != nil {
if err != nil && shouldCheckModel(c) {
abortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
c.Set("request_model", requestModel)
c.Set(ctxkey.RequestModel, requestModel)
if token.Models != nil && *token.Models != "" {
c.Set("available_models", *token.Models)
c.Set(ctxkey.AvailableModels, *token.Models)
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
return
}
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
c.Set(ctxkey.Id, token.UserId)
c.Set(ctxkey.TokenId, token.Id)
c.Set(ctxkey.TokenName, token.Name)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])
c.Set(ctxkey.SpecificChannelId, parts[1])
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
@@ -143,3 +143,19 @@ func TokenAuth() func(c *gin.Context) {
c.Next()
}
}
func shouldCheckModel(c *gin.Context) bool {
if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
return true
}
return false
}

View File

@@ -3,9 +3,10 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"net/http"
"strconv"
)
@@ -16,12 +17,12 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
userId := c.GetInt(ctxkey.Id)
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
c.Set(ctxkey.Group, userGroup)
var requestModel string
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
channelId, ok := c.Get(ctxkey.SpecificChannelId)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -33,12 +34,12 @@ func Distribute() func(c *gin.Context) {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {
if channel.Status != model.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
requestModel := c.GetString("request_model")
requestModel = c.GetString(ctxkey.RequestModel)
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil {
@@ -57,28 +58,36 @@ func Distribute() func(c *gin.Context) {
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping())
c.Set("original_model", modelName) // for retry
c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name)
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
cfg, _ := channel.LoadConfig()
// this is for backward compatibility
switch channel.Type {
case common.ChannelTypeAzure:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
c.Set(common.ConfigKeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
case channeltype.Azure:
if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other
}
case channeltype.Xunfei:
if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other
}
case channeltype.Gemini:
if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other
}
case channeltype.AIProxyLibrary:
if cfg.LibraryID == "" {
cfg.LibraryID = channel.Other
}
case channeltype.Ali:
if cfg.Plugin == "" {
cfg.Plugin = channel.Other
}
}
c.Set(ctxkey.Config, cfg)
}

View File

@@ -3,14 +3,14 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/helper"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[logger.RequestIdKey].(string)
requestID = param.Keys[helper.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@@ -4,16 +4,15 @@ import (
"context"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := helper.GenRequestID()
c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Set(helper.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(logger.RequestIdKey, id)
c.Header(helper.RequestIdKey, id)
c.Next()
}
}

View File

@@ -12,7 +12,7 @@ import (
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)),
"type": "one_api_error",
},
})

View File

@@ -57,7 +57,7 @@ func (channel *Channel) AddAbilities() error {
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Enabled: channel.Status == ChannelStatusEnabled,
Priority: channel.Priority,
}
abilities = append(abilities, ability)

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"math/rand"
"sort"
"strconv"
@@ -172,7 +173,7 @@ var channelSyncLock sync.RWMutex
func InitChannelCache() {
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
@@ -247,7 +248,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior
idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels))
idx = random.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil

View File

@@ -3,13 +3,19 @@ package model
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
)
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
@@ -32,6 +38,16 @@ type Channel struct {
Config string `json:"config"`
}
type ChannelConfig struct {
Region string `json:"region,omitempty"`
SK string `json:"sk,omitempty"`
AK string `json:"ak,omitempty"`
UserID string `json:"user_id,omitempty"`
APIVersion string `json:"api_version,omitempty"`
LibraryID string `json:"library_id,omitempty"`
Plugin string `json:"plugin,omitempty"`
}
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
var channels []*Channel
var err error
@@ -39,7 +55,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
case "all":
err = DB.Order("id desc").Find(&channels).Error
case "disabled":
err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error
err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
@@ -155,20 +171,20 @@ func (channel *Channel) Delete() error {
return err
}
func (channel *Channel) LoadConfig() (map[string]string, error) {
func (channel *Channel) LoadConfig() (ChannelConfig, error) {
var cfg ChannelConfig
if channel.Config == "" {
return nil, nil
return cfg, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
return cfg, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
if err != nil {
logger.SysError("failed to update ability status: " + err.Error())
}
@@ -199,6 +215,6 @@ func DeleteChannelByStatus(status int64) (int64, error) {
}
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
)

View File

@@ -7,6 +7,7 @@ import (
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
@@ -31,10 +32,10 @@ func CreateRootAccountIfNeed() error {
rootUser := User{
Username: "root",
Password: hashedPassword,
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
Role: RoleRootUser,
Status: UserStatusEnabled,
DisplayName: "Root User",
AccessToken: helper.GetUUID(),
AccessToken: random.GetUUID(),
Quota: 500000000000000,
}
DB.Create(&rootUser)
@@ -44,7 +45,7 @@ func CreateRootAccountIfNeed() error {
Id: 1,
UserId: rootUser.Id,
Key: config.InitialRootToken,
Status: common.TokenStatusEnabled,
Status: TokenStatusEnabled,
Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),

View File

@@ -1,9 +1,9 @@
package model
import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"strconv"
"strings"
"time"
@@ -66,9 +66,9 @@ func InitOptionMap() {
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
@@ -82,7 +82,7 @@ func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
if option.Key == "ModelRatio" {
option.Value = common.AddNewMissingRatio(option.Value)
option.Value = billingratio.AddNewMissingRatio(option.Value)
}
err := updateOptionMap(option.Key, option.Value)
if err != nil {
@@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) {
config.GitHubClientId = value
case "GitHubClientSecret":
config.GitHubClientSecret = value
case "LarkClientId":
config.LarkClientId = value
case "LarkClientSecret":
config.LarkClientSecret = value
case "Footer":
config.Footer = value
case "SystemName":
@@ -205,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) {
case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
err = billingratio.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
err = billingratio.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
err = billingratio.UpdateCompletionRatioByJSONString(value)
case "TopUpLink":
config.TopUpLink = value
case "ChatLink":

View File

@@ -8,6 +8,12 @@ import (
"gorm.io/gorm"
)
const (
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
RedemptionCodeStatusDisabled = 2 // also don't use 0
RedemptionCodeStatusUsed = 3 // also don't use 0
)
type Redemption struct {
Id int `json:"id"`
UserId int `json:"user_id"`
@@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
if err != nil {
return errors.New("无效的兑换码")
}
if redemption.Status != common.RedemptionCodeStatusEnabled {
if redemption.Status != RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
@@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
return err
}
redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed
redemption.Status = RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err
})

View File

@@ -11,6 +11,13 @@ import (
"gorm.io/gorm"
)
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
TokenStatusExpired = 3
TokenStatusExhausted = 4
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
@@ -62,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
return nil, errors.New("令牌验证失败")
}
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
if token.Status == TokenStatusExhausted {
return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id)
} else if token.Status == TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
if token.Status != TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
token.Status = TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
@@ -83,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted
token.Status = TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())

View File

@@ -6,12 +6,25 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
"strings"
)
const (
RoleGuestUser = 0
RoleCommonUser = 1
RoleAdminUser = 10
RoleRootUser = 100
)
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
@@ -24,6 +37,7 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -41,21 +55,21 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
@@ -107,8 +121,8 @@ func (user *User) Insert(inviterId int) error {
}
}
user.Quota = config.QuotaForNewUser
user.AccessToken = helper.GetUUID()
user.AffCode = helper.GetRandomString(4)
user.AccessToken = random.GetUUID()
user.AffCode = random.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
@@ -137,9 +151,9 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
if user.Status == common.UserStatusDisabled {
if user.Status == UserStatusDisabled {
blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled {
} else if user.Status == UserStatusEnabled {
blacklist.UnbanUser(user.Id)
}
err = DB.Model(user).Updates(user).Error
@@ -151,8 +165,8 @@ func (user *User) Delete() error {
return errors.New("id 为空!")
}
blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID())
user.Status = common.UserStatusDeleted
user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
user.Status = UserStatusDeleted
err := DB.Model(user).Updates(user).Error
return err
}
@@ -176,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) {
}
}
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
if !okay || user.Status != UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
return nil
@@ -206,6 +220,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByLarkId() error {
if user.LarkId == "" {
return errors.New("lark id 为空!")
}
DB.Where(User{LarkId: user.LarkId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -234,6 +256,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -260,7 +286,7 @@ func IsAdmin(userId int) bool {
logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
return user.Role >= RoleAdminUser
}
func IsUserEnabled(userId int) (bool, error) {
@@ -272,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) {
if err != nil {
return false, err
}
return user.Status == common.UserStatusEnabled, nil
return user.Status == UserStatusEnabled, nil
}
func ValidateAccessToken(token string) (user *User) {
@@ -345,7 +371,7 @@ func decreaseUserQuota(id int, quota int64) (err error) {
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
return email
}

View File

@@ -2,7 +2,6 @@ package monitor
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
@@ -29,7 +28,7 @@ func notifyRootUser(subject string, content string) {
// DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("渠道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
@@ -37,7 +36,7 @@ func DisableChannel(channelId int, channelName string, reason string) {
}
func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道(#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
@@ -47,7 +46,7 @@ func MetricDisableChannel(channelId int, successRate float64) {
// EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)

62
monitor/manage.go Normal file
View File

@@ -0,0 +1,62 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
"net/http"
"strings"
)
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *model.Error) bool {
if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}

60
relay/adaptor.go Normal file
View File

@@ -0,0 +1,60 @@
package relay
import (
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
"github.com/songquanpeng/one-api/relay/adaptor/ali"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws"
"github.com/songquanpeng/one-api/relay/adaptor/baidu"
"github.com/songquanpeng/one-api/relay/adaptor/cloudflare"
"github.com/songquanpeng/one-api/relay/adaptor/cohere"
"github.com/songquanpeng/one-api/relay/adaptor/coze"
"github.com/songquanpeng/one-api/relay/adaptor/deepl"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
"github.com/songquanpeng/one-api/relay/adaptor/zhipu"
"github.com/songquanpeng/one-api/relay/apitype"
)
func GetAdaptor(apiType int) adaptor.Adaptor {
switch apiType {
case apitype.AIProxyLibrary:
return &aiproxy.Adaptor{}
case apitype.Ali:
return &ali.Adaptor{}
case apitype.Anthropic:
return &anthropic.Adaptor{}
case apitype.AwsClaude:
return &aws.Adaptor{}
case apitype.Baidu:
return &baidu.Adaptor{}
case apitype.Gemini:
return &gemini.Adaptor{}
case apitype.OpenAI:
return &openai.Adaptor{}
case apitype.PaLM:
return &palm.Adaptor{}
case apitype.Tencent:
return &tencent.Adaptor{}
case apitype.Xunfei:
return &xunfei.Adaptor{}
case apitype.Zhipu:
return &zhipu.Adaptor{}
case apitype.Ollama:
return &ollama.Adaptor{}
case apitype.Coze:
return &coze.Adaptor{}
case apitype.Cohere:
return &cohere.Adaptor{}
case apitype.Cloudflare:
return &cloudflare.Adaptor{}
case apitype.DeepL:
return &deepl.Adaptor{}
}
return nil
}

View File

@@ -4,27 +4,27 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
meta *meta.Meta
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
@@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
aiProxyLibraryRequest.LibraryId = a.meta.Config.LibraryID
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {

View File

@@ -1,6 +1,6 @@
package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai"
import "github.com/songquanpeng/one-api/relay/adaptor/openai"
var ModelList = []string{""}

View File

@@ -8,7 +8,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
@@ -53,7 +54,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -66,7 +67,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "",
@@ -78,7 +79,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: response.Model,

View File

@@ -0,0 +1,105 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
)
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct {
meta *meta.Meta
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
fullRequestURL := ""
switch meta.Mode {
case relaymode.Embeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
case relaymode.ImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
default:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
if meta.IsStream {
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("X-DashScope-SSE", "enable")
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.Mode == relaymode.ImagesGenerations {
req.Header.Set("X-DashScope-Async", "enable")
}
if a.meta.Config.Plugin != "" {
req.Header.Set("X-DashScope-Plugin", a.meta.Config.Plugin)
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case relaymode.Embeddings:
aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
return aliEmbeddingRequest, nil
default:
aliRequest := ConvertRequest(*request)
return aliRequest, nil
}
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aliRequest := ConvertImageRequest(*request)
return aliRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ali"
}

View File

@@ -3,4 +3,5 @@ package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
}

192
relay/adaptor/ali/image.go Normal file
View File

@@ -0,0 +1,192 @@
package ali
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"time"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse TaskResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
logger.SysError("aliAsyncTask err: " + string(responseBody))
return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}
func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
var aliResponse TaskResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.SysError("aliAsyncTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response TaskResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
waitSeconds := 2
step := 0
maxStep := 20
var taskResponse TaskResponse
var responseBody []byte
for {
step++
rsp, err, body := asyncTask(taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
imageResponse := openai.ImageResponse{
Created: helper.GetTimestamp(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
// 读取 data.Url 的图片数据并转存到 b64Json
imageData, err := getImageData(data.Url)
if err != nil {
// 处理获取图片数据失败的情况
logger.SysError("getImageData Error getting image data: " + err.Error())
continue
}
// 将图片数据转为 Base64 编码的字符串
b64Json = Base64Encode(imageData)
} else {
// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, openai.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func getImageData(url string) ([]byte, error) {
response, err := http.Get(url)
if err != nil {
return nil, err
}
defer response.Body.Close()
imageData, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
return imageData, nil
}
func Base64Encode(data []byte) string {
b64Json := base64.StdEncoding.EncodeToString(data)
return b64Json
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
@@ -66,6 +66,17 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingReque
}
}
func ConvertImageRequest(request model.ImageRequest) *ImageRequest {
var imageRequest ImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)

154
relay/adaptor/ali/model.go Normal file
View File

@@ -0,0 +1,154 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type ImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type TaskResponse struct {
StatusCode int `json:"status_code,omitempty"`
RequestId string `json:"request_id,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Output struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Results []struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
} `json:"results,omitempty"`
TaskMetrics struct {
Total int `json:"TOTAL,omitempty"`
Succeeded int `json:"SUCCEEDED,omitempty"`
Failed int `json:"FAILED,omitempty"`
} `json:"task_metrics,omitempty"`
} `json:"output,omitempty"`
Usage Usage `json:"usage"`
}
type Header struct {
Action string `json:"action,omitempty"`
Streaming string `json:"streaming,omitempty"`
TaskID string `json:"task_id,omitempty"`
Event string `json:"event,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Attributes any `json:"attributes,omitempty"`
}
type Payload struct {
Model string `json:"model,omitempty"`
Task string `json:"task,omitempty"`
TaskGroup string `json:"task_group,omitempty"`
Function string `json:"function,omitempty"`
Parameters struct {
SampleRate int `json:"sample_rate,omitempty"`
Rate float64 `json:"rate,omitempty"`
Format string `json:"format,omitempty"`
} `json:"parameters,omitempty"`
Input struct {
Text string `json:"text,omitempty"`
} `json:"input,omitempty"`
Usage struct {
Characters int `json:"characters,omitempty"`
} `json:"usage,omitempty"`
}
type WSSMessage struct {
Header Header `json:"header,omitempty"`
Payload Payload `json:"payload,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -4,9 +4,9 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@@ -14,16 +14,16 @@ import (
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
@@ -41,11 +41,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {

View File

@@ -4,16 +4,17 @@ import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func stopReasonClaude2OpenAI(reason *string) string {
@@ -91,7 +92,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
}
// https://docs.anthropic.com/claude/reference/messages-streaming
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
@@ -129,7 +130,7 @@ func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
return &openaiResponse, response
}
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
@@ -176,10 +177,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data: ") {
if !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimPrefix(data, "data:")
dataChan <- data
}
stopChan <- true
@@ -192,14 +193,14 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
data = strings.TrimSpace(data)
var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
@@ -254,7 +255,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse := ResponseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,

View File

@@ -0,0 +1,82 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct {
meta *meta.Meta
awsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
a.awsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, a.awsClient)
} else {
err, usage = Handler(c, a.awsClient, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}

191
relay/adaptor/aws/main.go Normal file
View File

@@ -0,0 +1,191 @@
// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
var id string
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
claudeResp := new(anthropic.StreamResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}

View File

@@ -0,0 +1,17 @@
package aws
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
// Request is the request to AWS Claude
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
type Request struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
Messages []anthropic.Message `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
}

View File

@@ -3,25 +3,25 @@ package baidu
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(meta.ActualModelName, "Embedding") {
@@ -44,17 +44,25 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-3.5-4K-0205":
suffix += "ernie-3.5-4k-0205"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K":
case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
@@ -81,8 +89,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
@@ -92,7 +100,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
@@ -101,16 +109,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)

View File

@@ -2,15 +2,15 @@ package baidu
var ModelList = []string{
"ERNIE-4.0-8K",
"ERNIE-Bot-8K-0922",
"ERNIE-3.5-8K",
"ERNIE-Lite-8K-0922",
"ERNIE-Speed-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Lite-8K",
"ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"Embedding-V1",

View File

@@ -8,10 +8,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
@@ -305,7 +305,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := util.ImpatientHTTPClient.Do(req)
res, err := client.ImpatientHTTPClient.Do(req)
if err != nil {
return nil, err
}

View File

@@ -0,0 +1,66 @@
package cloudflare
import (
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type Adaptor struct {
meta *meta.Meta
}
// ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertImageRequest implements adaptor.Adaptor.
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp, meta.PromptTokens, meta.ActualModelName)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "cloudflare"
}

View File

@@ -0,0 +1,36 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
"@cf/deepseek-ai/deepseek-math-7b-base",
"@cf/deepseek-ai/deepseek-math-7b-instruct",
"@cf/thebloke/discolm-german-7b-v1-awq",
"@cf/tiiuae/falcon-7b-instruct",
"@cf/google/gemma-2b-it-lora",
"@hf/google/gemma-7b-it",
"@cf/google/gemma-7b-it-lora",
"@hf/nousresearch/hermes-2-pro-mistral-7b",
"@hf/thebloke/llama-2-13b-chat-awq",
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
"@cf/meta/llama-3-8b-instruct",
"@hf/thebloke/llamaguard-7b-awq",
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
"@hf/mistralai/mistral-7b-instruct-v0.2",
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
"@hf/thebloke/neural-chat-7b-v3-1-awq",
"@cf/openchat/openchat-3.5-0106",
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
"@cf/microsoft/phi-2",
"@cf/qwen/qwen1.5-0.5b-chat",
"@cf/qwen/qwen1.5-1.8b-chat",
"@cf/qwen/qwen1.5-14b-chat-awq",
"@cf/qwen/qwen1.5-7b-chat-awq",
"@cf/defog/sqlcoder-7b-2",
"@hf/nexusflow/starling-lm-7b-beta",
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
"@hf/thebloke/zephyr-7b-beta-awq",
}

View File

@@ -0,0 +1,152 @@
package cloudflare
import (
"bufio"
"bytes"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
lastMessage := textRequest.Messages[len(textRequest.Messages)-1]
return &Request{
MaxTokens: textRequest.MaxTokens,
Prompt: lastMessage.StringContent(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cloudflareResponse.Result.Response,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
Created: helper.GetTimestamp(),
}
return &openaiResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
responseModel := c.GetString("original_model")
var responseText string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var cloudflareResponse StreamResponse
err := json.Unmarshal([]byte(data), &cloudflareResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
if response == nil {
return true
}
responseText += cloudflareResponse.Response
response.Id = id
response.Model = responseModel
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
return nil, usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cloudflareResponse Response
err = json.Unmarshal(responseBody, &cloudflareResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse)
fullTextResponse.Model = modelName
usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens)
fullTextResponse.Usage = *usage
fullTextResponse.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -0,0 +1,25 @@
package cloudflare
type Request struct {
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
type Result struct {
Response string `json:"response"`
}
type Response struct {
Result Result `json:"result"`
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
}
type StreamResponse struct {
Response string `json:"response"`
}

View File

@@ -0,0 +1,64 @@
package cohere
import (
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type Adaptor struct{}
// ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertImageRequest implements adaptor.Adaptor.
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1/chat", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "Cohere"
}

View File

@@ -0,0 +1,14 @@
package cohere
var ModelList = []string{
"command", "command-nightly",
"command-light", "command-light-nightly",
"command-r", "command-r-plus",
}
func init() {
num := len(ModelList)
for i := 0; i < num; i++ {
ModelList = append(ModelList, ModelList[i]+"-internet")
}
}

View File

@@ -0,0 +1,241 @@
package cohere
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
var (
WebSearchConnector = Connector{ID: "web-search"}
)
func stopReasonCohere2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "COMPLETE":
return "stop"
default:
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
cohereRequest := Request{
Model: textRequest.Model,
Message: "",
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
P: textRequest.TopP,
K: textRequest.TopK,
Stream: textRequest.Stream,
FrequencyPenalty: textRequest.FrequencyPenalty,
PresencePenalty: textRequest.FrequencyPenalty,
Seed: int(textRequest.Seed),
}
if cohereRequest.Model == "" {
cohereRequest.Model = "command-r"
}
if strings.HasSuffix(cohereRequest.Model, "-internet") {
cohereRequest.Model = strings.TrimSuffix(cohereRequest.Model, "-internet")
cohereRequest.Connectors = append(cohereRequest.Connectors, WebSearchConnector)
}
for _, message := range textRequest.Messages {
if message.Role == "user" {
cohereRequest.Message = message.Content.(string)
} else {
var role string
if message.Role == "assistant" {
role = "CHATBOT"
} else if message.Role == "system" {
role = "SYSTEM"
} else {
role = "USER"
}
cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatMessage{
Role: role,
Message: message.Content.(string),
})
}
}
return &cohereRequest
}
func StreamResponseCohere2OpenAI(cohereResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var finishReason string
switch cohereResponse.EventType {
case "stream-start":
return nil, nil
case "text-generation":
responseText += cohereResponse.Text
case "stream-end":
usage := cohereResponse.Response.Meta.Tokens
response = &Response{
Meta: Meta{
Tokens: Usage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
},
},
}
finishReason = *cohereResponse.Response.FinishReason
default:
return nil, nil
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
if finishReason != "" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
}
func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cohereResponse.Text,
Name: nil,
},
FinishReason: stopReasonCohere2OpenAI(cohereResponse.FinishReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", cohereResponse.ResponseID),
Model: "model",
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var usage model.Usage
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var cohereResponse StreamResponse
err := json.Unmarshal([]byte(data), &cohereResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
if meta != nil {
usage.PromptTokens += meta.Meta.Tokens.InputTokens
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
return true
}
if response == nil {
return true
}
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
response.Model = c.GetString("original_model")
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResponse Response
err = json.Unmarshal(responseBody, &cohereResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if cohereResponse.ResponseID == "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: cohereResponse.Message,
Type: cohereResponse.Message,
Param: "",
Code: resp.StatusCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := ResponseCohere2OpenAI(&cohereResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: cohereResponse.Meta.Tokens.InputTokens,
CompletionTokens: cohereResponse.Meta.Tokens.OutputTokens,
TotalTokens: cohereResponse.Meta.Tokens.InputTokens + cohereResponse.Meta.Tokens.OutputTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,147 @@
package cohere
type Request struct {
Message string `json:"message" required:"true"`
Model string `json:"model,omitempty"` // 默认值为"command-r"
Stream bool `json:"stream,omitempty"` // 默认值为false
Preamble string `json:"preamble,omitempty"`
ChatHistory []ChatMessage `json:"chat_history,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
Connectors []Connector `json:"connectors,omitempty"`
Documents []Document `json:"documents,omitempty"`
Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3
MaxTokens int `json:"max_tokens,omitempty"`
MaxInputTokens int `json:"max_input_tokens,omitempty"`
K int `json:"k,omitempty"` // 默认值为0
P float64 `json:"p,omitempty"` // 默认值为0.75
Seed int `json:"seed,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
Tools []Tool `json:"tools,omitempty"`
ToolResults []ToolResult `json:"tool_results,omitempty"`
}
type ChatMessage struct {
Role string `json:"role" required:"true"`
Message string `json:"message" required:"true"`
}
type Tool struct {
Name string `json:"name" required:"true"`
Description string `json:"description" required:"true"`
ParameterDefinitions map[string]ParameterSpec `json:"parameter_definitions"`
}
type ParameterSpec struct {
Description string `json:"description"`
Type string `json:"type" required:"true"`
Required bool `json:"required"`
}
type ToolResult struct {
Call ToolCall `json:"call"`
Outputs []map[string]interface{} `json:"outputs"`
}
type ToolCall struct {
Name string `json:"name" required:"true"`
Parameters map[string]interface{} `json:"parameters" required:"true"`
}
type StreamResponse struct {
IsFinished bool `json:"is_finished"`
EventType string `json:"event_type"`
GenerationID string `json:"generation_id,omitempty"`
SearchQueries []*SearchQuery `json:"search_queries,omitempty"`
SearchResults []*SearchResult `json:"search_results,omitempty"`
Documents []*Document `json:"documents,omitempty"`
Text string `json:"text,omitempty"`
Citations []*Citation `json:"citations,omitempty"`
Response *Response `json:"response,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
type SearchQuery struct {
Text string `json:"text"`
GenerationID string `json:"generation_id"`
}
type SearchResult struct {
SearchQuery *SearchQuery `json:"search_query"`
DocumentIDs []string `json:"document_ids"`
Connector *Connector `json:"connector"`
}
type Connector struct {
ID string `json:"id"`
}
type Document struct {
ID string `json:"id"`
Snippet string `json:"snippet"`
Timestamp string `json:"timestamp"`
Title string `json:"title"`
URL string `json:"url"`
}
type Citation struct {
Start int `json:"start"`
End int `json:"end"`
Text string `json:"text"`
DocumentIDs []string `json:"document_ids"`
}
type Response struct {
ResponseID string `json:"response_id"`
Text string `json:"text"`
GenerationID string `json:"generation_id"`
ChatHistory []*Message `json:"chat_history"`
FinishReason *string `json:"finish_reason"`
Meta Meta `json:"meta"`
Citations []*Citation `json:"citations"`
Documents []*Document `json:"documents"`
SearchResults []*SearchResult `json:"search_results"`
SearchQueries []*SearchQuery `json:"search_queries"`
Message string `json:"message"`
}
type Message struct {
Role string `json:"role"`
Message string `json:"message"`
}
type Version struct {
Version string `json:"version"`
}
type Units struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type ChatEntry struct {
Role string `json:"role"`
Message string `json:"message"`
}
type Meta struct {
APIVersion APIVersion `json:"api_version"`
BilledUnits BilledUnits `json:"billed_units"`
Tokens Usage `json:"tokens"`
}
type APIVersion struct {
Version string `json:"version"`
}
type BilledUnits struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -1,15 +1,16 @@
package channel
package adaptor
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"
)
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
@@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela
}
}
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
@@ -38,7 +39,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return nil, err
}

View File

@@ -0,0 +1,75 @@
package coze
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type Adaptor struct {
meta *meta.Meta
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/open_api/v2/chat", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
request.User = a.meta.Config.UserID
return ConvertRequest(*request), nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
var responseText *string
if meta.IsStream {
err, responseText = StreamHandler(c, resp)
} else {
err, responseText = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
if responseText != nil {
usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.PromptTokens)
} else {
usage = &model.Usage{}
}
usage.PromptTokens = meta.PromptTokens
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "coze"
}

View File

@@ -0,0 +1,5 @@
package contenttype
const (
Text = "text"
)

View File

@@ -0,0 +1,7 @@
package event
const (
Message = "message"
Done = "done"
Error = "error"
)

View File

@@ -0,0 +1,6 @@
package messagetype
const (
Answer = "answer"
FollowUp = "follow_up"
)

View File

@@ -0,0 +1,3 @@
package coze
var ModelList = []string{}

View File

@@ -0,0 +1,10 @@
package coze
import "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/event"
func event2StopReason(e *string) string {
if e == nil || *e == event.Message {
return ""
}
return "stop"
}

215
relay/adaptor/coze/main.go Normal file
View File

@@ -0,0 +1,215 @@
package coze
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
// https://www.coze.com/open
func stopReasonCoze2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
cozeRequest := Request{
Stream: textRequest.Stream,
User: textRequest.User,
BotId: strings.TrimPrefix(textRequest.Model, "bot-"),
}
for i, message := range textRequest.Messages {
if i == len(textRequest.Messages)-1 {
cozeRequest.Query = message.StringContent()
continue
}
cozeMessage := Message{
Role: message.Role,
Content: message.StringContent(),
}
cozeRequest.ChatHistory = append(cozeRequest.ChatHistory, cozeMessage)
}
return &cozeRequest
}
func StreamResponseCoze2OpenAI(cozeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var stopReason string
var choice openai.ChatCompletionsStreamResponseChoice
if cozeResponse.Message != nil {
if cozeResponse.Message.Type != messagetype.Answer {
return nil, nil
}
choice.Delta.Content = cozeResponse.Message.Content
}
choice.Delta.Role = "assistant"
finishReason := stopReasonCoze2OpenAI(&stopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
openaiResponse.Id = cozeResponse.ConversationId
return &openaiResponse, response
}
func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse {
var responseText string
for _, message := range cozeResponse.Messages {
if message.Type == messagetype.Answer {
responseText = message.Content
break
}
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", cozeResponse.ConversationId),
Model: "coze-bot",
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) {
var responseText string
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 {
continue
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var modelName string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var cozeResponse StreamResponse
err := json.Unmarshal([]byte(data), &cozeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
if response == nil {
return true
}
for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *string) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cozeResponse Response
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if cozeResponse.Code != 0 {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: cozeResponse.Msg,
Code: cozeResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := ResponseCoze2OpenAI(&cozeResponse)
fullTextResponse.Model = modelName
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
var responseText string
if len(fullTextResponse.Choices) > 0 {
responseText = fullTextResponse.Choices[0].Message.StringContent()
}
return nil, &responseText
}

View File

@@ -0,0 +1,38 @@
package coze
type Message struct {
Role string `json:"role"`
Type string `json:"type"`
Content string `json:"content"`
ContentType string `json:"content_type"`
}
type ErrorInformation struct {
Code int `json:"code"`
Msg string `json:"msg"`
}
type Request struct {
ConversationId string `json:"conversation_id,omitempty"`
BotId string `json:"bot_id"`
User string `json:"user"`
Query string `json:"query"`
ChatHistory []Message `json:"chat_history,omitempty"`
Stream bool `json:"stream"`
}
type Response struct {
ConversationId string `json:"conversation_id,omitempty"`
Messages []Message `json:"messages,omitempty"`
Code int `json:"code,omitempty"`
Msg string `json:"msg,omitempty"`
}
type StreamResponse struct {
Event string `json:"event,omitempty"`
Message *Message `json:"message,omitempty"`
IsFinish bool `json:"is_finish,omitempty"`
Index int `json:"index,omitempty"`
ConversationId string `json:"conversation_id,omitempty"`
ErrorInformation *ErrorInformation `json:"error_information,omitempty"`
}

View File

@@ -0,0 +1,73 @@
package deepl
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type Adaptor struct {
meta *meta.Meta
promptText string
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v2/translate", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "DeepL-Auth-Key "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
convertedRequest, text := ConvertRequest(*request)
a.promptText = text
return convertedRequest, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err = StreamHandler(c, resp, meta.ActualModelName)
} else {
err = Handler(c, resp, meta.ActualModelName)
}
promptTokens := len(a.promptText)
usage = &model.Usage{
PromptTokens: promptTokens,
TotalTokens: promptTokens,
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "deepl"
}

View File

@@ -0,0 +1,9 @@
package deepl
// https://developers.deepl.com/docs/api-reference/glossaries
var ModelList = []string{
"deepl-zh",
"deepl-en",
"deepl-ja",
}

View File

@@ -0,0 +1,11 @@
package deepl
import "strings"
func parseLangFromModelName(modelName string) string {
parts := strings.Split(modelName, "-")
if len(parts) == 1 {
return "ZH"
}
return parts[1]
}

137
relay/adaptor/deepl/main.go Normal file
View File

@@ -0,0 +1,137 @@
package deepl
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/constant/finishreason"
"github.com/songquanpeng/one-api/relay/constant/role"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
// https://developers.deepl.com/docs/getting-started/your-first-api-request
func ConvertRequest(textRequest model.GeneralOpenAIRequest) (*Request, string) {
var text string
if len(textRequest.Messages) != 0 {
text = textRequest.Messages[len(textRequest.Messages)-1].StringContent()
}
deeplRequest := Request{
TargetLang: parseLangFromModelName(textRequest.Model),
Text: []string{text},
}
return &deeplRequest, text
}
func StreamResponseDeepL2OpenAI(deeplResponse *Response) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(deeplResponse.Translations) != 0 {
choice.Delta.Content = deeplResponse.Translations[0].Text
}
choice.Delta.Role = role.Assistant
choice.FinishReason = &constant.StopFinishReason
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: constant.StreamObject,
Created: helper.GetTimestamp(),
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &openaiResponse
}
func ResponseDeepL2OpenAI(deeplResponse *Response) *openai.TextResponse {
var responseText string
if len(deeplResponse.Translations) != 0 {
responseText = deeplResponse.Translations[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: role.Assistant,
Content: responseText,
Name: nil,
},
FinishReason: finishreason.Stop,
}
fullTextResponse := openai.TextResponse{
Object: constant.NonStreamObject,
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var deeplResponse Response
err = json.Unmarshal(responseBody, &deeplResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
fullTextResponse := StreamResponseDeepL2OpenAI(&deeplResponse)
fullTextResponse.Model = modelName
fullTextResponse.Id = helper.GetResponseID(c)
jsonData, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
if jsonData != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
jsonData = nil
return true
}
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
})
_ = resp.Body.Close()
return nil
}
func Handler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var deeplResponse Response
err = json.Unmarshal(responseBody, &deeplResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if deeplResponse.Message != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: deeplResponse.Message,
Code: "deepl_error",
},
StatusCode: resp.StatusCode,
}
}
fullTextResponse := ResponseDeepL2OpenAI(&deeplResponse)
fullTextResponse.Model = modelName
fullTextResponse.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil
}

View File

@@ -0,0 +1,16 @@
package deepl
type Request struct {
Text []string `json:"text"`
TargetLang string `json:"target_lang"`
}
type Translation struct {
DetectedSourceLanguage string `json:"detected_source_language,omitempty"`
Text string `json:"text,omitempty"`
}
type Response struct {
Translations []Translation `json:"translations,omitempty"`
Message string `json:"message,omitempty"`
}

View File

@@ -0,0 +1,6 @@
package deepseek
var ModelList = []string{
"deepseek-chat",
"deepseek-coder",
}

View File

@@ -3,33 +3,35 @@ package gemini
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
action = "streamGenerateContent?alt=sse"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
@@ -42,11 +44,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)

Some files were not shown because too many files have changed in this diff Show More