Compare commits

...

75 Commits

Author SHA1 Message Date
Qiying Wang
f636c50c84 fix: duplicate [DONE] (#1629) 2024-07-09 22:43:59 +08:00
Qiying Wang
720fe2dfeb feat: refactor AwsClaude to Aws to support both llama3 and claude (#1601)
* feat: refactor AwsClaude to Aws to support both llama3 and claude

* fix: aws llama3 ratio
2024-07-06 13:19:41 +08:00
Jason
e090e76c86 feat: add Novita AI as model provider (#1609) 2024-07-06 13:16:46 +08:00
open source
6a941748f8 feat: add initial root access token (#1598)
Signed-off-by: xiaobo <peterwillcn@gmail.com>
2024-07-06 13:15:17 +08:00
open source
46a0773580 fix: update readme docs (#1599)
Signed-off-by: xiaobo <peterwillcn@gmail.com>
2024-07-06 13:14:32 +08:00
zijiren
ffdb0b0c81 fix: use musl libc (#1597) 2024-07-06 13:14:07 +08:00
zijiren
efd30a40b3 feat: cloudflare support native openai api (#1596) 2024-07-06 13:12:30 +08:00
Qiying Wang
d7a78f3397 feat: support test specific model (#1600) 2024-07-05 18:05:16 +08:00
Leo Q
273be55797 feat(ui): show available models for air theme (#1595)
* feat(ui): air 主题显示可用模型

* chore: 改为全角括号
2024-07-04 08:35:41 +08:00
Leo Q
ec6ad24810 feat: support smtp without auth (#1101) 2024-07-03 22:23:49 +08:00
LinZeliang
c4fe57c165 feat: support one or more log file (#1400)
Co-authored-by: Laisky.Cai <github@laisky.com>
2024-07-03 20:53:29 +08:00
igophper
274fcf3d76 refactor: init db (#1590)
Co-authored-by: 江杭辉 <jianghanghui@k.app>
2024-07-03 20:50:40 +08:00
Mikey
0fc07ea558 feat: add support for Claude 3 tool use (function calling) (#1587)
* feat: add tool support for AWS & Claude

* fix: add {} for openai compatibility in streaming tool_use
2024-07-02 00:12:01 +08:00
Leo Q
1ce1e529ee ci: skip archive, upload directly (#1586) 2024-07-02 00:05:47 +08:00
Darkside
d936817de9 docs: add related projects (#1562)
Co-authored-by: 成达 <chengda.615@bytedance.com>
2024-06-30 19:57:30 +08:00
igophper
fecaece71b fix: fix size not support during image generation (#1564)
Fixes #1224, #1068
2024-06-30 19:52:33 +08:00
Shi Jilin
c135d74f13 feat: support Spark4.0 Ultra (#1575)
* fix: fix SparkDesk Function Call (修复 Spark Pro/Max函数调用只会返回普通对话回答而不是Function Call回答的问题

* feat: support Spark4.0 Ultra
2024-06-30 19:38:02 +08:00
lihangfu
d0369b114f feat: support spark4.0 ultra (#1569)
* feat: 支持v3最新协议的腾讯混元(#1452)

* feat: 支持Spark4.0 Ultra

---------

Co-authored-by: lihangfu <hfli8@iflytek.com>
2024-06-30 19:37:07 +08:00
zijiren
b21b3b5b46 refactor: abusing goroutines and channel (#1561)
* refactor: abusing goroutines

* fix: trim data prefix

* refactor: move functions to render package

* refactor: add back trim & flush

---------

Co-authored-by: JustSong <quanpengsong@gmail.com>
2024-06-30 18:36:33 +08:00
shaoyun
ae1cd29f94 feat: added support for Claude Sonnet 3.5 (#1567) 2024-06-30 16:25:25 +08:00
dependabot[bot]
f25aaf7752 chore(deps): bump golang.org/x/image from 0.16.0 to 0.18.0 (#1568)
Bumps [golang.org/x/image](https://github.com/golang/image) from 0.16.0 to 0.18.0.
- [Commits](https://github.com/golang/image/compare/v0.16.0...v0.18.0)

---
updated-dependencies:
- dependency-name: golang.org/x/image
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-30 16:21:48 +08:00
JustSong
b70a07e814 fix: fix ci 2024-06-30 16:19:49 +08:00
igophper
34cb147a74 refactor: replace hardcoded string with ctxkey constant (#1579)
Co-authored-by: 江杭辉 <jianghanghui@k.app>
2024-06-30 16:13:43 +08:00
Leo Q
8cc1ee6360 ci: use codecov to upload coverage report (#1583) 2024-06-30 16:12:16 +08:00
Ghostz
5a58426859 fix minimax empty log (#1560) 2024-06-30 16:09:16 +08:00
JustSong
254b9777c0 feat: support load env variables from .env file 2024-06-23 15:37:11 +08:00
JustSong
114c44c6e7 ci: fix ci.yml 2024-06-23 15:17:58 +08:00
JustSong
a3c7e15aed fix: fix ut 2024-06-23 15:14:39 +08:00
JustSong
3777517f64 chore: add ut 2024-06-23 14:28:55 +08:00
JustSong
9fc5f427dc chore: add commit lint and rename yml 2024-06-23 14:01:57 +08:00
JustSong
864a467886 chore: rename go.yml to unit-testing.yml 2024-06-23 13:57:19 +08:00
JustSong
ed78b5340b fix: fix go.yml 2024-06-23 13:53:30 +08:00
JustSong
fee69e7c20 fix: fix ut 2024-06-23 13:48:52 +08:00
JustSong
9d23a44dbf ci: add coverage report 2024-06-23 13:38:43 +08:00
JustSong
6e4cfb20d5 ci: add go.yaml 2024-06-23 13:00:42 +08:00
Shi Jilin
ff196b75a7 fix: fix sparkdesk function call 2024-06-20 22:56:59 +08:00
lihangfu
279caf82dc feat: support tencent v3 api (#1542)
Co-authored-by: lihangfu <hfli8@iflytek.com>
2024-06-20 00:23:08 +08:00
Wei Tingjiang
b1520b308b Try to fix Gemini streaming return being truncated by FinishReason. (#1477)
1
2024-06-14 00:30:47 +08:00
JustSong
ed717211aa chore: adjust default rate limit config 2024-06-13 00:35:37 +08:00
JustSong
6ccf3f3cfc chore: add logger.SysLogf function 2024-06-13 00:28:56 +08:00
jinjianming
f74577141c fix: fix default token not created in some cases (#1510)
* 修复git、微信等用户注册不会创建默认令牌问题

修复git、微信等用户注册不会创建默认令牌问题

* 修复git、微信等用户注册不会创建默认令牌问题

删除普通用户注册代码

* fix: do not block if error happened

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-06-13 00:20:48 +08:00
Buer
6aafb7a99e fix: channel edit settings key error (#1496) 2024-06-13 00:08:49 +08:00
Zhong Liu
c1971870fa fix: support for Spark Lite model (#1526)
* fix: Support for Spark Lite model

* fix: fix panic

* fix: fix xunfei version config

---------

Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-06-13 00:07:26 +08:00
wagxuebing
f83894c83f fix: xunfei interface call 4001 error (#1499)
Co-authored-by: lynnssb <lynntobing@gmail.com>
2024-06-12 23:12:58 +08:00
fxsome
e9981fff36 feat: post all messages for cloudflare (#1515) 2024-06-08 13:34:23 +08:00
取梦为饮
98669d5d48 feat: add support for bytedance's doubao (#1438)
* 增加豆包大模型支持

* chore: update channel options & add prompt

---------

Co-authored-by: 康龙彪 <longbiao.kang@i-tudou.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-06-08 13:26:26 +08:00
Wei Tingjiang
9321427c6e feat: support gemini embeddings (text-embedding-004,embedding-001) (#1475)
* Refactor Gemini Adaptor to Support Embeddings

* Add new models to ModelList
2024-05-29 01:17:32 +08:00
JustSong
ceea4c6d4a feat: support user content download proxy & relay proxy now 2024-05-29 01:14:00 +08:00
carey036
b53e00a9b3 feat: generate default token after register (#1401)
* feat: generate default token after register

* chore: use go routine to create default token for new user

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-05-28 01:44:38 +08:00
Mo
332c8db0b3 fix: add prefixes to image models to solve the problem of duplicate models (#1469)
* Add prefixes to image models to solve the problem of duplicate models

* Fix the issue that response_format is not set, causing the b64_json parameter to be ignored.
2024-05-28 01:32:57 +08:00
fatwang2
3be28da57b Update package.json (#1465) 2024-05-28 01:31:08 +08:00
Ghostz
fa74ba0eaa chore: print user id when relay error happened (#1447)
* add userid when relay error

* chore: update log format

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-05-28 01:30:51 +08:00
Dafei Zhao
a9211d66f6 fix: fix gpt-4o token encoding (#1446) 2024-05-28 01:26:07 +08:00
Buer
07b2fd58d6 feat: berry theme update & bug fix (#1471)
* feat: load channel models from server

* chore: support AWS Claude/Cloudflare/Coze

* fix: Popup message when copying fails

* chore: Optimize tips
2024-05-28 01:22:40 +08:00
JustSong
0acee9a065 fix: fix berry error (close #1445) 2024-05-22 01:00:00 +08:00
JustSong
f965469e8a chore: update dependencies version 2024-05-22 00:52:23 +08:00
JustSong
03ea60532a fix: fix html lang attribute (close #1433) 2024-05-21 01:20:37 +08:00
Qiying Wang
2457d00afb feat: support gpt-4o (#1431) 2024-05-21 01:14:22 +08:00
JustSong
91b80ae879 fix: remove extra space 2024-05-07 23:57:34 +08:00
JustSong
2720e1a358 feat: support minimax's 6.5 models (close #1395) 2024-04-30 02:23:14 +08:00
JustSong
71f4403fd5 feat: add together.ai support (#1298) 2024-04-30 02:16:53 +08:00
JustSong
1f76c80553 fix: fix aws claude panic (#1384) 2024-04-29 22:49:06 +08:00
JustSong
7e027d2bd0 fix: fix minimax prompt & completion tokens is empty (#1391) 2024-04-29 22:35:47 +08:00
JustSong
30f373b623 fix: fix usage is empty (close #1391) 2024-04-29 22:29:13 +08:00
plusye
1c2654320e fix: fix getPreConsumedQuota (#1312) 2024-04-27 16:07:06 +08:00
caixinjiang
6cffb116b7 fix: fix zhipu embedding error when input is array but not string (#1306)
* fix zhipu embedding error when input is array but not string

* fix: only use the first one

---------

Co-authored-by: 蔡新疆 <cxj@icc.link>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-27 16:05:14 +08:00
Qiying Wang
a84c7b38b7 fix: claude stream response parse (#1334) 2024-04-27 15:58:07 +08:00
tylinux
1bd14af47b feat: use mapped model name to test (#1370) 2024-04-27 15:53:20 +08:00
NongMO
6170b91d1c feat: support for the ollama vision model (#1376)
* feat: support for the ollama vision model

`llava` model, pass test

* Update main.go

format code

* chore: remove useless log

---------

Co-authored-by: nongqiqin <nongqiqin@tipdm.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-27 15:47:27 +08:00
JustSong
04b49aa0ec chore: use StringContent() to convert response to text 2024-04-27 15:41:02 +08:00
Wei Tingjiang
ef88497f25 fix: refactor Gemini adaptor to support streaming content generation (#1382) 2024-04-27 15:39:59 +08:00
JustSong
007906216d feat: support DeepL's model (close #1126) 2024-04-27 13:37:22 +08:00
JustSong
e64e7707a0 feat: support cohere's web search 2024-04-27 00:06:43 +08:00
JustSong
ea210b6ed7 chore: update ollama models 2024-04-26 23:12:39 +08:00
JustSong
9026ec7510 feat: support cloudflare now 2024-04-26 23:05:48 +08:00
139 changed files with 3784 additions and 1961 deletions

3
.env.example Normal file
View File

@@ -0,0 +1,3 @@
PORT=3000
DEBUG=false
HTTPS_PROXY=http://localhost:7890

47
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,47 @@
name: CI
# This setup assumes that you run the unit tests with code coverage in the same
# workflow that will also print the coverage report as comment to the pull request.
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
# when new code is pushed to the branch of the pull request. In addition, you also
# need to trigger this workflow when new code is pushed to the main branch because
# we need to upload the code coverage results as artifact for the main branch as
# well since it will be the baseline code coverage.
#
# We do not want to trigger the workflow for pushes to *any* branch because this
# would trigger our jobs twice on pull requests (once from "push" event and once
# from "pull_request->synchronize")
on:
pull_request:
types: [opened, reopened, synchronize]
push:
branches:
- 'main'
jobs:
unit_tests:
name: "Unit tests"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ^1.22
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
# in the next step as well as the next job.
- name: Test
run: go test -cover -coverprofile=coverage.txt ./...
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
commit_lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: wagoid/commitlint-github-action@v6

3
.gitignore vendored
View File

@@ -8,4 +8,5 @@ build
logs logs
data data
/web/node_modules /web/node_modules
cmd.md cmd.md
.env

View File

@@ -16,7 +16,9 @@ WORKDIR /web/air
RUN npm install RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2 FROM golang:alpine AS builder2
RUN apk add --no-cache g++
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=1 \ CGO_ENABLED=1 \
@@ -27,7 +29,7 @@ ADD go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
COPY --from=builder /web/build ./web/build COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine FROM alpine

View File

@@ -101,7 +101,7 @@ Nginx reference configuration:
``` ```
server{ server{
server_name openai.justsong.cn; # Modify your domain name accordingly server_name openai.justsong.cn; # Modify your domain name accordingly
location / { location / {
client_max_body_size 64m; client_max_body_size 64m;
proxy_http_version 1.1; proxy_http_version 1.1;
@@ -132,12 +132,12 @@ The initial account username is `root` and password is `123456`.
1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source:
```shell ```shell
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# Build the frontend # Build the frontend
cd one-api/web/default cd one-api/web/default
npm install npm install
npm run build npm run build
# Build the backend # Build the backend
cd ../.. cd ../..
go mod download go mod download
@@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. 6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'.
7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60` + Example: `SYNC_FREQUENCY=60`
7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. 8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave` + Example: `NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. 9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440` + Example: `CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. 10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440` + Example: `CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. 11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5` + Example: `POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'.
+Example: ` BATCH_UPDATE_ENABLED=true`
+If you encounter an issue with too many database connections, you can try enabling this option.
13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'.
+Example: ` BATCH_UPDATE_INTERVAL=5`
14. Request frequency limit:
+ `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180.
+ `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60.
15. Encoder cache settings:
+`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment.
+`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it.
16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set.
17. `RELAY_PROXY`: After setting up, use this proxy to request APIs.
18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds.
19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images.
20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'.
21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default.
22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'.
23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md).
24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'.
25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'.
26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'.
27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time.
28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time.
### Command Line Parameters ### Command Line Parameters
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.
@@ -287,7 +312,9 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Double-check that your interface address and API Key are correct. + Double-check that your interface address and API Key are correct.
## Related Projects ## Related Projects
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM * [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller.
* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization.
## Note ## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.

View File

@@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
> [!NOTE] > [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> >
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!WARNING] > [!WARNING]
@@ -68,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude) + [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude)
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ [x] [Mistral 系列模型](https://mistral.ai/) + [x] [Mistral 系列模型](https://mistral.ai/)
+ [x] [字节跳动豆包大模型](https://console.volcengine.com/ark/region:ark+cn-beijing/model)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
@@ -76,7 +77,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
+ [x] [Moonshot AI](https://platform.moonshot.cn/) + [x] [Moonshot AI](https://platform.moonshot.cn/)
+ [x] [百川大模型](https://platform.baichuan-ai.com) + [x] [百川大模型](https://platform.baichuan-ai.com)
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ [x] [MINIMAX](https://api.minimax.chat/) + [x] [MINIMAX](https://api.minimax.chat/)
+ [x] [Groq](https://wow.groq.com/) + [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama) + [x] [Ollama](https://github.com/ollama/ollama)
@@ -85,6 +85,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Coze](https://www.coze.com/) + [x] [Coze](https://www.coze.com/)
+ [x] [Cohere](https://cohere.com/) + [x] [Cohere](https://cohere.com/)
+ [x] [DeepSeek](https://www.deepseek.com/) + [x] [DeepSeek](https://www.deepseek.com/)
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -141,7 +145,7 @@ Nginx 的参考配置:
``` ```
server{ server{
server_name openai.justsong.cn; # 请根据实际情况修改你的域名 server_name openai.justsong.cn; # 请根据实际情况修改你的域名
location / { location / {
client_max_body_size 64m; client_max_body_size 64m;
proxy_http_version 1.1; proxy_http_version 1.1;
@@ -186,12 +190,12 @@ docker-compose ps
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell ```shell
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# 构建前端 # 构建前端
cd one-api/web/default cd one-api/web/default
npm install npm install
npm run build npm run build
# 构建后端 # 构建后端
cd ../.. cd ../..
go mod download go mod download
@@ -318,7 +322,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashbo
例如对于 OpenAI 的官方库: 例如对于 OpenAI 的官方库:
```bash ```bash
OPENAI_API_KEY="sk-xxxxxx" OPENAI_API_KEY="sk-xxxxxx"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1" OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
``` ```
```mermaid ```mermaid
@@ -337,6 +341,7 @@ graph LR
不加的话将会使用负载均衡的方式使用多个渠道。 不加的话将会使用负载均衡的方式使用多个渠道。
### 环境变量 ### 环境变量
> One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。 + 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。
@@ -365,30 +370,34 @@ graph LR
+ 例子:`NODE_TYPE=slave` + 例子:`NODE_TYPE=slave`
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440` + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
11. 例子:`CHANNEL_TEST_FREQUENCY=1440` +例子:`CHANNEL_TEST_FREQUENCY=1440`
12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5` + 例子:`POLLING_INTERVAL=5`
13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true` + 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5` + 例子:`BATCH_UPDATE_INTERVAL=5`
15. 请求频率限制: 14. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
16. 编码器缓存设置: 15. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
18. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000` 17. `RELAY_PROXY`:设置后使用该代理来请求 API
19. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE` 18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒
20. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1` 19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片
21. `THEME`:系统的主题设置,默认 `default`,具体可选值参考[此处](./web/README.md) 20. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`
22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 21. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 22. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`。
24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8` 23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)
25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌 24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`
25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -441,6 +450,8 @@ https://openai.justsong.cn
## 相关项目 ## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。
* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。
## 注意 ## 注意

60
common/client/init.go Normal file
View File

@@ -0,0 +1,60 @@
package client
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"net/url"
"time"
)
var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client
var UserContentRequestHTTPClient *http.Client
func Init() {
if config.UserContentRequestProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy))
proxyURL, err := url.Parse(config.UserContentRequestProxy)
if err != nil {
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
UserContentRequestHTTPClient = &http.Client{
Transport: transport,
Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
}
} else {
UserContentRequestHTTPClient = &http.Client{}
}
var transport http.RoundTripper
if config.RelayProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy))
proxyURL, err := url.Parse(config.RelayProxy)
if err != nil {
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
}
transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
if config.RelayTimeout == 0 {
HTTPClient = &http.Client{
Transport: transport,
}
} else {
HTTPClient = &http.Client{
Timeout: time.Duration(config.RelayTimeout) * time.Second,
Transport: transport,
}
}
ImpatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
Transport: transport,
}
}

View File

@@ -117,10 +117,10 @@ var ValidThemes = map[string]bool{
// All duration's unit is seconds // All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration // Shouldn't larger then RateLimitKeyExpirationDuration
var ( var (
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 240)
GlobalApiRateLimitDuration int64 = 3 * 60 GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 120)
GlobalWebRateLimitDuration int64 = 3 * 60 GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10 UploadRateLimitNum = 10
@@ -143,4 +143,13 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1") var GeminiVersion = env.String("GEMINI_VERSION", "v1")
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

View File

@@ -1,13 +0,0 @@
package ctxkey
const (
ConfigPrefix = "cfg_"
ConfigAPIVersion = ConfigPrefix + "api_version"
ConfigLibraryID = ConfigPrefix + "library_id"
ConfigPlugin = ConfigPrefix + "plugin"
ConfigSK = ConfigPrefix + "sk"
ConfigAK = ConfigPrefix + "ak"
ConfigRegion = ConfigPrefix + "region"
ConfigUserID = ConfigPrefix + "user_id"
)

View File

@@ -1,6 +1,7 @@
package ctxkey package ctxkey
const ( const (
Config = "config"
Id = "id" Id = "id"
Username = "username" Username = "username"
Role = "role" Role = "role"
@@ -18,4 +19,5 @@ const (
TokenName = "token_name" TokenName = "token_name"
BaseURL = "base_url" BaseURL = "base_url"
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
) )

View File

@@ -4,14 +4,13 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"io" "io"
"strings" "strings"
) )
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) { func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody) requestBody, _ := c.Get(ctxkey.KeyRequestBody)
if requestBody != nil { if requestBody != nil {
return requestBody.([]byte), nil return requestBody.([]byte), nil
} }
@@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
return nil, err return nil, err
} }
_ = c.Request.Body.Close() _ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody) c.Set(ctxkey.KeyRequestBody, requestBody)
return requestBody.([]byte), nil return requestBody.([]byte), nil
} }

View File

@@ -2,6 +2,7 @@ package helper
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"html/template" "html/template"
"log" "log"
@@ -105,6 +106,11 @@ func GenRequestID() string {
return GetTimeString() + random.GetRandomNumberString(8) return GetTimeString() + random.GetRandomNumberString(8)
} }
func GetResponseID(c *gin.Context) string {
logID := c.GetString(RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID)
}
func Max(a int, b int) int { func Max(a int, b int) int {
if a >= b { if a >= b {
return a return a

5
common/helper/key.go Normal file
View File

@@ -0,0 +1,5 @@
package helper
const (
RequestIdKey = "X-Oneapi-Request-Id"
)

View File

@@ -3,6 +3,7 @@ package image
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image" "image"
_ "image/gif" _ "image/gif"
_ "image/jpeg" _ "image/jpeg"
@@ -19,7 +20,7 @@ import (
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) { func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url) resp, err := client.UserContentRequestHTTPClient.Head(url)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -34,7 +35,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
if !isImage { if !isImage {
return return
} }
resp, err := http.Get(url) resp, err := client.UserContentRequestHTTPClient.Get(url)
if err != nil { if err != nil {
return return
} }

View File

@@ -2,6 +2,7 @@ package image_test
import ( import (
"encoding/base64" "encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image" "image"
_ "image/gif" _ "image/gif"
_ "image/jpeg" _ "image/jpeg"
@@ -44,6 +45,11 @@ var (
} }
) )
func TestMain(m *testing.M) {
client.Init()
m.Run()
}
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
// Bytes read: varies sometimes // Bytes read: varies sometimes
// jpeg: 1063892 // jpeg: 1063892

View File

@@ -24,7 +24,7 @@ func printHelp() {
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]") fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
} }
func init() { func Init() {
flag.Parse() flag.Parse()
if *PrintVersion { if *PrintVersion {

View File

@@ -1,7 +1,3 @@
package logger package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string var LogDir string

View File

@@ -27,7 +27,12 @@ var setupLogOnce sync.Once
func SetupLogger() { func SetupLogger() {
setupLogOnce.Do(func() { setupLogOnce.Do(func() {
if LogDir != "" { if LogDir != "" {
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) var logPath string
if config.OnlyOneLogFile {
logPath = filepath.Join(LogDir, "oneapi.log")
} else {
logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
}
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")
@@ -43,11 +48,19 @@ func SysLog(s string) {
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysLogf(format string, a ...any) {
SysLog(fmt.Sprintf(format, a...))
}
func SysError(s string) { func SysError(s string) {
t := time.Now() t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysErrorf(format string, a ...any) {
SysError(fmt.Sprintf(format, a...))
}
func Debug(ctx context.Context, msg string) { func Debug(ctx context.Context, msg string) {
if config.DebugEnabled { if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg) logHelper(ctx, loggerDEBUG, msg)
@@ -87,7 +100,7 @@ func logHelper(ctx context.Context, level string, msg string) {
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
id := ctx.Value(RequestIdKey) id := ctx.Value(helper.RequestIdKey)
if id == nil { if id == nil {
id = helper.GenRequestID() id = helper.GenRequestID()
} }

View File

@@ -6,11 +6,16 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"net"
"net/smtp" "net/smtp"
"strings" "strings"
"time" "time"
) )
func shouldAuth() bool {
return config.SMTPAccount != "" || config.SMTPToken != ""
}
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if receiver == "" { if receiver == "" {
return fmt.Errorf("receiver is empty") return fmt.Errorf("receiver is empty")
@@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error {
"Date: %s\r\n"+ "Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
if config.SMTPPort == 465 { if config.SMTPPort == 465 || !shouldAuth() {
tlsConfig := &tls.Config{ // need advanced client
InsecureSkipVerify: true, var conn net.Conn
ServerName: config.SMTPServer, var err error
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: config.SMTPServer,
}
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
} else {
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
} }
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil { if err != nil {
return err return err
} }
@@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error {
return err return err
} }
defer client.Close() defer client.Close()
if err = client.Auth(auth); err != nil { if shouldAuth() {
return err if err = client.Auth(auth); err != nil {
return err
}
} }
if err = client.Mail(config.SMTPFrom); err != nil { if err = client.Mail(config.SMTPFrom); err != nil {
return err return err

29
common/render/render.go Normal file
View File

@@ -0,0 +1,29 @@
package render
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"strings"
)
func StringData(c *gin.Context, str string) {
str = strings.TrimPrefix(str, "data: ")
str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
c.Writer.Flush()
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}

View File

@@ -4,12 +4,12 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"

View File

@@ -5,6 +5,16 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@@ -18,23 +28,15 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
) )
func buildTestRequest() *relaymodel.GeneralOpenAIRequest { func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
if model == "" {
model = "gpt-3.5-turbo"
}
testRequest := &relaymodel.GeneralOpenAIRequest{ testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2, MaxTokens: 2,
Stream: false, Model: model,
Model: "gpt-3.5-turbo",
} }
testMessage := relaymodel.Message{ testMessage := relaymodel.Message{
Role: "user", Role: "user",
@@ -44,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
return testRequest return testRequest
} }
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@@ -57,6 +59,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.BaseURL, channel.GetBaseURL()) c.Set(ctxkey.BaseURL, channel.GetBaseURL())
cfg, _ := channel.LoadConfig()
c.Set(ctxkey.Config, cfg)
middleware.SetupContextForSelectedChannel(c, channel, "") middleware.SetupContextForSelectedChannel(c, channel, "")
meta := meta.GetByContext(c) meta := meta.GetByContext(c)
apiType := channeltype.ToAPIType(channel.Type) apiType := channeltype.ToAPIType(channel.Type)
@@ -65,20 +69,19 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
adaptor.Init(meta) adaptor.Init(meta)
var modelName string modelName := request.Model
modelList := adaptor.GetModelList() modelMap := channel.GetModelMapping()
if len(modelList) != 0 {
modelName = modelList[0]
}
if modelName == "" || !strings.Contains(channel.Models, modelName) { if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",") modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 { if len(modelNames) > 0 {
modelName = modelNames[0] modelName = modelNames[0]
} }
if modelMap != nil && modelMap[modelName] != "" {
modelName = modelMap[modelName]
}
} }
request := buildTestRequest() meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
@@ -132,10 +135,15 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
model := c.Query("model")
testRequest := buildTestRequest(model)
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel) err, _ = testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
milliseconds = 0
}
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if err != nil { if err != nil {
@@ -143,6 +151,7 @@ func TestChannel(c *gin.Context) {
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@@ -150,6 +159,7 @@ func TestChannel(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@@ -180,11 +190,12 @@ func testChannels(notify bool, scope string) error {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel) testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold { if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
if config.AutomaticDisableChannelEnabled { if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error()) monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else { } else {

View File

@@ -4,6 +4,9 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
@@ -16,8 +19,6 @@ import (
"github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
@@ -47,6 +48,7 @@ func Relay(c *gin.Context) {
logger.Debugf(ctx, "request body: %s", string(requestBody)) logger.Debugf(ctx, "request body: %s", string(requestBody))
} }
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
userId := c.GetInt(ctxkey.Id)
bizErr := relayHelper(c, relayMode) bizErr := relayHelper(c, relayMode)
if bizErr == nil { if bizErr == nil {
monitor.Emit(channelId, true) monitor.Emit(channelId, true)
@@ -56,8 +58,8 @@ func Relay(c *gin.Context) {
channelName := c.GetString(ctxkey.ChannelName) channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group) group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel) originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, channelId, channelName, bizErr) go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey) requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) { if !shouldRetry(c, bizErr.StatusCode) {
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
@@ -83,7 +85,7 @@ func Relay(c *gin.Context) {
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName) channelName := c.GetString(ctxkey.ChannelName)
go processChannelRelayError(ctx, channelId, channelName, bizErr) go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
} }
if bizErr != nil { if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests { if bizErr.StatusCode == http.StatusTooManyRequests {
@@ -115,8 +117,8 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
return true return true
} }
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message) monitor.DisableChannel(channelId, channelName, err.Message)

View File

@@ -173,6 +173,7 @@ func Register(c *gin.Context) {
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",

51
go.mod
View File

@@ -4,42 +4,43 @@ module github.com/songquanpeng/one-api
go 1.20 go 1.20
require ( require (
github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2 v1.27.0
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.15
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3
github.com/gin-contrib/cors v1.7.1 github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v1.0.0 github.com/gin-contrib/gzip v1.0.1
github.com/gin-contrib/sessions v1.0.0 github.com/gin-contrib/sessions v1.0.1
github.com/gin-contrib/static v1.1.1 github.com/gin-contrib/static v1.1.2
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.10.0
github.com/go-playground/validator/v10 v10.19.0 github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.1 github.com/gorilla/websocket v1.5.1
github.com/jinzhu/copier v0.4.0 github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.6 github.com/pkoukk/tiktoken-go v0.1.7
github.com/smartystreets/goconvey v1.8.1 github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.22.0 golang.org/x/crypto v0.23.0
golang.org/x/image v0.15.0 golang.org/x/image v0.18.0
gorm.io/driver/mysql v1.5.6 gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlite v1.5.5
gorm.io/gorm v1.25.9 gorm.io/gorm v1.25.10
) )
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect
github.com/aws/smithy-go v1.20.2 // indirect github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.11.5 // indirect github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.3 // indirect github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
@@ -50,7 +51,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/context v1.1.2 // indirect github.com/gorilla/context v1.1.2 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect
@@ -67,19 +68,19 @@ require (
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/smarty/assertions v1.15.0 // indirect github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.7.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.24.0 // indirect golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.19.0 // indirect golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.16.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

103
go.sum
View File

@@ -1,28 +1,27 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
@@ -37,32 +36,32 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs= github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps= github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk= github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE=
github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE= github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4=
github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y= github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI=
github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI= github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k= github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4=
github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74= github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4= github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
@@ -94,6 +93,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
@@ -109,8 +110,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -119,12 +120,12 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
@@ -149,25 +150,25 @@ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
@@ -182,7 +183,7 @@ gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4c
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8= gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

29
main.go
View File

@@ -6,7 +6,9 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
_ "github.com/joho/godotenv/autoload"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
@@ -22,29 +24,22 @@ import (
var buildFS embed.FS var buildFS embed.FS
func main() { func main() {
common.Init()
logger.SetupLogger() logger.SetupLogger()
logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) logger.SysLogf("One API %s started", common.Version)
if os.Getenv("GIN_MODE") != "debug" {
if os.Getenv("GIN_MODE") != gin.DebugMode {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if config.DebugEnabled { if config.DebugEnabled {
logger.SysLog("running in debug mode") logger.SysLog("running in debug mode")
} }
var err error
// Initialize SQL Database // Initialize SQL Database
model.DB, err = model.InitDB("SQL_DSN") model.InitDB()
if err != nil { model.InitLogDB()
logger.FatalLog("failed to initialize database: " + err.Error())
} var err error
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
err = model.CreateRootAccountIfNeed() err = model.CreateRootAccountIfNeed()
if err != nil { if err != nil {
logger.FatalLog("database init error: " + err.Error()) logger.FatalLog("database init error: " + err.Error())
@@ -94,6 +89,7 @@ func main() {
logger.SysLog("metric enabled, will disable channel if too much request failed") logger.SysLog("metric enabled, will disable channel if too much request failed")
} }
openai.InitTokenEncoders() openai.InitTokenEncoders()
client.Init()
// Initialize HTTP server // Initialize HTTP server
server := gin.New() server := gin.New()
@@ -111,6 +107,7 @@ func main() {
if port == "" { if port == "" {
port = strconv.Itoa(*common.Port) port = strconv.Itoa(*common.Port)
} }
logger.SysLogf("server started on http://localhost:%s", port)
err = server.Run(":" + port) err = server.Run(":" + port)
if err != nil { if err != nil {
logger.FatalLog("failed to start HTTP server: " + err.Error()) logger.FatalLog("failed to start HTTP server: " + err.Error())

View File

@@ -65,21 +65,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.OriginalModel, modelName) // for retry c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set(ctxkey.BaseURL, channel.GetBaseURL()) c.Set(ctxkey.BaseURL, channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type {
case channeltype.Azure:
c.Set(ctxkey.ConfigAPIVersion, channel.Other)
case channeltype.Xunfei:
c.Set(ctxkey.ConfigAPIVersion, channel.Other)
case channeltype.Gemini:
c.Set(ctxkey.ConfigAPIVersion, channel.Other)
case channeltype.AIProxyLibrary:
c.Set(ctxkey.ConfigLibraryID, channel.Other)
case channeltype.Ali:
c.Set(ctxkey.ConfigPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig() cfg, _ := channel.LoadConfig()
for k, v := range cfg { // this is for backward compatibility
c.Set(ctxkey.ConfigPrefix+k, v) if channel.Other != nil {
switch channel.Type {
case channeltype.Azure:
if cfg.APIVersion == "" {
cfg.APIVersion = *channel.Other
}
case channeltype.Xunfei:
if cfg.APIVersion == "" {
cfg.APIVersion = *channel.Other
}
case channeltype.Gemini:
if cfg.APIVersion == "" {
cfg.APIVersion = *channel.Other
}
case channeltype.AIProxyLibrary:
if cfg.LibraryID == "" {
cfg.LibraryID = *channel.Other
}
case channeltype.Ali:
if cfg.Plugin == "" {
cfg.Plugin = *channel.Other
}
}
} }
c.Set(ctxkey.Config, cfg)
} }

View File

@@ -3,14 +3,14 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/helper"
) )
func SetUpLogger(server *gin.Engine) { func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string var requestID string
if param.Keys != nil { if param.Keys != nil {
requestID = param.Keys[logger.RequestIdKey].(string) requestID = param.Keys[helper.RequestIdKey].(string)
} }
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"), param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@@ -4,16 +4,15 @@ import (
"context" "context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
) )
func RequestId() func(c *gin.Context) { func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := helper.GenRequestID() id := helper.GenRequestID()
c.Set(logger.RequestIdKey, id) c.Set(helper.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Header(logger.RequestIdKey, id) c.Header(helper.RequestIdKey, id)
c.Next() c.Next()
} }
} }

View File

@@ -12,7 +12,7 @@ import (
func abortWithMessage(c *gin.Context, statusCode int, message string) { func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), "message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)),
"type": "one_api_error", "type": "one_api_error",
}, },
}) })

View File

@@ -27,7 +27,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"` // DEPRECATED: please save config to field Config Other *string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"` Models string `json:"models"`
@@ -38,6 +38,16 @@ type Channel struct {
Config string `json:"config"` Config string `json:"config"`
} }
type ChannelConfig struct {
Region string `json:"region,omitempty"`
SK string `json:"sk,omitempty"`
AK string `json:"ak,omitempty"`
UserID string `json:"user_id,omitempty"`
APIVersion string `json:"api_version,omitempty"`
LibraryID string `json:"library_id,omitempty"`
Plugin string `json:"plugin,omitempty"`
}
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
var channels []*Channel var channels []*Channel
var err error var err error
@@ -161,14 +171,14 @@ func (channel *Channel) Delete() error {
return err return err
} }
func (channel *Channel) LoadConfig() (map[string]string, error) { func (channel *Channel) LoadConfig() (ChannelConfig, error) {
var cfg ChannelConfig
if channel.Config == "" { if channel.Config == "" {
return nil, nil return cfg, nil
} }
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg) err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil { if err != nil {
return nil, err return cfg, err
} }
return cfg, nil return cfg, nil
} }

View File

@@ -1,6 +1,7 @@
package model package model
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
@@ -29,13 +30,17 @@ func CreateRootAccountIfNeed() error {
if err != nil { if err != nil {
return err return err
} }
accessToken := random.GetUUID()
if config.InitialRootAccessToken != "" {
accessToken = config.InitialRootAccessToken
}
rootUser := User{ rootUser := User{
Username: "root", Username: "root",
Password: hashedPassword, Password: hashedPassword,
Role: RoleRootUser, Role: RoleRootUser,
Status: UserStatusEnabled, Status: UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: random.GetUUID(), AccessToken: accessToken,
Quota: 500000000000000, Quota: 500000000000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@@ -60,90 +65,156 @@ func CreateRootAccountIfNeed() error {
} }
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" { dsn := os.Getenv(envName)
dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") { switch {
// Use PostgreSQL case strings.HasPrefix(dsn, "postgres://"):
logger.SysLog("using PostgreSQL as database") // Use PostgreSQL
common.UsingPostgreSQL = true return openPostgreSQL(dsn)
return gorm.Open(postgres.New(postgres.Config{ case dsn != "":
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL // Use MySQL
logger.SysLog("using MySQL as database") return openMySQL(dsn)
common.UsingMySQL = true default:
return gorm.Open(mysql.Open(dsn), &gorm.Config{ // Use SQLite
PrepareStmt: true, // precompile SQL return openSQLite()
})
} }
// Use SQLite }
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true func openPostgreSQL(dsn string) (*gorm.DB, error) {
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) logger.SysLog("using PostgreSQL as database")
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
func InitDB(envName string) (db *gorm.DB, err error) { func openMySQL(dsn string) (*gorm.DB, error) {
db, err = chooseDB(envName) logger.SysLog("using MySQL as database")
if err == nil { common.UsingMySQL = true
if config.DebugSQLEnabled { return gorm.Open(mysql.Open(dsn), &gorm.Config{
db = db.Debug() PrepareStmt: true, // precompile SQL
} })
sqlDB, err := db.DB() }
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode { func openSQLite() (*gorm.DB, error) {
return db, err logger.SysLog("SQL_DSN not set, using SQLite as database")
} common.UsingSQLite = true
if common.UsingMySQL { dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded return gorm.Open(sqlite.Open(dsn), &gorm.Config{
} PrepareStmt: true, // precompile SQL
logger.SysLog("database migration started") })
err = db.AutoMigrate(&Channel{}) }
if err != nil {
return nil, err func InitDB() {
} var err error
err = db.AutoMigrate(&Token{}) DB, err = chooseDB("SQL_DSN")
if err != nil { if err != nil {
return nil, err logger.FatalLog("failed to initialize database: " + err.Error())
} return
err = db.AutoMigrate(&User{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return nil, err
}
logger.SysLog("database migrated")
return db, err
} else {
logger.FatalLog(err)
} }
return db, err
sqlDB := setDBConns(DB)
if !config.IsMasterNode {
return
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
if err = migrateDB(); err != nil {
logger.FatalLog("failed to migrate database: " + err.Error())
return
}
logger.SysLog("database migrated")
}
func migrateDB() error {
var err error
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Token{}); err != nil {
return err
}
if err = DB.AutoMigrate(&User{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Option{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Redemption{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Ability{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Log{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
return nil
}
func InitLogDB() {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
logger.SysLog("using secondary database for table logs")
var err error
LOG_DB, err = chooseDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
return
}
setDBConns(LOG_DB)
if !config.IsMasterNode {
return
}
logger.SysLog("secondary database migration started")
err = migrateLOGDB()
if err != nil {
logger.FatalLog("failed to migrate secondary database: " + err.Error())
return
}
logger.SysLog("secondary database migrated")
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func setDBConns(db *gorm.DB) *sql.DB {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
logger.FatalLog("failed to connect database: " + err.Error())
return nil
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
return sqlDB
} }
func closeDB(db *gorm.DB) error { func closeDB(db *gorm.DB) error {

View File

@@ -6,6 +6,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm" "gorm.io/gorm"
@@ -140,6 +141,22 @@ func (user *User) Insert(inviterId int) error {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
// create default token
cleanToken := Token{
UserId: user.Id,
Name: "default",
Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: -1,
UnlimitedQuota: true,
}
result.Error = cleanToken.Insert()
if result.Error != nil {
// do not block
logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
}
return nil return nil
} }

View File

@@ -7,8 +7,10 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws" "github.com/songquanpeng/one-api/relay/adaptor/aws"
"github.com/songquanpeng/one-api/relay/adaptor/baidu" "github.com/songquanpeng/one-api/relay/adaptor/baidu"
"github.com/songquanpeng/one-api/relay/adaptor/cloudflare"
"github.com/songquanpeng/one-api/relay/adaptor/cohere" "github.com/songquanpeng/one-api/relay/adaptor/cohere"
"github.com/songquanpeng/one-api/relay/adaptor/coze" "github.com/songquanpeng/one-api/relay/adaptor/coze"
"github.com/songquanpeng/one-api/relay/adaptor/deepl"
"github.com/songquanpeng/one-api/relay/adaptor/gemini" "github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/ollama" "github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
@@ -49,6 +51,10 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &coze.Adaptor{} return &coze.Adaptor{}
case apitype.Cohere: case apitype.Cohere:
return &cohere.Adaptor{} return &cohere.Adaptor{}
case apitype.Cloudflare:
return &cloudflare.Adaptor{}
case apitype.DeepL:
return &deepl.Adaptor{}
} }
return nil return nil
} }

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@@ -13,10 +12,11 @@ import (
) )
type Adaptor struct { type Adaptor struct {
meta *meta.Meta
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
@@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
aiProxyLibraryRequest := ConvertRequest(*request) aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(ctxkey.ConfigLibraryID) aiProxyLibraryRequest.LibraryId = a.meta.Config.LibraryID
return aiProxyLibraryRequest, nil return aiProxyLibraryRequest, nil
} }

View File

@@ -4,6 +4,12 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@@ -12,10 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
) )
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
@@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage var usage model.Usage
var documents []LibraryDocument
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var documents []LibraryDocument
c.Stream(func(w io.Writer) bool { for scanner.Scan() {
select { data := scanner.Text()
case data := <-dataChan: if len(data) < 5 || data[:5] != "data:" {
var AIProxyLibraryResponse LibraryStreamResponse continue
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = data[5:]
err := resp.Body.Close()
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
response := documentsAIProxyLibrary(documents)
err := render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
render.Done(c)
err = resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@@ -16,10 +15,11 @@ import (
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct { type Adaptor struct {
meta *meta.Meta
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
@@ -47,8 +47,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
if meta.Mode == relaymode.ImagesGenerations { if meta.Mode == relaymode.ImagesGenerations {
req.Header.Set("X-DashScope-Async", "enable") req.Header.Set("X-DashScope-Async", "enable")
} }
if c.GetString(ctxkey.ConfigPlugin) != "" { if a.meta.Config.Plugin != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(ctxkey.ConfigPlugin)) req.Header.Set("X-DashScope-Plugin", a.meta.Config.Plugin)
} }
return nil return nil
} }

View File

@@ -3,15 +3,17 @@ package ali
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool { for scanner.Scan() {
select { data := scanner.Text()
case data := <-dataChan: if len(data) < 5 || data[:5] != "data:" {
var aliResponse ChatResponse continue
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = data[5:]
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
continue
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -5,4 +5,5 @@ var ModelList = []string{
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
} }

View File

@@ -4,6 +4,11 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@@ -11,9 +16,6 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
func stopReasonClaude2OpenAI(reason *string) string { func stopReasonClaude2OpenAI(reason *string) string {
@@ -27,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string {
return "stop" return "stop"
case "max_tokens": case "max_tokens":
return "length" return "length"
case "tool_use":
return "tool_calls"
default: default:
return *reason return *reason
} }
} }
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
}
claudeRequest := Request{ claudeRequest := Request{
Model: textRequest.Model, Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
@@ -40,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
TopP: textRequest.TopP, TopP: textRequest.TopP,
TopK: textRequest.TopK, TopK: textRequest.TopK,
Stream: textRequest.Stream, Stream: textRequest.Stream,
Tools: claudeTools,
}
if len(claudeTools) > 0 {
claudeToolChoice := struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
if function, ok := choice["function"].(map[string]any); ok {
claudeToolChoice.Type = "tool"
claudeToolChoice.Name = function["name"].(string)
}
} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
if toolChoiceType == "any" {
claudeToolChoice.Type = toolChoiceType
}
}
claudeRequest.ToolChoice = claudeToolChoice
} }
if claudeRequest.MaxTokens == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096 claudeRequest.MaxTokens = 4096
@@ -62,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
if message.IsStringContent() { if message.IsStringContent() {
content.Type = "text" content.Type = "text"
content.Text = message.StringContent() content.Text = message.StringContent()
if message.Role == "tool" {
claudeMessage.Role = "user"
content.Type = "tool_result"
content.Content = content.Text
content.Text = ""
content.ToolUseId = message.ToolCallId
}
claudeMessage.Content = append(claudeMessage.Content, content) claudeMessage.Content = append(claudeMessage.Content, content)
for i := range message.ToolCalls {
inputParam := make(map[string]any)
_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
claudeMessage.Content = append(claudeMessage.Content, Content{
Type: "tool_use",
Id: message.ToolCalls[i].Id,
Name: message.ToolCalls[i].Function.Name,
Input: inputParam,
})
}
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue continue
} }
@@ -95,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
var response *Response var response *Response
var responseText string var responseText string
var stopReason string var stopReason string
tools := make([]model.Tool, 0)
switch claudeResponse.Type { switch claudeResponse.Type {
case "message_start": case "message_start":
return nil, claudeResponse.Message return nil, claudeResponse.Message
case "content_block_start": case "content_block_start":
if claudeResponse.ContentBlock != nil { if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text responseText = claudeResponse.ContentBlock.Text
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, model.Tool{
Id: claudeResponse.ContentBlock.Id,
Type: "function",
Function: model.Function{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} }
case "content_block_delta": case "content_block_delta":
if claudeResponse.Delta != nil { if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text responseText = claudeResponse.Delta.Text
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, model.Tool{
Function: model.Function{
Arguments: claudeResponse.Delta.PartialJson,
},
})
}
} }
case "message_delta": case "message_delta":
if claudeResponse.Usage != nil { if claudeResponse.Usage != nil {
@@ -118,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
} }
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText choice.Delta.Content = responseText
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
choice.Delta.Role = "assistant" choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason) finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" { if finishReason != "null" {
@@ -134,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
if len(claudeResponse.Content) > 0 { if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text responseText = claudeResponse.Content[0].Text
} }
tools := make([]model.Tool, 0)
for _, v := range claudeResponse.Content {
if v.Type == "tool_use" {
args, _ := json.Marshal(v.Input)
tools = append(tools, model.Tool{
Id: v.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: model.Function{
Name: v.Name,
Arguments: string(args),
},
})
}
}
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: responseText, Content: responseText,
Name: nil, Name: nil,
ToolCalls: tools,
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
@@ -168,64 +261,77 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var usage model.Usage var usage model.Usage
var modelName string var modelName string
var id string var id string
c.Stream(func(w io.Writer) bool { var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
select {
case data := <-dataChan: for scanner.Scan() {
// some implementations may add \r at the end of data data := scanner.Text()
data = strings.TrimSuffix(data, "\r") if len(data) < 6 || !strings.HasPrefix(data, "data:") {
var claudeResponse StreamResponse continue
err := json.Unmarshal([]byte(data), &claudeResponse) }
if err != nil { data = strings.TrimPrefix(data, "data:")
logger.SysError("error unmarshalling stream response: " + err.Error()) data = strings.TrimSpace(data)
return true
} var claudeResponse StreamResponse
response, meta := StreamResponseClaude2OpenAI(&claudeResponse) err := json.Unmarshal([]byte(data), &claudeResponse)
if meta != nil { if err != nil {
usage.PromptTokens += meta.Usage.InputTokens logger.SysError("error unmarshalling stream response: " + err.Error())
usage.CompletionTokens += meta.Usage.OutputTokens continue
}
response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
modelName = meta.Model modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id) id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true continue
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
} }
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) if response == nil {
_ = resp.Body.Close() continue
}
response.Id = id
response.Model = modelName
response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage return nil, &usage
} }

View File

@@ -16,6 +16,12 @@ type Content struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"` Source *ImageSource `json:"source,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
} }
type Message struct { type Message struct {
@@ -23,6 +29,18 @@ type Message struct {
Content []Content `json:"content"` Content []Content `json:"content"`
} }
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type Request struct { type Request struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@@ -33,6 +51,8 @@ type Request struct {
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
//Metadata `json:"metadata,omitempty"` //Metadata `json:"metadata,omitempty"`
} }
@@ -61,6 +81,7 @@ type Response struct {
type Delta struct { type Delta struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`
PartialJson string `json:"partial_json,omitempty"`
StopReason *string `json:"stop_reason"` StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"` StopSequence *string `json:"stop_sequence"`
} }

View File

@@ -0,0 +1,84 @@
package aws
import (
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct {
awsAdapter utils.AwsAdapter
Meta *meta.Meta
AwsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
adaptor := GetAdaptor(request.Model)
if adaptor == nil {
return nil, errors.New("adaptor not found")
}
a.awsAdapter = adaptor
return adaptor.ConvertRequest(c, relayMode, request)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if a.awsAdapter == nil {
return nil, utils.WrapErr(errors.New("awsAdapter is nil"))
}
return a.awsAdapter.DoResponse(c, a.AwsClient, meta)
}
func (a *Adaptor) GetModelList() (models []string) {
for model := range adaptors {
models = append(models, model)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}

View File

@@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@@ -5,72 +5,48 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"io" "io"
"net/http" "net/http"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func newAwsClient(c *gin.Context) (*bedrockruntime.Client, error) {
ak := c.GetString(ctxkey.ConfigAK)
sk := c.GetString(ctxkey.ConfigSK)
region := c.GetString(ctxkey.ConfigRegion)
client := bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
return client, nil
}
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{ var AwsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1", "claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2", "claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1", "claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
} }
func awsModelID(requestModel string) (string, error) { func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok { if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil return awsModelID, nil
} }
return "", errors.Errorf("model %s not found", requestModel) return "", errors.Errorf("model %s not found", requestModel)
} }
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
@@ -81,30 +57,30 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return utils.WrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*anthropic.Request) claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{ awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31", AnthropicVersion: "bedrock-2023-05-31",
} }
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil return utils.WrapErr(errors.Wrap(err, "copy request")), nil
} }
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
} }
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
} }
claudeResponse := new(anthropic.Response) claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse) err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
} }
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
@@ -120,16 +96,11 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return nil, &usage return nil, &usage
} }
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
awsCli, err := newAwsClient(c)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
@@ -140,7 +111,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok { if !ok {
return wrapErr(errors.New("request not found")), nil return utils.WrapErr(errors.New("request not found")), nil
} }
claudeReq := claudeReq_.(*anthropic.Request) claudeReq := claudeReq_.(*anthropic.Request)
@@ -148,16 +119,16 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
AnthropicVersion: "bedrock-2023-05-31", AnthropicVersion: "bedrock-2023-05-31",
} }
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil return utils.WrapErr(errors.Wrap(err, "copy request")), nil
} }
awsReq.Body, err = json.Marshal(awsClaudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
} }
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
} }
stream := awsResp.GetStream() stream := awsResp.GetStream()
defer stream.Close() defer stream.Close()
@@ -165,6 +136,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage var usage relaymodel.Usage
var id string var id string
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events() event, ok := <-stream.Events()
if !ok { if !ok {
@@ -185,8 +158,19 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
if meta != nil { if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens usage.CompletionTokens += meta.Usage.OutputTokens
id = fmt.Sprintf("chatcmpl-%s", meta.Id) if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
return true id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
}
} }
if response == nil { if response == nil {
return true return true
@@ -194,6 +178,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
response.Id = id response.Id = id
response.Model = c.GetString(ctxkey.OriginalModel) response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())

View File

@@ -9,9 +9,12 @@ type Request struct {
// AnthropicVersion should be "bedrock-2023-05-31" // AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"` AnthropicVersion string `json:"anthropic_version"`
Messages []anthropic.Message `json:"messages"` Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
} }

View File

@@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
llamaReq := ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, llamaReq)
return llamaReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@@ -0,0 +1,231 @@
// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"text/template"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/random"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// Only support llama-3-8b and llama-3-70b instruction models
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var AwsModelIDMap = map[string]string{
"llama3-8b-8192": "meta.llama3-8b-instruct-v1:0",
"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0",
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
// promptTemplate with range
const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|>
`
var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate))
func RenderPrompt(messages []relaymodel.Message) string {
var buf bytes.Buffer
err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages})
if err != nil {
logger.SysError("error rendering prompt messages: " + err.Error())
}
return buf.String()
}
func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
llamaRequest := Request{
MaxGenLen: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
}
if llamaRequest.MaxGenLen == 0 {
llamaRequest.MaxGenLen = 2048
}
prompt := RenderPrompt(textRequest.Messages)
llamaRequest.Prompt = prompt
return &llamaRequest
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
}
var llamaResponse Response
err = json.Unmarshal(awsResp.Body, &llamaResponse)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := ResponseLlama2OpenAI(&llamaResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: llamaResponse.PromptTokenCount,
CompletionTokens: llamaResponse.GenerationTokenCount,
TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
var responseText string
if len(llamaResponse.Generation) > 0 {
responseText = llamaResponse.Generation
}
choice := openai.TextResponseChoice{
Index: 0,
Message: relaymodel.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: llamaResponse.StopReason,
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
var llamaResp StreamResponse
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
if llamaResp.PromptTokenCount > 0 {
usage.PromptTokens = llamaResp.PromptTokenCount
}
if llamaResp.StopReason == "stop" {
usage.CompletionTokens = llamaResp.GenerationTokenCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
response := StreamResponseLlama2OpenAI(&llamaResp)
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}
func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = llamaResponse.Generation
choice.Delta.Role = "assistant"
finishReason := llamaResponse.StopReason
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse
}

View File

@@ -0,0 +1,45 @@
package aws_test
import (
"testing"
aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/stretchr/testify/assert"
)
func TestRenderPrompt(t *testing.T) {
messages := []relaymodel.Message{
{
Role: "user",
Content: "What's your name?",
},
}
prompt := aws.RenderPrompt(messages)
expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
messages = []relaymodel.Message{
{
Role: "system",
Content: "Your name is Kat. You are a detective.",
},
{
Role: "user",
Content: "What's your name?",
},
{
Role: "assistant",
Content: "Kat",
},
{
Role: "user",
Content: "What's your job?",
},
}
prompt = aws.RenderPrompt(messages)
expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
}

View File

@@ -0,0 +1,29 @@
package aws
// Request is the request to AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Request struct {
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
}
// Response is the response from AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Response struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}
// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None}
type StreamResponse struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}

View File

@@ -0,0 +1,39 @@
package aws
import (
claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude"
llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
)
type AwsModelType int
const (
AwsClaude AwsModelType = iota + 1
AwsLlama3
)
var (
adaptors = map[string]AwsModelType{}
)
func init() {
for model := range claude.AwsModelIDMap {
adaptors[model] = AwsClaude
}
for model := range llama3.AwsModelIDMap {
adaptors[model] = AwsLlama3
}
}
func GetAdaptor(model string) utils.AwsAdapter {
adaptorType := adaptors[model]
switch adaptorType {
case AwsClaude:
return &claude.Adaptor{}
case AwsLlama3:
return &llama3.Adaptor{}
default:
return nil
}
}

View File

@@ -0,0 +1,51 @@
package utils
import (
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type AwsAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
}
type Adaptor struct {
Meta *meta.Meta
AwsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}

View File

@@ -0,0 +1,16 @@
package utils
import (
"net/http"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func WrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: err.Error(),
},
}
}

View File

@@ -1,15 +0,0 @@
package azure
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
)
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString(ctxkey.ConfigAPIVersion)
}
return apiVersion
}

View File

@@ -5,18 +5,20 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
) )
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage var usage model.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
var baiduResponse ChatStreamResponse if len(data) < 6 {
err := json.Unmarshal([]byte(data), &baiduResponse) continue
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = data[6:]
var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -0,0 +1,81 @@
package cloudflare
import (
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
meta *meta.Meta
}
// ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertImageRequest implements adaptor.Adaptor.
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case relaymode.Completions:
return ConvertCompletionsRequest(*request), nil
case relaymode.ChatCompletions, relaymode.Embeddings:
return request, nil
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp, meta.PromptTokens, meta.ActualModelName)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "cloudflare"
}

View File

@@ -0,0 +1,36 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
"@cf/deepseek-ai/deepseek-math-7b-base",
"@cf/deepseek-ai/deepseek-math-7b-instruct",
"@cf/thebloke/discolm-german-7b-v1-awq",
"@cf/tiiuae/falcon-7b-instruct",
"@cf/google/gemma-2b-it-lora",
"@hf/google/gemma-7b-it",
"@cf/google/gemma-7b-it-lora",
"@hf/nousresearch/hermes-2-pro-mistral-7b",
"@hf/thebloke/llama-2-13b-chat-awq",
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
"@cf/meta/llama-3-8b-instruct",
"@hf/thebloke/llamaguard-7b-awq",
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
"@hf/mistralai/mistral-7b-instruct-v0.2",
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
"@hf/thebloke/neural-chat-7b-v3-1-awq",
"@cf/openchat/openchat-3.5-0106",
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
"@cf/microsoft/phi-2",
"@cf/qwen/qwen1.5-0.5b-chat",
"@cf/qwen/qwen1.5-1.8b-chat",
"@cf/qwen/qwen1.5-14b-chat-awq",
"@cf/qwen/qwen1.5-7b-chat-awq",
"@cf/defog/sqlcoder-7b-2",
"@hf/nexusflow/starling-lm-7b-beta",
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
"@hf/thebloke/zephyr-7b-beta-awq",
}

View File

@@ -0,0 +1,115 @@
package cloudflare
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
p, _ := textRequest.Prompt.(string)
return &Request{
Prompt: p,
MaxTokens: textRequest.MaxTokens,
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
responseModel := c.GetString(ctxkey.OriginalModel)
var responseText string
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response openai.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, v := range response.Choices {
v.Delta.Role = "assistant"
responseText += v.Delta.StringContent()
}
response.Id = id
response.Model = modelName
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
return nil, usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var response openai.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
response.Model = modelName
var responseText string
for _, v := range response.Choices {
responseText += v.Message.Content.(string)
}
usage := openai.ResponseText2Usage(responseText, modelName, promptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -0,0 +1,13 @@
package cloudflare
import "github.com/songquanpeng/one-api/relay/model"
type Request struct {
Messages []model.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}

View File

@@ -5,3 +5,10 @@ var ModelList = []string{
"command-light", "command-light-nightly", "command-light", "command-light-nightly",
"command-r", "command-r-plus", "command-r", "command-r-plus",
} }
func init() {
num := len(ModelList)
for i := 0; i < num; i++ {
ModelList = append(ModelList, ModelList[i]+"-internet")
}
}

View File

@@ -2,9 +2,9 @@ package cohere
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -17,6 +17,10 @@ import (
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
var (
WebSearchConnector = Connector{ID: "web-search"}
)
func stopReasonCohere2OpenAI(reason *string) string { func stopReasonCohere2OpenAI(reason *string) string {
if reason == nil { if reason == nil {
return "" return ""
@@ -45,6 +49,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
if cohereRequest.Model == "" { if cohereRequest.Model == "" {
cohereRequest.Model = "command-r" cohereRequest.Model = "command-r"
} }
if strings.HasSuffix(cohereRequest.Model, "-internet") {
cohereRequest.Model = strings.TrimSuffix(cohereRequest.Model, "-internet")
cohereRequest.Connectors = append(cohereRequest.Connectors, WebSearchConnector)
}
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
if message.Role == "user" { if message.Role == "user" {
cohereRequest.Message = message.Content.(string) cohereRequest.Message = message.Content.(string)
@@ -126,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var usage model.Usage var usage model.Usage
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
// some implementations may add \r at the end of data data = strings.TrimSuffix(data, "\r")
data = strings.TrimSuffix(data, "\r")
var cohereResponse StreamResponse var cohereResponse StreamResponse
err := json.Unmarshal([]byte(data), &cohereResponse) err := json.Unmarshal([]byte(data), &cohereResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
}
response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
if meta != nil {
usage.PromptTokens += meta.Meta.Tokens.InputTokens
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
return true
}
if response == nil {
return true
}
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
response.Model = c.GetString("original_model")
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
_ = resp.Body.Close() response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
if meta != nil {
usage.PromptTokens += meta.Meta.Tokens.InputTokens
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
continue
}
if response == nil {
continue
}
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
response.Model = c.GetString("original_model")
response.Created = createdTime
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage return nil, &usage
} }

View File

@@ -4,7 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/client" "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"io" "io"
"net/http" "net/http"

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
@@ -14,10 +13,11 @@ import (
) )
type Adaptor struct { type Adaptor struct {
meta *meta.Meta
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
@@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
request.User = c.GetString(ctxkey.ConfigUserID) request.User = a.meta.Config.UserID
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }

View File

@@ -4,6 +4,11 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
@@ -12,9 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
// https://www.coze.com/open // https://www.coze.com/open
@@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
var responseText string var responseText string
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 {
continue
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var modelName string var modelName string
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
// some implementations may add \r at the end of data if len(data) < 5 || !strings.HasPrefix(data, "data:") {
data = strings.TrimSuffix(data, "\r") continue
var cozeResponse StreamResponse
err := json.Unmarshal([]byte(data), &cozeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
if response == nil {
return true
}
for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = strings.TrimPrefix(data, "data:")
_ = resp.Body.Close() data = strings.TrimSuffix(data, "\r")
var cozeResponse StreamResponse
err := json.Unmarshal([]byte(data), &cozeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
if response == nil {
continue
}
for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
response.Model = modelName
response.Created = createdTime
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &responseText return nil, &responseText
} }

View File

@@ -1,32 +1,32 @@
package aws package deepl
import ( import (
"github.com/songquanpeng/one-api/common/ctxkey" "errors"
"io" "fmt"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct { type Adaptor struct {
meta *meta.Meta
promptText string
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil return fmt.Sprintf("%s/v2/translate", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "DeepL-Auth-Key "+meta.APIKey)
return nil return nil
} }
@@ -34,11 +34,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
convertedRequest, text := ConvertRequest(*request)
claudeReq := anthropic.ConvertRequest(*request) a.promptText = text
c.Set(ctxkey.RequestModel, request.Model) return convertedRequest, nil
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
@@ -49,26 +47,27 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err = StreamHandler(c, resp, meta.ActualModelName)
} else { } else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) err = Handler(c, resp, meta.ActualModelName)
}
promptTokens := len(a.promptText)
usage = &model.Usage{
PromptTokens: promptTokens,
TotalTokens: promptTokens,
} }
return return
} }
func (a *Adaptor) GetModelList() (models []string) { func (a *Adaptor) GetModelList() []string {
for n := range awsModelIDMap { return ModelList
models = append(models, n)
}
return
} }
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
return "aws" return "deepl"
} }

View File

@@ -0,0 +1,9 @@
package deepl
// https://developers.deepl.com/docs/api-reference/glossaries
var ModelList = []string{
"deepl-zh",
"deepl-en",
"deepl-ja",
}

View File

@@ -0,0 +1,11 @@
package deepl
import "strings"
func parseLangFromModelName(modelName string) string {
parts := strings.Split(modelName, "-")
if len(parts) == 1 {
return "ZH"
}
return parts[1]
}

137
relay/adaptor/deepl/main.go Normal file
View File

@@ -0,0 +1,137 @@
package deepl
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/constant/finishreason"
"github.com/songquanpeng/one-api/relay/constant/role"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
// https://developers.deepl.com/docs/getting-started/your-first-api-request
func ConvertRequest(textRequest model.GeneralOpenAIRequest) (*Request, string) {
var text string
if len(textRequest.Messages) != 0 {
text = textRequest.Messages[len(textRequest.Messages)-1].StringContent()
}
deeplRequest := Request{
TargetLang: parseLangFromModelName(textRequest.Model),
Text: []string{text},
}
return &deeplRequest, text
}
func StreamResponseDeepL2OpenAI(deeplResponse *Response) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(deeplResponse.Translations) != 0 {
choice.Delta.Content = deeplResponse.Translations[0].Text
}
choice.Delta.Role = role.Assistant
choice.FinishReason = &constant.StopFinishReason
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: constant.StreamObject,
Created: helper.GetTimestamp(),
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &openaiResponse
}
func ResponseDeepL2OpenAI(deeplResponse *Response) *openai.TextResponse {
var responseText string
if len(deeplResponse.Translations) != 0 {
responseText = deeplResponse.Translations[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: role.Assistant,
Content: responseText,
Name: nil,
},
FinishReason: finishreason.Stop,
}
fullTextResponse := openai.TextResponse{
Object: constant.NonStreamObject,
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var deeplResponse Response
err = json.Unmarshal(responseBody, &deeplResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
fullTextResponse := StreamResponseDeepL2OpenAI(&deeplResponse)
fullTextResponse.Model = modelName
fullTextResponse.Id = helper.GetResponseID(c)
jsonData, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
if jsonData != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
jsonData = nil
return true
}
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
})
_ = resp.Body.Close()
return nil
}
func Handler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var deeplResponse Response
err = json.Unmarshal(responseBody, &deeplResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if deeplResponse.Message != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: deeplResponse.Message,
Code: "deepl_error",
},
StatusCode: resp.StatusCode,
}
}
fullTextResponse := ResponseDeepL2OpenAI(&deeplResponse)
fullTextResponse.Model = modelName
fullTextResponse.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil
}

View File

@@ -0,0 +1,16 @@
package deepl
type Request struct {
Text []string `json:"text"`
TargetLang string `json:"target_lang"`
}
type Translation struct {
DetectedSourceLanguage string `json:"detected_source_language,omitempty"`
Text string `json:"text,omitempty"`
}
type Response struct {
Translations []Translation `json:"translations,omitempty"`
Message string `json:"message,omitempty"`
}

View File

@@ -0,0 +1,13 @@
package doubao
// https://console.volcengine.com/ark/region:ark+cn-beijing/model
var ModelList = []string{
"Doubao-pro-128k",
"Doubao-pro-32k",
"Doubao-pro-4k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
}

View File

@@ -0,0 +1,14 @@
package doubao
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
}

View File

@@ -3,6 +3,9 @@ package gemini
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@@ -10,8 +13,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "github.com/songquanpeng/one-api/relay/relaymode"
"net/http"
) )
type Adaptor struct { type Adaptor struct {
@@ -22,10 +24,17 @@ func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, config.GeminiVersion) version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
action := "generateContent" action := ""
switch meta.Mode {
case relaymode.Embeddings:
action = "batchEmbedContents"
default:
action = "generateContent"
}
if meta.IsStream { if meta.IsStream {
action = "streamGenerateContent" action = "streamGenerateContent?alt=sse"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
} }
@@ -40,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return ConvertRequest(*request), nil switch relayMode {
case relaymode.Embeddings:
geminiEmbeddingRequest := ConvertEmbeddingRequest(*request)
return geminiEmbeddingRequest, nil
default:
geminiRequest := ConvertRequest(*request)
return geminiRequest, nil
}
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
@@ -60,7 +76,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else { } else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
} }
return return
} }

View File

@@ -4,5 +4,5 @@ package gemini
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
} }

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -134,6 +135,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
return &geminiRequest return &geminiRequest
} }
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
inputs := request.ParseInput()
requests := make([]EmbeddingRequest, len(inputs))
model := fmt.Sprintf("models/%s", request.Model)
for i, input := range inputs {
requests[i] = EmbeddingRequest{
Model: model,
Content: ChatContent{
Parts: []Part{
{
Text: input,
},
},
},
}
}
return &BatchEmbeddingRequest{
Requests: requests,
}
}
type ChatResponse struct { type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"` Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"` PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
@@ -222,81 +246,80 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText() choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason //choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse var response openai.ChatCompletionsStreamResponse
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Created = helper.GetTimestamp()
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "gemini" response.Model = "gemini"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response return &response
} }
func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: "gemini-embedding",
Usage: model.Usage{TotalTokens: 0},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: item.Values,
})
}
return &openAIEmbeddingResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
// this is used to prevent annoying \ related format bug data = strings.TrimSpace(data)
data = fmt.Sprintf("{\"content\": \"%s\"}", data) if !strings.HasPrefix(data, "data: ") {
type dummyStruct struct { continue
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
var geminiResponse ChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
responseText += response.Choices[0].Delta.StringContent()
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }
@@ -343,3 +366,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &usage return nil, &usage
} }
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var geminiEmbeddingResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if geminiEmbeddingResponse.Error != nil {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: geminiEmbeddingResponse.Error.Message,
Type: "gemini_error",
Param: "",
Code: geminiEmbeddingResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -7,6 +7,33 @@ type ChatRequest struct {
Tools []ChatTools `json:"tools,omitempty"` Tools []ChatTools `json:"tools,omitempty"`
} }
type EmbeddingRequest struct {
Model string `json:"model"`
Content ChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
type BatchEmbeddingRequest struct {
Requests []EmbeddingRequest `json:"requests"`
}
type EmbeddingData struct {
Values []float64 `json:"values"`
}
type EmbeddingResponse struct {
Embeddings []EmbeddingData `json:"embeddings"`
Error *Error `json:"error,omitempty"`
}
type Error struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Status string `json:"status,omitempty"`
}
type InlineData struct { type InlineData struct {
MimeType string `json:"mimeType"` MimeType string `json:"mimeType"`
Data string `json:"data"` Data string `json:"data"`

View File

@@ -1,7 +1,11 @@
package minimax package minimax
// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
var ModelList = []string{ var ModelList = []string{
"abab5.5s-chat", "abab6.5-chat",
"abab5.5-chat", "abab6.5s-chat",
"abab6-chat", "abab6-chat",
"abab5.5-chat",
"abab5.5s-chat",
} }

View File

@@ -0,0 +1,19 @@
package novita
// https://novita.ai/llm-api
var ModelList = []string{
"meta-llama/llama-3-8b-instruct",
"meta-llama/llama-3-70b-instruct",
"nousresearch/hermes-2-pro-llama-3-8b",
"nousresearch/nous-hermes-llama2-13b",
"mistralai/mistral-7b-instruct",
"cognitivecomputations/dolphin-mixtral-8x22b",
"sao10k/l3-70b-euryale-v2.1",
"sophosympatheia/midnight-rose-70b",
"gryphe/mythomax-l2-13b",
"Nous-Hermes-2-Mixtral-8x7B-DPO",
"lzlv_70b",
"teknium/openhermes-2.5-mistral-7b",
"microsoft/wizardlm-2-8x22b",
}

View File

@@ -0,0 +1,15 @@
package novita
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode)
}

View File

@@ -1,5 +1,11 @@
package ollama package ollama
var ModelList = []string{ var ModelList = []string{
"codellama:7b-instruct",
"llama2:7b",
"llama2:latest",
"llama3:latest",
"phi3:latest",
"qwen:0.5b-chat", "qwen:0.5b-chat",
"qwen:7b",
} }

View File

@@ -5,14 +5,17 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/common/random"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
@@ -32,9 +35,22 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
Stream: request.Stream, Stream: request.Stream,
} }
for _, message := range request.Messages { for _, message := range request.Messages {
openaiContent := message.ParseContent()
var imageUrls []string
var contentText string
for _, part := range openaiContent {
switch part.Type {
case model.ContentTypeText:
contentText = part.Text
case model.ContentTypeImageURL:
_, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
imageUrls = append(imageUrls, data)
}
}
ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ ollamaRequest.Messages = append(ollamaRequest.Messages, Message{
Role: message.Role, Role: message.Role,
Content: message.StringContent(), Content: contentText,
Images: imageUrls,
}) })
} }
return &ollamaRequest return &ollamaRequest
@@ -91,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "}\n"); i >= 0 { if i := strings.Index(string(data), "}\n"); i >= 0 {
return i + 2, data[0:i], nil return i + 2, data[0 : i+1], nil
} }
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := strings.TrimPrefix(scanner.Text(), "}")
dataChan <- data + "}"
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := strings.TrimPrefix(scanner.Text(), "}")
var ollamaResponse ChatResponse data = data + "}"
err := json.Unmarshal([]byte(data), &ollamaResponse)
if err != nil { var ollamaResponse ChatResponse
logger.SysError("error unmarshalling stream response: " + err.Error()) err := json.Unmarshal([]byte(data), &ollamaResponse)
return true if err != nil {
} logger.SysError("error unmarshalling stream response: " + err.Error())
if ollamaResponse.EvalCount != 0 { continue
usage.PromptTokens = ollamaResponse.PromptEvalCount
usage.CompletionTokens = ollamaResponse.EvalCount
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
}
response := streamResponseOllama2OpenAI(&ollamaResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
if ollamaResponse.EvalCount != 0 {
usage.PromptTokens = ollamaResponse.PromptEvalCount
usage.CompletionTokens = ollamaResponse.EvalCount
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
}
response := streamResponseOllama2OpenAI(&ollamaResponse)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }

View File

@@ -3,16 +3,19 @@ package openai
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
type Adaptor struct { type Adaptor struct {
@@ -29,13 +32,13 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ImagesGenerations { if meta.Mode == relaymode.ImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion)
return fullRequestURL, nil return fullRequestURL, nil
} }
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(meta.RequestURLPath, "?")[0] requestURL := strings.Split(meta.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion)
task := strings.TrimPrefix(requestURL, "/v1/") task := strings.TrimPrefix(requestURL, "/v1/")
model_ := meta.ActualModelName model_ := meta.ActualModelName
model_ = strings.Replace(model_, ".", "", -1) model_ = strings.Replace(model_, ".", "", -1)
@@ -45,6 +48,10 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case channeltype.Minimax: case channeltype.Minimax:
return minimax.GetRequestURL(meta) return minimax.GetRequestURL(meta)
case channeltype.Doubao:
return doubao.GetRequestURL(meta)
case channeltype.Novita:
return novita.GetRequestURL(meta)
default: default:
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
} }
@@ -86,9 +93,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText, usage = StreamHandler(c, resp, meta.Mode) err, responseText, usage = StreamHandler(c, resp, meta.Mode)
if usage == nil { if usage == nil || usage.TotalTokens == 0 {
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} }
if usage.TotalTokens != 0 && usage.PromptTokens == 0 { // some channels don't return prompt tokens & completion tokens
usage.PromptTokens = meta.PromptTokens
usage.CompletionTokens = usage.TotalTokens - meta.PromptTokens
}
} else { } else {
switch meta.Mode { switch meta.Mode {
case relaymode.ImagesGenerations: case relaymode.ImagesGenerations:

View File

@@ -4,12 +4,15 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/ai360" "github.com/songquanpeng/one-api/relay/adaptor/ai360"
"github.com/songquanpeng/one-api/relay/adaptor/baichuan" "github.com/songquanpeng/one-api/relay/adaptor/baichuan"
"github.com/songquanpeng/one-api/relay/adaptor/deepseek" "github.com/songquanpeng/one-api/relay/adaptor/deepseek"
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/groq" "github.com/songquanpeng/one-api/relay/adaptor/groq"
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu" "github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot" "github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
) )
@@ -19,11 +22,14 @@ var CompatibleChannels = []int{
channeltype.Moonshot, channeltype.Moonshot,
channeltype.Baichuan, channeltype.Baichuan,
channeltype.Minimax, channeltype.Minimax,
channeltype.Doubao,
channeltype.Mistral, channeltype.Mistral,
channeltype.Groq, channeltype.Groq,
channeltype.LingYiWanWu, channeltype.LingYiWanWu,
channeltype.StepFun, channeltype.StepFun,
channeltype.DeepSeek, channeltype.DeepSeek,
channeltype.TogetherAI,
channeltype.Novita,
} }
func GetCompatibleChannelMeta(channelType int) (string, []string) { func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -48,6 +54,12 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "stepfun", stepfun.ModelList return "stepfun", stepfun.ModelList
case channeltype.DeepSeek: case channeltype.DeepSeek:
return "deepseek", deepseek.ModelList return "deepseek", deepseek.ModelList
case channeltype.TogetherAI:
return "together.ai", togetherai.ModelList
case channeltype.Doubao:
return "doubao", doubao.ModelList
case channeltype.Novita:
return "novita", novita.ModelList
default: default:
return "openai", ModelList return "openai", ModelList
} }

View File

@@ -7,6 +7,7 @@ var ModelList = []string{
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

View File

@@ -4,15 +4,18 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
const ( const (
@@ -24,88 +27,72 @@ const (
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
responseText := "" responseText := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
var usage *model.Usage var usage *model.Usage
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format
continue
}
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
dataChan <- data
continue
}
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
dataChan <- data // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
dataChan <- data
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
dataChan <- data
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { doneRendered := false
case data := <-dataChan: for scanner.Scan() {
if strings.HasPrefix(data, "data: [DONE]") { data := scanner.Text()
data = data[:12] if len(data) < dataPrefixLength { // ignore blank line or wrong format
} continue
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
} }
}) if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data)
doneRendered = true
continue
}
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
render.StringData(c, data)
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
render.StringData(c, data)
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
if !doneRendered {
render.Done(c)
}
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
} }
return nil, responseText, usage return nil, responseText, usage
} }
@@ -149,7 +136,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 { if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += CountTokenText(choice.Message.StringContent(), modelName) completionTokens += CountTokenText(choice.Message.StringContent(), modelName)

View File

@@ -134,7 +134,7 @@ type ChatCompletionsStreamResponse struct {
Created int64 `json:"created"` Created int64 `json:"created"`
Model string `json:"model"` Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *model.Usage `json:"usage"` Usage *model.Usage `json:"usage,omitempty"`
} }
type CompletionsStreamResponse struct { type CompletionsStreamResponse struct {

View File

@@ -24,6 +24,10 @@ func InitTokenEncoders() {
logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
} }
defaultTokenEncoder = gpt35TokenEncoder defaultTokenEncoder = gpt35TokenEncoder
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
}
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil { if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
@@ -31,6 +35,8 @@ func InitTokenEncoders() {
for model := range billingratio.ModelRatio { for model := range billingratio.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = gpt4oTokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder tokenEncoderMap[model] = gpt4TokenEncoder
} else { } else {
@@ -206,3 +212,7 @@ func CountTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text) return getTokenNum(tokenEncoder, text)
} }
func CountToken(text string) int {
return CountTokenInput(text, "gpt-3.5-turbo")
}

View File

@@ -3,6 +3,10 @@ package palm
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@@ -11,8 +15,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
@@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
logger.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { responseBody, err := io.ReadAll(resp.Body)
case data := <-dataChan: if err != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + data}) logger.SysError("error reading stream response: " + err.Error())
return true err := resp.Body.Close()
case <-stopChan: if err != nil {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return false
} }
}) return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), ""
err := resp.Body.Close() }
err = resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), ""
}
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
}
err = render.ObjectData(c, string(jsonResponse))
if err != nil {
logger.SysError(err.Error())
}
render.Done(c)
return nil, responseText return nil, responseText
} }

View File

@@ -2,35 +2,43 @@ package tencent
import ( import (
"errors" "errors"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
"strconv"
"strings" "strings"
) )
// https://cloud.tencent.com/document/api/1729/101837 // https://cloud.tencent.com/document/api/1729/101837
type Adaptor struct { type Adaptor struct {
Sign string Sign string
Action string
Version string
Timestamp int64
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.Action = "ChatCompletions"
a.Version = "2023-09-01"
a.Timestamp = helper.GetTimestamp()
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil return meta.BaseURL + "/", nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", a.Sign) req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", meta.ActualModelName) req.Header.Set("X-TC-Action", a.Action)
req.Header.Set("X-TC-Version", a.Version)
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
return nil return nil
} }
@@ -40,15 +48,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := ParseConfig(apiKey) _, secretId, secretKey, err := ParseConfig(apiKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tencentRequest := ConvertRequest(*request) tencentRequest := ConvertRequest(*request)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
// we have to calculate the sign here // we have to calculate the sign here
a.Sign = GetSign(*tencentRequest, secretKey) a.Sign = GetSign(*tencentRequest, a, secretId, secretKey)
return tencentRequest, nil return tencentRequest, nil
} }

View File

@@ -1,7 +1,8 @@
package tencent package tencent
var ModelList = []string{ var ModelList = []string{
"ChatPro", "hunyuan-lite",
"ChatStd", "hunyuan-standard",
"hunyuan", "hunyuan-standard-256K",
"hunyuan-pro",
} }

View File

@@ -3,11 +3,18 @@ package tencent
import ( import (
"bufio" "bufio"
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha256"
"encoding/base64" "encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
@@ -17,36 +24,23 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"sort"
"strconv"
"strings"
) )
// https://cloud.tencent.com/document/product/1729/97732
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages)) messages := make([]*Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ { for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i] message := request.Messages[i]
messages = append(messages, Message{ messages = append(messages, &Message{
Content: message.StringContent(), Content: message.StringContent(),
Role: message.Role, Role: message.Role,
}) })
} }
stream := 0
if request.Stream {
stream = 1
}
return &ChatRequest{ return &ChatRequest{
Timestamp: helper.GetTimestamp(), Model: &request.Model,
Expired: helper.GetTimestamp() + 24*60*60, Stream: &request.Stream,
QueryID: random.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages, Messages: messages,
TopP: &request.TopP,
Temperature: &request.Temperature,
} }
} }
@@ -54,7 +48,11 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Usage: response.Usage, Usage: model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
} }
if len(response.Choices) > 0 { if len(response.Choices) > 0 {
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
@@ -91,69 +89,52 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
var responseText string var responseText string
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
var TencentResponse ChatResponse if len(data) < 5 || !strings.HasPrefix(data, "data:") {
err := json.Unmarshal([]byte(data), &TencentResponse) continue
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = strings.TrimPrefix(data, "data:")
var tencentResponse ChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 {
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var TencentResponse ChatResponse var TencentResponse ChatResponse
var responseP ChatResponseP
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -162,10 +143,11 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &TencentResponse) err = json.Unmarshal(responseBody, &responseP)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
TencentResponse = responseP.Response
if TencentResponse.Error.Code != 0 { if TencentResponse.Error.Code != 0 {
return &model.ErrorWithStatusCode{ return &model.ErrorWithStatusCode{
Error: model.Error{ Error: model.Error{
@@ -202,29 +184,62 @@ func ParseConfig(config string) (appId int64, secretId string, secretKey string,
return return
} }
func GetSign(req ChatRequest, secretKey string) string { func sha256hex(s string) string {
params := make([]string, 0) b := sha256.Sum256([]byte(s))
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) return hex.EncodeToString(b[:])
params = append(params, "secret_id="+req.SecretId) }
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID) func hmacSha256(s, key string) string {
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) hashed := hmac.New(sha256.New, []byte(key))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) hashed.Write([]byte(s))
params = append(params, "stream="+strconv.Itoa(req.Stream)) return string(hashed.Sum(nil))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) }
var messageStr string func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
for _, msg := range req.Messages { // build canonical request string
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) host := "hunyuan.tencentcloudapi.com"
} httpRequestMethod := "POST"
messageStr = strings.TrimSuffix(messageStr, ",") canonicalURI := "/"
params = append(params, "messages=["+messageStr+"]") canonicalQueryString := ""
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
sort.Strings(params) "application/json", host, strings.ToLower(adaptor.Action))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") signedHeaders := "content-type;host;x-tc-action"
mac := hmac.New(sha1.New, []byte(secretKey)) payload, _ := json.Marshal(req)
signURL := url hashedRequestPayload := sha256hex(string(payload))
mac.Write([]byte(signURL)) canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
sign := mac.Sum([]byte(nil)) httpRequestMethod,
return base64.StdEncoding.EncodeToString(sign) canonicalURI,
canonicalQueryString,
canonicalHeaders,
signedHeaders,
hashedRequestPayload)
// build string to sign
algorithm := "TC3-HMAC-SHA256"
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
t := time.Unix(timestamp, 0).UTC()
// must be the format 2006-01-02, ref to package time for more info
date := t.Format("2006-01-02")
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
hashedCanonicalRequest := sha256hex(canonicalRequest)
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
algorithm,
requestTimestamp,
credentialScope,
hashedCanonicalRequest)
// sign string
secretDate := hmacSha256(date, "TC3"+secKey)
secretService := hmacSha256("hunyuan", secretDate)
secretKey := hmacSha256("tc3_request", secretService)
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
// build authorization
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
algorithm,
secId,
credentialScope,
signedHeaders,
signature)
return authorization
} }

View File

@@ -1,63 +1,75 @@
package tencent package tencent
import (
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"Role"`
Content string `json:"content"` Content string `json:"Content"`
} }
type ChatRequest struct { type ChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID // 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
SecretId string `json:"secret_id"` // 官网 SecretId // 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 //
// 例如1529223702如果与当前时间相差过大会引起签名过期错误 // 注意:
Timestamp int64 `json:"timestamp"` // 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, Model *string `json:"Model"`
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 // 聊天上下文信息。
Expired int64 `json:"expired"` // 说明:
QueryID string `json:"query_id"` //请求 Id用于问题排查 // 1. 长度最多为 40按对话时间从旧到新在数组中排列。
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 // 2. Message.Role 可选值system、user、assistant。
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 // 其中system 角色可选如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system可选 user assistant user assistant user ...]。
// 建议该参数和 top_p 只设置1个不要同时更改 top_p // 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
Temperature float64 `json:"temperature"` Messages []*Message `json:"Messages"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 // 流式调用开关。
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 // 说明:
// 建议该参数和 temperature 只设置1个不要同时更改 // 1. 未传值时默认为非流式调用false
TopP float64 `json:"top_p"` // 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
// Stream 0同步1流式 默认协议SSE) // 3. 非流式调用时:
// 同步请求超时60s如果内容较长建议使用流式 // 调用方式与普通 HTTP 请求无异。
Stream int `json:"stream"` // 接口响应耗时较长,**如需更低时延建议设置为 true**。
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 // 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
// 输入 content 总数最大支持 3000 token。 //
Messages []Message `json:"messages"` // 注意:
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
Stream *bool `json:"Stream"`
// 说明:
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP"`
// 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature"`
} }
type Error struct { type Error struct {
Code int `json:"code"` Code int `json:"Code"`
Message string `json:"message"` Message string `json:"Message"`
} }
type Usage struct { type Usage struct {
InputTokens int `json:"input_tokens"` PromptTokens int `json:"PromptTokens"`
OutputTokens int `json:"output_tokens"` CompletionTokens int `json:"CompletionTokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"TotalTokens"`
} }
type ResponseChoices struct { type ResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 Messages Message `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 Delta Message `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
} }
type ChatResponse struct { type ChatResponse struct {
Choices []ResponseChoices `json:"choices,omitempty"` // 结果 Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串 Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id Id string `json:"Id,omitempty"` // 会话 id
Usage model.Usage `json:"usage,omitempty"` // token 数量 Usage Usage `json:"Usage,omitempty"` // token 数量
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值 Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释 Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参 ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
type ChatResponseP struct {
Response ChatResponse `json:"Response,omitempty"`
} }

View File

@@ -0,0 +1,10 @@
package togetherai
// https://docs.together.ai/docs/inference-models
var ModelList = []string{
"meta-llama/Llama-3-70b-chat-hf",
"deepseek-ai/deepseek-coder-33b-instruct",
"mistralai/Mixtral-8x22B-Instruct-v0.1",
"Qwen/Qwen1.5-72B-Chat",
}

View File

@@ -14,10 +14,11 @@ import (
type Adaptor struct { type Adaptor struct {
request *model.GeneralOpenAIRequest request *model.GeneralOpenAIRequest
meta *meta.Meta
} }
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
@@ -60,10 +61,18 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
if a.request == nil { if a.request == nil {
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
} }
version := parseAPIVersionByModelName(meta.ActualModelName)
if version == "" {
version = a.meta.Config.APIVersion
}
if version == "" {
version = "v1.1"
}
a.meta.Config.APIVersion = version
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2]) err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2])
} else { } else {
err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2]) err, usage = Handler(c, meta, *a.request, splits[0], splits[1], splits[2])
} }
return return
} }

View File

@@ -6,4 +6,5 @@ var ModelList = []string{
"SparkDesk-v2.1", "SparkDesk-v2.1",
"SparkDesk-v3.1", "SparkDesk-v3.1",
"SparkDesk-v3.5", "SparkDesk-v3.5",
"SparkDesk-v4.0",
} }

View File

@@ -5,22 +5,24 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
) )
// https://console.xfyun.cn/services/cbm // https://console.xfyun.cn/services/cbm
@@ -28,11 +30,7 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages)) messages := make([]Message, 0, len(request.Messages))
var lastToolCalls []model.Tool
for _, message := range request.Messages { for _, message := range request.Messages {
if message.ToolCalls != nil {
lastToolCalls = message.ToolCalls
}
messages = append(messages, Message{ messages = append(messages, Message{
Role: message.Role, Role: message.Role,
Content: message.StringContent(), Content: message.StringContent(),
@@ -45,9 +43,14 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages xunfeiRequest.Payload.Message.Text = messages
if len(lastToolCalls) != 0 {
for _, toolCall := range lastToolCalls { if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) functions := make([]model.Function, len(request.Tools))
for i, tool := range request.Tools {
functions[i] = tool.Function
}
xunfeiRequest.Payload.Functions = &Functions{
Text: functions,
} }
} }
@@ -149,8 +152,8 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl return callUrl
} }
func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
@@ -179,8 +182,8 @@ func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId
return nil, &usage return nil, &usage
} }
func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
@@ -203,7 +206,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin
} }
} }
if len(xunfeiResponse.Payload.Choices.Text) == 0 { if len(xunfeiResponse.Payload.Choices.Text) == 0 {
return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil return openai.ErrorWrapper(errors.New("xunfei empty response detected"), "xunfei_empty_response_detected", http.StatusInternalServerError), nil
} }
xunfeiResponse.Payload.Choices.Text[0].Content = content xunfeiResponse.Payload.Choices.Text[0].Content = content
@@ -268,25 +271,12 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
return dataChan, stopChan, nil return dataChan, stopChan, nil
} }
func getAPIVersion(c *gin.Context, modelName string) string { func parseAPIVersionByModelName(modelName string) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion != "" {
return apiVersion
}
parts := strings.Split(modelName, "-") parts := strings.Split(modelName, "-")
if len(parts) == 2 { if len(parts) == 2 {
apiVersion = parts[1] return parts[1]
return apiVersion
} }
apiVersion = c.GetString(ctxkey.ConfigAPIVersion) return ""
if apiVersion != "" {
return apiVersion
}
apiVersion = "v1.1"
logger.SysLog("api_version not found, using default: " + apiVersion)
return apiVersion
} }
// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E // https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
@@ -300,12 +290,13 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3" return "generalv3"
case "v3.5": case "v3.5":
return "generalv3.5" return "generalv3.5"
case "v4.0":
return "4.0Ultra"
} }
return "general" + apiVersion return "general" + apiVersion
} }
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) { func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
apiVersion := getAPIVersion(c, modelName)
domain := apiVersion2domain(apiVersion) domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl return domain, authUrl

View File

@@ -9,6 +9,10 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
} }
type Functions struct {
Text []model.Function `json:"text,omitempty"`
}
type ChatRequest struct { type ChatRequest struct {
Header struct { Header struct {
AppId string `json:"app_id"` AppId string `json:"app_id"`
@@ -26,9 +30,7 @@ type ChatRequest struct {
Message struct { Message struct {
Text []Message `json:"text"` Text []Message `json:"text"`
} `json:"message"` } `json:"message"`
Functions struct { Functions *Functions `json:"functions,omitempty"`
Text []model.Function `json:"text,omitempty"`
} `json:"functions,omitempty"`
} `json:"payload"` } `json:"payload"`
} }

View File

@@ -62,8 +62,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
switch relayMode { switch relayMode {
case relaymode.Embeddings: case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, err
default: default:
// TopP (0.0, 1.0) // TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP) request.TopP = math.Min(0.99, request.TopP)
@@ -129,11 +129,15 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
return return
} }
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) (*EmbeddingRequest, error) {
return &EmbeddingRequest{ inputs := request.ParseInput()
Model: "embedding-2", if len(inputs) != 1 {
Input: request.Input.(string), return nil, errors.New("invalid input length, zhipu only support one input")
} }
return &EmbeddingRequest{
Model: request.Model,
Input: inputs[0],
}, nil
} }
func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetModelList() []string {

View File

@@ -3,6 +3,13 @@ package zhipu
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
@@ -11,11 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"sync"
"time"
) )
// https://open.bigmodel.cn/doc/api#chatglm_std // https://open.bigmodel.cn/doc/api#chatglm_std
@@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
metaChan := make(chan string) common.SetEventStreamHeaders(c)
stopChan := make(chan bool)
go func() { for scanner.Scan() {
for scanner.Scan() { data := scanner.Text()
data := scanner.Text() lines := strings.Split(data, "\n")
lines := strings.Split(data, "\n") for i, line := range lines {
for i, line := range lines { if len(line) < 5 {
if len(line) < 5 { continue
}
if strings.HasPrefix(line, "data:") {
dataSegment := line[5:]
if i != len(lines)-1 {
dataSegment += "\n"
}
response := streamResponseZhipu2OpenAI(dataSegment)
err := render.ObjectData(c, response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
}
} else if strings.HasPrefix(line, "meta:") {
metaSegment := line[5:]
var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(metaSegment), &zhipuResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue continue
} }
if line[:5] == "data:" { response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
dataChan <- line[5:] err = render.ObjectData(c, response)
if i != len(lines)-1 { if err != nil {
dataChan <- "\n" logger.SysError("error marshalling stream response: " + err.Error())
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
} }
usage = zhipuUsage
} }
} }
stopChan <- true }
}()
common.SetEventStreamHeaders(c) if err := scanner.Err(); err != nil {
c.Stream(func(w io.Writer) bool { logger.SysError("error reading stream: " + err.Error())
select { }
case data := <-dataChan:
response := streamResponseZhipu2OpenAI(data) render.Done(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case data := <-metaChan:
var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, usage return nil, usage
} }

16
relay/adaptor_test.go Normal file
View File

@@ -0,0 +1,16 @@
package relay
import (
. "github.com/smartystreets/goconvey/convey"
"github.com/songquanpeng/one-api/relay/apitype"
"testing"
)
func TestGetAdaptor(t *testing.T) {
Convey("get adaptor", t, func() {
for i := 0; i < apitype.Dummy; i++ {
a := GetAdaptor(i)
So(a, ShouldNotBeNil)
}
})
}

View File

@@ -15,6 +15,8 @@ const (
AwsClaude AwsClaude
Coze Coze
Cohere Cohere
Cloudflare
DeepL
Dummy // this one is only for count, do not add any channel after this Dummy // this one is only for count, do not add any channel after this
) )

Some files were not shown because too many files have changed in this diff Show More