Compare commits

...

133 Commits

Author SHA1 Message Date
JustSong
2dcef85285 feat: support ollama now (close #870) 2024-03-14 01:02:47 +08:00
JustSong
79d0cd378a fix: fix baidu system prompt (close #1079) 2024-03-13 22:56:54 +08:00
JustSong
e99150bdb9 fix: make quota int64 2024-03-13 20:00:51 +08:00
JustSong
a72e5fcc9e fix: when cached quota is too low, force refresh it 2024-03-13 19:38:44 +08:00
JustSong
0710f8cd66 fix: when cached quota is too low, force refresh it 2024-03-13 19:26:24 +08:00
JustSong
49cad7d4a5 feat: update func ShouldDisableChannel for claude 2024-03-13 19:11:30 +08:00
JustSong
a90161cf00 chore: drop idx_channels_key on start 2024-03-11 02:24:58 +08:00
sparanoid
a45fc7d736 fix: model name typo (#1109) 2024-03-11 00:44:49 +08:00
JustSong
45940dcb12 chore: add more info for panic fix 2024-03-10 23:59:35 +08:00
JustSong
969042b001 chore: only use one log file (close #1116) 2024-03-10 23:44:48 +08:00
JustSong
7e7369dbc4 fix: only disable channel when allowed 2024-03-10 23:41:16 +08:00
JustSong
e54e647170 chore: remove useless code 2024-03-10 23:36:29 +08:00
JustSong
358920c858 fix: remove index idx_channels_key (close #644) 2024-03-10 23:27:22 +08:00
JustSong
1ea598c773 feat: check claude's error response 2024-03-10 20:39:55 +08:00
JustSong
796be42487 feat: update ratio config if missing 2024-03-10 19:29:42 +08:00
JustSong
5b50eb94e5 feat: able to send alert message via message pusher (close #993) 2024-03-10 19:16:06 +08:00
JustSong
71c61365eb feat: able to only test disabled channels (#1090) 2024-03-10 18:34:57 +08:00
JustSong
b09f979b80 fix: add missing turnstile setup (close #1015) 2024-03-10 18:15:24 +08:00
JustSong
12440874b0 feat: able to disable channel by success rate 2024-03-10 17:57:47 +08:00
JustSong
6ebc99460e fix: add user to blacklist when it's banned or deleted, and make deletion soft (close #473, close #791) 2024-03-10 15:56:19 +08:00
JustSong
27ad8bfb98 feat: able to search channel type now 2024-03-10 15:00:33 +08:00
JustSong
8388aa537f chore: able to search channel now 2024-03-10 14:59:57 +08:00
JustSong
2346bf70af fix: check response type when expect stream response 2024-03-10 14:59:40 +08:00
JustSong
f05b403ca5 feat: use real system prompt now (close #1079) 2024-03-10 14:32:30 +08:00
JustSong
b33616df44 feat: support groq now (close #1087) 2024-03-10 14:09:44 +08:00
JustSong
cf16f44970 feat: load channel models from server 2024-03-09 02:28:23 +08:00
JustSong
bf2e26a48f feat: support claude-3 (close #1080, close #1094) 2024-03-09 01:12:47 +08:00
momomobinx
4fb22ad4ce feat: support third part models of baidu (#1046)
百度千帆平台上的第三方大模型调用
2024-03-03 23:50:28 +08:00
JustSong
95cfb8e8c9 fix: using the first available model if default model is not found (close #1021) 2024-03-03 22:58:41 +08:00
JustSong
c6ace985c2 fix: set missing ali parameters (close #1028) 2024-03-03 22:51:01 +08:00
JustSong
10a926b8f3 feat: only use the top priority when first retry (#1048) 2024-03-03 22:16:34 +08:00
JustSong
2df877a352 feat: switch priority when retry (close #1048) 2024-03-03 22:14:07 +08:00
JustSong
9d8967f7d3 feat: support Mistral's models now (close #1051) 2024-03-03 21:46:45 +08:00
JustSong
b35f3523d3 feat: add gemini model alias (close #1064) 2024-03-03 21:03:04 +08:00
JustSong
82e916b5ff fix: fix azure test (close #1069) 2024-03-03 20:51:28 +08:00
JustSong
de18d6fe16 refactor: refactor image relay (close #1068) 2024-03-03 19:30:11 +08:00
JustSong
1d0b7fb5ae feat: support chatglm-4 (close #1045, close #952, close #952, close #943) 2024-03-02 03:05:25 +08:00
JustSong
f9490bb72e fix: able to use updated default ratio 2024-03-02 01:32:04 +08:00
JustSong
76467285e8 docs: update readme 2024-03-02 01:25:21 +08:00
JustSong
df1fd9aa81 feat: support minimax's models now (close #354) 2024-03-02 01:24:28 +08:00
JustSong
614c2e0442 feat: support baichuan's models now (close #1057) 2024-03-02 00:55:48 +08:00
JustSong
eac6a0b9aa fix: fix version is blank 2024-03-02 00:03:29 +08:00
JustSong
b747cdbc6f fix: fix getAndValidateTextRequest failed: unexpected end of JSON input (close #1043) 2024-02-26 22:52:16 +08:00
JustSong
6b27d6659a fix: add role for ChatCompletionsStreamResponseChoice.Delta 2024-02-25 19:49:22 +08:00
JustSong
dc5b781191 fix: fix stream response id 2024-02-25 19:47:59 +08:00
JustSong
c880b4a9a3 fix: fix missing index in ChatCompletionsStreamResponseChoice (#1037) 2024-02-25 19:17:37 +08:00
JustSong
565ea58e68 feat: built in retry supported (close #1036, close #770) 2024-02-25 19:01:49 +08:00
JustSong
f141a37a9e fix: fix "error update user quota cache: Error 1040: Too many connections" 2024-02-25 16:58:14 +08:00
JustSong
5b78886ad3 fix: fix i18n 2024-02-25 16:53:46 +08:00
JustSong
87c7c4f0e6 fix: rm history build before building 2024-02-25 02:07:34 +08:00
JustSong
4c4a873890 fix: add an ending line for THEMES 2024-02-25 01:59:40 +08:00
JustSong
0664bdfda1 fix: fix build.sh (close #1026) 2024-02-25 01:53:27 +08:00
JustSong
32387d9c20 fix: fix version is blank 2024-02-21 22:21:01 +08:00
JustSong
bd888f2eb7 fix: fix prompt token is zero (close #1023) 2024-02-21 22:19:42 +08:00
JustSong
cece77e533 fix: fix model list 2024-02-19 22:20:18 +08:00
JustSong
2a5468e23c refactor: remove useless button (close #1014) 2024-02-18 22:21:37 +08:00
JustSong
d0e415893b fix: fix SparkDesk model name 2024-02-18 17:16:11 +08:00
JustSong
6cf5ce9a7a fix: fix SparkDesk model name 2024-02-18 17:11:16 +08:00
JustSong
f598b9df87 feat: add new SparkDesk models 2024-02-18 17:02:36 +08:00
JustSong
532c50d212 fix: fix channel table page copy 2024-02-18 16:19:14 +08:00
JustSong
2acc2f5017 feat: support moonshot now (close #804) 2024-02-18 16:17:19 +08:00
JustSong
604ac56305 fix: set seed parameter for qwen (close #1005) 2024-02-18 15:01:09 +08:00
JustSong
9383b638a6 feat: add ChatPro & ChatStd for tencent (#1010) 2024-02-18 14:40:01 +08:00
JustSong
28d512a675 refactor: delete useless code 2024-02-18 02:23:31 +08:00
JustSong
de9a58ca0b refactor: use config field to save config 2024-02-18 02:22:50 +08:00
JustSong
1aa374ccfb refactor: use adaptor to do relay & test 2024-02-18 00:15:31 +08:00
Laisky.Cai
d548a01c59 feat: Handle errors, validate model names, and calculate quota usage (#978)
- Improved error handling in various modules for better stability and responsiveness.
- Optimized code in several files for improved efficiency and readability.
- Enhanced user experience by providing more detailed error responses in the controller.
- Strengthened security by ignoring sensitive files in `.gitignore`.
2024-02-12 21:35:40 +08:00
JustSong
2cd1a78203 chore: update module name 2024-01-28 19:38:58 +08:00
JustSong
b9d3cb0c45 refactor: split RelayTextHelper function 2024-01-28 19:14:46 +08:00
JustSong
ea407f0054 feat: able to set completion ration now (close #968) 2024-01-28 16:45:54 +08:00
Benny
26e2e646cb feat: sync models with OpenAI (#971)
* add new 0125 chat models and embedding-3 models

* refine the step of manually deploying

* add gpt-4-turbo-preview
2024-01-28 16:09:21 +08:00
yongman
4f214c48c6 fix: fix primary chat button (#951)
Signed-off-by: yongman <yming0221@gmail.com>
2024-01-21 23:27:34 +08:00
JustSong
2d760d4a01 refactor: refactor relay part (#957)
* refactor: refactor relay part

* refactor: refactor config part
2024-01-21 23:21:42 +08:00
Buer
e2ed0399f0 fix: fix aff not effective (#937) 2024-01-20 12:38:06 +08:00
JustSong
eed9f5fdf0 refactor: refactor relay part (#935) 2024-01-14 19:21:03 +08:00
JustSong
f2c51a494c feat: able to login via email (close #921) 2024-01-14 14:08:39 +08:00
Calcium-Ion
8a4d6f3327 fix: remove useless wrong index (#916)
* add sqlite busy_timeout=3000

* chore: update impl

* fix: fix JSON tag in Log struct

* fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费而导致令牌大额欠费

* Revert "fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费而导致令牌大额欠费"

This reverts commit f0ffe14437.

* fix: remove wrong index

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-14 13:48:16 +08:00
Buer
cf4e33cb12 fix: fix bugs with theme berry (#931)
* fix: home page & logo style issue

* improve: Enhanced user experience by improving the channel selection box

* fix: key cannot be activated after expiration
2024-01-14 13:22:31 +08:00
JustSong
5d60305570 docs: update readme (close #930) 2024-01-14 13:00:59 +08:00
JustSong
d062bc60e4 chore: update ui copy 2024-01-07 19:18:52 +08:00
JustSong
39c1882970 chore: add back THEMES 2024-01-07 19:01:31 +08:00
JustSong
9c42c7dfd9 docs: update theme readme 2024-01-07 18:59:26 +08:00
JustSong
903aaeded0 chore: revert change 2024-01-07 18:47:39 +08:00
JustSong
bdd4be562d chore: add theme validation 2024-01-07 18:44:26 +08:00
Buer
37afb313b5 fix: fix some issues with berry (#913)
* fix: login address error

* fix: Normal users display profile menu

* fix: remove redundant code
2024-01-07 18:39:15 +08:00
JustSong
c9ebcab8b8 fix: fix theme logging 2024-01-07 18:02:59 +08:00
JustSong
86261cc656 feat: able to change theme 2024-01-07 17:53:05 +08:00
JustSong
8491785c9d chore: remove useless logging 2024-01-07 17:13:24 +08:00
JustSong
e848a3f7fa fix: fix error loading stats (#912) 2024-01-07 17:10:59 +08:00
JustSong
318adf5985 chore: remove useless lines 2024-01-07 16:28:44 +08:00
JustSong
965d7fc3d2 fix: fix Dockerfile 2024-01-07 15:33:33 +08:00
JustSong
aa3f605894 fix: fix Dockerfile 2024-01-07 15:17:21 +08:00
JustSong
7b8eff1f22 fix: fix Dockerfile 2024-01-07 15:15:02 +08:00
JustSong
e80cd508ba fix: fix Dockerfile 2024-01-07 15:12:30 +08:00
JustSong
d37f836d53 fix: fix Dockerfile 2024-01-07 15:03:36 +08:00
JustSong
e0b2d1ae47 fix: fix Dockerfile 2024-01-07 14:55:04 +08:00
JustSong
797ead686b fix: fix Dockerfile 2024-01-07 14:42:23 +08:00
JustSong
0d22cf9ead fix: fix Dockerfile 2024-01-07 14:38:02 +08:00
Buer
48989d4a0b feat: add new theme berry (#860)
* feat: add theme berry

* docs: add development notes

* fix: fix blank page

* chore: update implementation

* fix: fix package.json

* chore: update ui copy

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-07 14:20:07 +08:00
Seven Yu
6227eee5bc fix: fix token validation exception handling #901
* fix: fix exception handling

1. add error log for ValidateUserToken
2. update en.json

* chore: update log

---------

Co-authored-by: seven.yu <seven.yu@dji.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-07 13:32:39 +08:00
JustSong
cbf8f07747 docs: fix logo 2024-01-01 21:19:37 +08:00
JustSong
4a96031ce6 docs: update readme 2024-01-01 21:14:45 +08:00
JustSong
92886093ae docs: update readme 2024-01-01 21:10:40 +08:00
JustSong
0c022f17cb chore: update theme related code 2024-01-01 20:25:53 +08:00
JustSong
83f95935de ci: fix Dockerfile & ci 2024-01-01 19:23:46 +08:00
JustSong
aa03c89133 feat: able to add more UI theme (#860) 2024-01-01 18:55:03 +08:00
JustSong
505817ca17 chore: update en.json 2024-01-01 17:46:45 +08:00
JustSong
cb5a3df616 fix: fix pr error (#888) 2024-01-01 17:40:10 +08:00
Laisky.Cai
7772064d87 fix: support base64 encoded image_url (#872)
- Add support for base64 encoded image in OpenAI's image_url

Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-01-01 17:38:35 +08:00
Seven Yu
c50c609565 fix: fix button copywriting (#880)
* feat: rename Channel button

* fix: update en.json

---------

Co-authored-by: seven.yu <seven.yu@dji.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-01 17:09:12 +08:00
Tailen
498dea2dbb feat: add support for davinci-002 and babbage-002 (#888) 2024-01-01 17:06:17 +08:00
Zhanliang Liu
c725cc8842 fix: base 64 encoded format support of gemini-pro-vision for field image_url/url (#878) 2024-01-01 17:00:23 +08:00
Tisfeng
af8908db54 feat: able to change gemini safety setting (#867)
* perf: adjust gemini safety settings, set BLOCK_NONE by default

* feat: able to adjust by env variable

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-01 16:42:19 +08:00
JustSong
d8029550f7 fix: do not consume user quota if failed (close #881) 2024-01-01 16:18:50 +08:00
JustSong
f44fbe3fe7 docs: update pr template 2023-12-24 19:24:59 +08:00
JustSong
1c8922153d feat: support gemini-vision-pro 2023-12-24 18:54:32 +08:00
Laisky.Cai
f3c07e1451 fix: openai response should contains model (#841)
* fix: openai response should contains `model`

- Update model attributes in `claudeHandler` for `relay-claude.go`
- Implement model type for fullTextResponse in `relay-gemini.go`
- Add new `Model` field to `OpenAITextResponse` struct in `relay.go`

* chore: set model name response for models

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:58:31 +08:00
Bryan
40ceb29e54 fix: fix SearchUsers not working if using PostgreSQL (#778)
* fix SearchUsers

* refactor: using UsingPostgreSQL as condition

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:42:00 +08:00
dependabot[bot]
0699ecd0af chore(deps): bump golang.org/x/crypto from 0.14.0 to 0.17.0 (#840)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.14.0 to 0.17.0.
- [Commits](https://github.com/golang/crypto/compare/v0.14.0...v0.17.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-12-24 16:29:48 +08:00
moondie
ee9e746520 feat: update ali stream implementation & enable internet search (#856)
* Update relay-ali.go: 改进stream模式,添加联网搜索能力

通义千问支持stream的增量模式,不需要每次去掉上次的前缀;实测qwen-max联网模式效果不错,添加了联网模式。如果别的模型有问题可以改为单独给qwen-max开放

* 删除"stream参数"

刚发现原来阿里api没有这个参数,上次误加了。

* refactor: only enable search when specified

* fix: remove custom suffix when get model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:17:21 +08:00
Buer
a763681c2e fix: fix base64 image parse error (#858) 2023-12-24 15:35:56 +08:00
JustSong
b7fcb319da chore: check if SESSION_SECRET equals to random_string 2023-12-20 22:50:50 +08:00
JustSong
67c64e71c8 fix: fix max_tokens check 2023-12-20 21:45:33 +08:00
JustSong
97030e27f8 fix: fix gemini panic (close #833) 2023-12-17 23:30:45 +08:00
JustSong
461f5dab56 docs: update readme 2023-12-17 22:25:03 +08:00
JustSong
af378c59af docs: update readme 2023-12-17 22:19:16 +08:00
ShinChven ✨
bc6769826b feat: add condition to validate n value for non-Azure channels (#775)
- Add a condition to validate the n value only for non-Azure channels, ensuring it falls within the acceptable range.
- Fix Azure compatibility
2023-12-17 19:49:08 +08:00
Oliver Lee
0fe26cc4bd feat: update ali relay implementation (#830)
* 修改通译千问最新接口:1.删除history参数,改用官方推荐的messages参数 2.整理messages参数顺序,补充必要上下文信息 3.用autogen调试测试通过

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 19:43:23 +08:00
Calcium-Ion
7d6a169669 feat: able to set sqlite busy_timeout (#818)
* add sqlite busy_timeout=3000

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 19:17:00 +08:00
Ghostz
66f06e5d6f feat: reset image num to 1 when not given (#821)
* Update relay-image.go

* fix: reset image num to 1 when not given

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 18:54:08 +08:00
JustSong
6acb9537a9 fix: try to return a more meaningful error message (close #817) 2023-12-17 18:33:27 +08:00
JustSong
7069c49bdf fix: fix xunfei panic error (close #820) 2023-12-17 18:06:37 +08:00
JustSong
58dee76bf7 fix: fix Gemini stream problem 2023-12-17 16:16:18 +08:00
372 changed files with 21233 additions and 4796 deletions

View File

@@ -7,6 +7,11 @@ on:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
release:
runs-on: ubuntu-latest
@@ -23,8 +28,8 @@ jobs:
CI: ""
run: |
cd web
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
git describe --tags > VERSION
REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh
cd ..
- name: Set up Go
uses: actions/setup-go@v3
@@ -33,7 +38,7 @@ jobs:
- name: Build Backend (amd64)
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
- name: Build Backend (arm64)
run: |

View File

@@ -7,6 +7,11 @@ on:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
release:
runs-on: macos-latest
@@ -23,8 +28,8 @@ jobs:
CI: ""
run: |
cd web
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
git describe --tags > VERSION
REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh
cd ..
- name: Set up Go
uses: actions/setup-go@v3
@@ -33,7 +38,7 @@ jobs:
- name: Build Backend
run: |
go mod download
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')

View File

@@ -7,6 +7,11 @@ on:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
release:
runs-on: windows-latest
@@ -25,10 +30,10 @@ jobs:
env:
CI: ""
run: |
cd web
cd web/default
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
cd ..
cd ../..
- name: Set up Go
uses: actions/setup-go@v3
with:
@@ -36,7 +41,7 @@ jobs:
- name: Build Backend
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')

3
.gitignore vendored
View File

@@ -6,4 +6,5 @@ upload
build
*.db-journal
logs
data
data
/web/node_modules

View File

@@ -1,10 +1,15 @@
FROM node:16 as builder
WORKDIR /build
COPY web/package.json .
RUN npm install
COPY ./web .
WORKDIR /web
COPY ./VERSION .
COPY ./web .
WORKDIR /web/default
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
WORKDIR /web/berry
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2
@@ -17,8 +22,8 @@ WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine
@@ -30,4 +35,4 @@ RUN apk update \
COPY --from=builder2 /build/one-api /
EXPOSE 3000
WORKDIR /data
ENTRYPOINT ["/one-api"]
ENTRYPOINT ["/one-api"]

View File

@@ -3,7 +3,7 @@
</p>
<p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p>
<div align="center">
@@ -134,12 +134,12 @@ The initial account username is `root` and password is `123456`.
git clone https://github.com/songquanpeng/one-api.git
# Build the frontend
cd one-api/web
cd one-api/web/default
npm install
npm run build
# Build the backend
cd ..
cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
```

View File

@@ -3,7 +3,7 @@
</p>
<p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p>
<div align="center">
@@ -135,12 +135,12 @@ sudo service nginx restart
git clone https://github.com/songquanpeng/one-api.git
# フロントエンドのビルド
cd one-api/web
cd one-api/web/default
npm install
npm run build
# バックエンドのビルド
cd ..
cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
```

View File

@@ -4,7 +4,7 @@
<p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p>
<div align="center">
@@ -67,19 +67,20 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ [x] [Mistral 系列模型](https://mistral.ai/)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
2. 支持配置镜像以及众多第三方代理服务:
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [CloseAI](https://referer.shadowai.xyz/r/2412)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
+ [x] [Moonshot AI](https://platform.moonshot.cn/)
+ [x] [百川大模型](https://platform.baichuan-ai.com)
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ [x] [MINIMAX](https://api.minimax.chat/)
+ [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**[详见此处](#多机部署)。
@@ -105,6 +106,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
## 部署
### 基于 Docker 进行部署
@@ -179,12 +182,12 @@ docker-compose ps
git clone https://github.com/songquanpeng/one-api.git
# 构建前端
cd one-api/web
cd one-api/web/default
npm install
npm run build
# 构建后端
cd ..
cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
````
@@ -371,6 +374,12 @@ graph LR
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
16. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
17. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -416,6 +425,9 @@ https://openai.justsong.cn
8. 升级之前数据库需要做变更吗?
+ 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`
+ 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。
+ 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统

29
common/blacklist/main.go Normal file
View File

@@ -0,0 +1,29 @@
package blacklist
import (
"fmt"
"sync"
)
var blackList sync.Map
func init() {
blackList = sync.Map{}
}
func userId2Key(id int) string {
return fmt.Sprintf("userid_%d", id)
}
func BanUser(id int) {
blackList.Store(userId2Key(id), true)
}
func UnbanUser(id int) {
blackList.Delete(userId2Key(id))
}
func IsUserBanned(id int) bool {
_, ok := blackList.Load(userId2Key(id))
return ok
}

137
common/config/config.go Normal file
View File

@@ -0,0 +1,137 @@
package config
import (
"github.com/songquanpeng/one-api/common/env"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var MessagePusherAddress = ""
var MessagePusherToken = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser int64 = 0
var QuotaForInviter int64 = 0
var QuotaForInvitee int64 = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold int64 = 1000
var PreConsumedQuota int64 = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = env.String("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
}
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)

View File

@@ -1,106 +1,9 @@
package common
import (
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
import "time"
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const (
RoleGuestUser = 0
@@ -109,37 +12,10 @@ const (
RoleRootUser = 100
)
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
ImageUploadPermission = RoleGuestUser
ImageDownloadPermission = RoleGuestUser
)
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
const (
@@ -163,57 +39,79 @@ const (
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeDummy
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", //24
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
}
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

View File

@@ -1,6 +1,12 @@
package common
import (
"github.com/songquanpeng/one-api/common/env"
)
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -15,10 +15,7 @@ type embedFileSystem struct {
func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path)
if err != nil {
return false
}
return true
return err == nil
}
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {

42
common/env/helper.go vendored Normal file
View File

@@ -0,0 +1,42 @@
package env
import (
"os"
"strconv"
)
func Bool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func Int(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
return defaultValue
}
return num
}
func Float64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
return defaultValue
}
return num
}
func String(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@@ -8,12 +8,24 @@ import (
"strings"
)
func UnmarshalBodyReusable(c *gin.Context, v any) error {
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return err
return nil, err
}
err = c.Request.Body.Close()
_ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
if err != nil {
return err
}
@@ -31,3 +43,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@@ -1,6 +1,9 @@
package common
import "encoding/json"
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
)
var GroupRatio = map[string]float64{
"default": 1,
@@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
SysError("group ratio not found: " + name)
logger.SysError("group ratio not found: " + name)
return 1
}
return ratio

217
common/helper/helper.go Normal file
View File

@@ -0,0 +1,217 @@
package helper
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter := inter.(type) {
case string:
return inter
case int:
return fmt.Sprintf("%d", inter)
case float64:
return fmt.Sprintf("%f", inter)
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func GenRequestID() string {
return GetTimeString() + GetRandomNumberString(8)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -1,6 +1,8 @@
package image
import (
"bytes"
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
@@ -8,11 +10,30 @@ import (
"net/http"
"regexp"
"strings"
"sync"
_ "golang.org/x/image/webp"
)
// Regex to match data URL pattern
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
@@ -25,17 +46,60 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
return img.Width, img.Height, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
// Check if the URL is a data URL
matches := dataURLPattern.FindStringSubmatch(url)
if len(matches) == 3 {
// URL is a data URL
mimeType = "image/" + matches[1]
data = matches[2]
return
}
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
var readerPool = sync.Pool{
New: func() interface{} {
return &bytes.Reader{}
},
}
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
encoded = strings.TrimPrefix(encoded, "data:image/png;base64,")
base64 := strings.NewReader(reg.ReplaceAllString(encoded, ""))
img, _, err := image.DecodeConfig(base64)
decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, ""))
if err != nil {
return
return 0, 0, err
}
reader := readerPool.Get().(*bytes.Reader)
defer readerPool.Put(reader)
reader.Reset(decoded)
img, _, err := image.DecodeConfig(reader)
if err != nil {
return 0, 0, err
}
return img.Width, img.Height, nil
}

View File

@@ -12,7 +12,7 @@ import (
"strings"
"testing"
img "one-api/common/image"
img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
@@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) {
})
}
}
func TestGetImageSizeFromBase64(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
width, height, err := img.GetImageSizeFromBase64(encoded)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

View File

@@ -3,6 +3,8 @@ package common
import (
"flag"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"log"
"os"
"path/filepath"
@@ -36,7 +38,11 @@ func init() {
}
if os.Getenv("SESSION_SECRET") != "" {
SessionSecret = os.Getenv("SESSION_SECRET")
if os.Getenv("SESSION_SECRET") == "random_string" {
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
config.SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")
@@ -53,5 +59,6 @@ func init() {
log.Fatal(err)
}
}
logger.LogDir = *LogDir
}
}

View File

@@ -0,0 +1,7 @@
package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string

View File

@@ -1,9 +1,11 @@
package common
package logger
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"io"
"log"
"os"
@@ -13,19 +15,17 @@ import (
)
const (
loggerDEBUG = "DEBUG"
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
if *LogDir != "" {
if LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
@@ -35,7 +35,7 @@ func SetupLogger() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -55,29 +55,52 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func LogInfo(ctx context.Context, msg string) {
func Debug(ctx context.Context, msg string) {
if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg)
}
}
func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func Debugf(ctx context.Context, format string, a ...any) {
Debug(ctx, fmt.Sprintf(format, a...))
}
func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a...))
}
func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a...))
}
func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a...))
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
if id == nil {
id = helper.GenRequestID()
}
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
if !setupLogWorking {
setupLogWorking = true
go func() {
SetupLogger()
@@ -90,11 +113,3 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
func LogQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
}

View File

@@ -1,23 +1,27 @@
package common
package message
import (
"crypto/rand"
"crypto/tls"
"encoding/base64"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount
if receiver == "" {
return fmt.Errorf("receiver is empty")
}
if config.SMTPFrom == "" { // for compatibility
config.SMTPFrom = config.SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(SMTPFrom, "@")
parts := strings.Split(config.SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
@@ -36,21 +40,21 @@ func SendEmail(subject string, receiver string, content string) error {
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
if SMTPPort == 465 {
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: SMTPServer,
ServerName: config.SMTPServer,
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil {
return err
}
client, err := smtp.NewClient(conn, SMTPServer)
client, err := smtp.NewClient(conn, config.SMTPServer)
if err != nil {
return err
}
@@ -58,7 +62,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err = client.Auth(auth); err != nil {
return err
}
if err = client.Mail(SMTPFrom); err != nil {
if err = client.Mail(config.SMTPFrom); err != nil {
return err
}
receiverEmails := strings.Split(receiver, ";")
@@ -80,7 +84,7 @@ func SendEmail(subject string, receiver string, content string) error {
return err
}
} else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
}
return err
}

22
common/message/main.go Normal file
View File

@@ -0,0 +1,22 @@
package message
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
)
const (
ByAll = "all"
ByEmail = "email"
ByMessagePusher = "message_pusher"
)
func Notify(by string, title string, description string, content string) error {
if by == ByEmail {
return SendEmail(title, config.RootUserEmail, content)
}
if by == ByMessagePusher {
return SendMessage(title, description, content)
}
return fmt.Errorf("unknown notify method: %s", by)
}

View File

@@ -0,0 +1,53 @@
package message
import (
"bytes"
"encoding/json"
"errors"
"github.com/songquanpeng/one-api/common/config"
"net/http"
)
type request struct {
Title string `json:"title"`
Description string `json:"description"`
Content string `json:"content"`
URL string `json:"url"`
Channel string `json:"channel"`
Token string `json:"token"`
}
type response struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func SendMessage(title string, description string, content string) error {
if config.MessagePusherAddress == "" {
return errors.New("message pusher address is not set")
}
req := request{
Title: title,
Description: description,
Content: content,
Token: config.MessagePusherToken,
}
data, err := json.Marshal(req)
if err != nil {
return err
}
resp, err := http.Post(config.MessagePusherAddress,
"application/json", bytes.NewBuffer(data))
if err != nil {
return err
}
var res response
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return err
}
if !res.Success {
return errors.New(res.Message)
}
return nil
}

View File

@@ -2,107 +2,176 @@ package common
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
"strings"
"time"
)
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
const (
USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB
)
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"claude-2.0": 5.51, // $11.02 / 1M tokens
"claude-2.1": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // 0.12 / 1k tokens
"Embedding-V1": 0.1429, // 0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
// https://openai.com/pricing
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-ada-002": 0.05,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240229": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ChatStd": 0.01 * RMB,
"ChatPro": 0.1 * RMB,
// https://platform.moonshot.cn/pricing
"moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB,
"moonshot-v1-128k": 0.06 * RMB,
// https://platform.baichuan-ai.com/price
"Baichuan2-Turbo": 0.008 * RMB,
"Baichuan2-Turbo-192k": 0.016 * RMB,
"Baichuan2-53B": 0.02 * RMB,
// https://api.minimax.chat/document/price
"abab6-chat": 0.1 * RMB,
"abab5.5-chat": 0.015 * RMB,
"abab5.5s-chat": 0.005 * RMB,
// https://docs.mistral.ai/platform/pricing/
"open-mistral-7b": 0.25 / 1000 * USD,
"open-mixtral-8x7b": 0.7 / 1000 * USD,
"mistral-small-latest": 2.0 / 1000 * USD,
"mistral-medium-latest": 2.7 / 1000 * USD,
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/
"llama2-70b-4096": 0.7 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
}
var CompletionRatio = map[string]float64{}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
func init() {
DefaultModelRatio = make(map[string]float64)
for k, v := range ModelRatio {
DefaultModelRatio[k] = v
}
DefaultCompletionRatio = make(map[string]float64)
for k, v := range CompletionRatio {
DefaultCompletionRatio[k] = v
}
}
func AddNewMissingRatio(oldRatio string) string {
newRatio := make(map[string]float64)
err := json.Unmarshal([]byte(oldRatio), &newRatio)
if err != nil {
logger.SysError("error unmarshalling old ratio: " + err.Error())
return oldRatio
}
for k, v := range DefaultModelRatio {
if _, ok := newRatio[k]; !ok {
newRatio[k] = v
}
}
jsonBytes, err := json.Marshal(newRatio)
if err != nil {
logger.SysError("error marshalling new ratio: " + err.Error())
return oldRatio
}
return string(jsonBytes)
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -113,16 +182,46 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}
func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)
ratio, ok = DefaultModelRatio[name]
}
if !ok {
logger.SysError("model ratio not found: " + name)
return 30
}
return ratio
}
func CompletionRatio2JSONString() string {
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
logger.SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[name]; ok {
return ratio
}
if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3
}
if strings.HasSuffix(name, "1106") {
return 2
}
@@ -135,7 +234,7 @@ func GetCompletionRatio(name string) float64 {
return 2
}
}
return 1.333333
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
@@ -143,11 +242,18 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3.38
if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
if strings.HasPrefix(name, "claude-") {
return 3
}
if strings.HasPrefix(name, "mistral-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
}
return 1
}

8
common/random.go Normal file
View File

@@ -0,0 +1,8 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@@ -3,6 +3,7 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
"github.com/songquanpeng/one-api/common/logger"
"os"
"time"
)
@@ -14,18 +15,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
SysLog("Redis is enabled")
logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
@@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result()
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
logger.FatalLog("Redis ping test failed: " + err.Error())
}
return err
}
@@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
return opt
}

View File

@@ -2,208 +2,13 @@ package common
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
"github.com/songquanpeng/one-api/common/config"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
func LogQuota(quota int64) string {
if config.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter.(type) {
case string:
return inter.(string)
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
return fmt.Sprintf("%d 点额度", quota)
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -2,17 +2,18 @@ package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func GetSubscription(c *gin.Context) {
var remainQuota int
var usedQuota int
var remainQuota int64
var usedQuota int64
var err error
var token *model.Token
var expiredTime int64
if common.DisplayTokenStatEnabled {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
@@ -21,25 +22,27 @@ func GetSubscription(c *gin.Context) {
} else {
userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId)
usedQuota, err = model.GetUserUsedQuota(userId)
if err != nil {
usedQuota, err = model.GetUserUsedQuota(userId)
}
}
if expiredTime <= 0 {
expiredTime = 0
}
if err != nil {
openAIError := OpenAIError{
Error := relaymodel.Error{
Message: err.Error(),
Type: "upstream_error",
}
c.JSON(200, gin.H{
"error": openAIError,
"error": Error,
})
return
}
quota := remainQuota + usedQuota
amount := float64(quota)
if common.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit
if config.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit
}
if token != nil && token.UnlimitedQuota {
amount = 100000000
@@ -57,10 +60,10 @@ func GetSubscription(c *gin.Context) {
}
func GetUsage(c *gin.Context) {
var quota int
var quota int64
var err error
var token *model.Token
if common.DisplayTokenStatEnabled {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota
@@ -69,18 +72,18 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
openAIError := OpenAIError{
Error := relaymodel.Error{
Message: err.Error(),
Type: "one_api_error",
}
c.JSON(200, gin.H{
"error": openAIError,
"error": Error,
})
return
}
amount := float64(quota)
if common.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit
if config.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit
}
usage := OpenAIUsageResponse{
Object: "list",

View File

@@ -4,10 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
@@ -92,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := httpClient.Do(req)
res, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
@@ -292,7 +296,7 @@ func UpdateChannelBalance(c *gin.Context) {
}
func updateAllChannelsBalance() error {
channels, err := model.GetAllChannels(0, 0, true)
channels, err := model.GetAllChannels(0, 0, "all")
if err != nil {
return err
}
@@ -310,24 +314,23 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足")
monitor.DisableChannel(channel.Id, channel.Name, "余额不足")
}
}
time.Sleep(common.RequestInterval)
time.Sleep(config.RequestInterval)
}
return nil
}
func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async
err := updateAllChannelsBalance()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//err := updateAllChannelsBalance()
//if err != nil {
// c.JSON(http.StatusOK, gin.H{
// "success": false,
// "message": err.Error(),
// })
// return
//}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("updating all channels")
logger.SysLog("updating all channels")
_ = updateAllChannelsBalance()
common.SysLog("channels update done")
logger.SysLog("channels update done")
}
}

View File

@@ -5,98 +5,36 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"one-api/common"
"one-api/model"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeAnthropic:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default:
request.Model = "gpt-3.5-turbo"
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else {
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = baseURL
}
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
jsonData, err := json.Marshal(request)
if err != nil {
return err, nil
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err, nil
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return err, nil
}
defer resp.Body.Close()
var response TextResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
return err, nil
}
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
if response.Usage.CompletionTokens == 0 {
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"
}
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
return nil, nil
}
func buildTestRequest() *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 1,
Stream: false,
Model: "gpt-3.5-turbo",
}
testMessage := Message{
testMessage := relaymodel.Message{
Role: "user",
Content: "hi",
}
@@ -104,6 +42,72 @@ func buildTestRequest() *ChatRequest {
return testRequest
}
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := helper.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := adaptor.GetModelList()[0]
if !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
if resp.StatusCode != http.StatusOK {
err := util.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -121,9 +125,8 @@ func TestChannel(c *gin.Context) {
})
return
}
testRequest := buildTestRequest()
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
err, _ = testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@@ -147,35 +150,9 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
// enable & notify
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
func testChannels(notify bool, scope string) error {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
@@ -184,12 +161,11 @@ func testAllChannels(notify bool) error {
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
channels, err := model.GetAllChannels(0, 0, scope)
if err != nil {
return err
}
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
@@ -197,37 +173,45 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest)
err, openaiErr := testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
disableChannel(channel.Id, channel.Name, err.Error())
if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else {
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error())
}
}
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
disableChannel(channel.Id, channel.Name, err.Error())
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
time.Sleep(config.RequestInterval)
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
err := message.Notify(message.ByAll, "通道测试完成", "", "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(true)
func TestChannels(c *gin.Context) {
scope := c.Query("scope")
if scope == "" {
scope = "all"
}
err := testChannels(true, scope)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -245,8 +229,8 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")
_ = testAllChannels(false)
common.SysLog("channel test finished")
logger.SysLog("testing all channels")
_ = testChannels(false, "all")
logger.SysLog("channel test finished")
}
}

View File

@@ -2,9 +2,10 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
@@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 {
p = 0
}
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) {
})
return
}
channel.CreatedTime = common.GetTimestamp()
channel.CreatedTime = helper.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {

View File

@@ -7,9 +7,12 @@ import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
@@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
@@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
@@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
if !common.GitHubOAuthEnabled {
if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
} else {
if common.RegisterEnabled {
if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
}
func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled {
if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
state := helper.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {

View File

@@ -2,13 +2,13 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"net/http"
"one-api/common"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
for groupName := range common.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{

View File

@@ -2,9 +2,9 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
)
@@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -3,9 +3,11 @@ package controller
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
@@ -18,55 +20,55 @@ func GetStatus(c *gin.Context) {
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": config.WeChatAuthEnabled,
"server_address": config.ServerAddress,
"turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": config.TopUpLink,
"chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled,
},
})
return
}
func GetNotice(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
config.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["Notice"],
"data": config.OptionMap["Notice"],
})
return
}
func GetAbout(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
config.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["About"],
"data": config.OptionMap["About"],
})
return
}
func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
config.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["HomePageContent"],
"data": config.OptionMap["HomePageContent"],
})
return
}
@@ -80,9 +82,9 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
if common.EmailDomainRestrictionEnabled {
if config.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
@@ -105,11 +107,11 @@ func SendEmailVerification(c *gin.Context) {
}
code := common.GenerateVerificationCode(6)
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
"<p>您的验证码为: <strong>%s</strong></p>"+
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes)
err := message.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -142,13 +144,13 @@ func SendPasswordResetEmail(c *gin.Context) {
}
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes)
err := message.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -2,8 +2,14 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -35,6 +41,7 @@ type OpenAIModels struct {
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() {
var permission []OpenAIModelPermission
@@ -53,507 +60,63 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "dall-e-2",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-2",
Parent: nil,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-3",
Parent: nil,
},
{
Id: "whisper-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "whisper-1",
Parent: nil,
},
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-1106",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-instruct",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-instruct",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "gpt-4-1106-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-preview",
Parent: nil,
},
{
Id: "gpt-4-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-vision-preview",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
{
Id: "text-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-edit-001",
Parent: nil,
},
{
Id: "code-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
},
{
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "claude-2.1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.1",
Parent: nil,
},
{
Id: "claude-2.0",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.0",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_turbo",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-turbo",
Parent: nil,
},
{
Id: "qwen-plus",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
{
Id: "360GPT_S2_V9",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9",
Parent: nil,
},
{
Id: "embedding-bert-512-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding-bert-512-v1",
Parent: nil,
},
{
Id: "embedding_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding_s1_v1",
Parent: nil,
},
{
Id: "semantic_similarity_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "semantic_similarity_s1_v1",
Parent: nil,
},
{
Id: "hunyuan",
Object: "model",
Created: 1677649963,
OwnedBy: "tencent",
Permission: permission,
Root: "hunyuan",
Parent: nil,
},
for i := 0; i < constant.APITypeDummy; i++ {
if i == constant.APITypeAIProxyLibrary {
continue
}
adaptor := helper.GetAdaptor(i)
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
meta := &util.RelayMeta{
ChannelType: i,
}
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
}
func DashboardListModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelId2Models,
})
}
func ListModels(c *gin.Context) {
@@ -568,14 +131,14 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := OpenAIError{
Error := relaymodel.Error{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
"error": openAIError,
"error": Error,
})
}
}

View File

@@ -2,9 +2,10 @@ package controller
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
@@ -12,17 +13,17 @@ import (
func GetOptions(c *gin.Context) {
var options []*model.Option
common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap {
config.OptionMapRWMutex.Lock()
for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue
}
options = append(options, &model.Option{
Key: k,
Value: common.Interface2String(v),
Value: helper.Interface2String(v),
})
}
common.OptionMapRWMutex.Unlock()
config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -42,8 +43,16 @@ func UpdateOption(c *gin.Context) {
return
}
switch option.Key {
case "Theme":
if !config.ValidThemes[option.Value] {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的主题",
})
return
}
case "GitHubOAuthEnabled":
if option.Value == "true" && common.GitHubClientId == "" {
if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret",
@@ -51,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@@ -59,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "WeChatAuthEnabled":
if option.Value == "true" && common.WeChatServerAddress == "" {
if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@@ -67,7 +76,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "TurnstileCheckEnabled":
if option.Value == "true" && common.TurnstileSiteKey == "" {
if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",

View File

@@ -2,9 +2,10 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
)
@@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) {
if p < 0 {
p = 0
}
redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) {
}
var keys []string
for i := 0; i < redemption.Count; i++ {
key := common.GetUUID()
key := helper.GetUUID()
cleanRedemption := model.Redemption{
UserId: c.GetInt("id"),
Name: redemption.Name,
Key: key,
CreatedTime: common.GetTimestamp(),
CreatedTime: helper.GetTimestamp(),
Quota: redemption.Quota,
}
err = cleanRedemption.Insert()

View File

@@ -1,329 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: request.Messages[i+1].StringContent(),
})
i++
}
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,222 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
completionTokens := countTokenText(claudeResponse.Completion, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -1,281 +0,0 @@
package controller
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"github.com/gin-gonic/gin"
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{
// {
// Category: "HARM_CATEGORY_HARASSMENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_HATE_SPEECH",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
//},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
},
}
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "ok",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{
Index: i,
Message: Message{
Role: "assistant",
Content: candidate.Content.Parts[0].Text,
},
FinishReason: stopFinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text
}
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
common.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
responseText += geminiResponse.Candidates[0].Content.Parts[0].Text
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
c.Render(-1, common.CustomEvent{Data: "data: " + data})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -1,213 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
func isWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}
min := common.DalleGenerationImageAmounts[element][0]
max := common.DalleGenerationImageAmounts[element][1]
return value >= min && value <= max
}
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e-2"
imageSize := "1024x1024"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
var imageRequest ImageRequest
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
}
// Model validation
if imageRequest.Model != "" {
imageModel = imageRequest.Model
}
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
// Check if model is supported
if hasValidSize {
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
if imageSize == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
} else {
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// Check prompt length
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
}
var requestBody io.Reader
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var textResponse ImageResponse
defer func(ctx context.Context) {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -1,287 +0,0 @@
package controller
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"sort"
"strconv"
"strings"
)
// https://cloud.tencent.com/document/product/1729/97732
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var TencentResponse TencentChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

View File

@@ -1,693 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
apiVersion := GetAPIVersion(c)
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
// actually gemini does not support stream, it's fake
//if textRequest.Stream {
// action = "streamGenerateContent"
//}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if channelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
case APITypeGemini:
// do not set Authorization header
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
if isStream {
err, responseText := claudeStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}
}

View File

@@ -1,326 +1,132 @@
package controller
import (
"bytes"
"context"
"fmt"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == "text" {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
case constant.RelayModeImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech:
fallthrough
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type WhisperJSONResponse struct {
Text string `json:"text,omitempty"`
}
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
return err
}
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
channelId := c.GetInt("channel_id")
bizErr := relay(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
return
}
if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
retryTimes = 0
}
for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
break
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
if channel.Id == lastFailedChannelId {
continue
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relay(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
lastFailedChannelId = channelId
channelName := c.GetString("channel_name")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error,
})
}
}
func shouldRetry(c *gin.Context, statusCode int) bool {
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if statusCode == http.StatusTooManyRequests {
return true
}
if statusCode/100 == 5 {
return true
}
if statusCode == http.StatusBadRequest {
return false
}
if statusCode/100 == 2 {
return false
}
return true
}
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)
}
}
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
err := model.Error{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
@@ -332,7 +138,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
err := OpenAIError{
err := model.Error{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",

View File

@@ -2,9 +2,11 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
)
@@ -14,7 +16,7 @@ func GetAllTokens(c *gin.Context) {
if p < 0 {
p = 0
}
tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage)
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -119,9 +121,9 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
Key: helper.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
@@ -187,7 +189,7 @@ func UpdateToken(c *gin.Context) {
return
}
if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",

View File

@@ -3,10 +3,13 @@ package controller
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
@@ -18,7 +21,7 @@ type LoginRequest struct {
}
func Login(c *gin.Context) {
if !common.PasswordLoginEnabled {
if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录",
"success": false,
@@ -105,14 +108,14 @@ func Logout(c *gin.Context) {
}
func Register(c *gin.Context) {
if !common.RegisterEnabled {
if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
"success": false,
})
return
}
if !common.PasswordRegisterEnabled {
if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false,
@@ -135,7 +138,7 @@ func Register(c *gin.Context) {
})
return
}
if common.EmailVerificationEnabled {
if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -159,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username,
InviterId: inviterId,
}
if common.EmailVerificationEnabled {
if config.EmailVerificationEnabled {
cleanUser.Email = user.Email
}
if err := cleanUser.Insert(inviterId); err != nil {
@@ -181,7 +184,7 @@ func GetAllUsers(c *gin.Context) {
if p < 0 {
p = 0
}
users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -248,6 +251,29 @@ func GetUser(c *gin.Context) {
return
}
func GetUserDashboard(c *gin.Context) {
id := c.GetInt("id")
now := time.Now()
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法获取统计信息",
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dashboards,
})
return
}
func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, true)
@@ -258,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
user.AccessToken = common.GetUUID()
user.AccessToken = helper.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@@ -295,7 +321,7 @@ func GetAffCode(c *gin.Context) {
return
}
if user.AffCode == "" {
user.AffCode = common.GetRandomString(4)
user.AffCode = helper.GetRandomString(4)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -702,7 +728,7 @@ func EmailBind(c *gin.Context) {
return
}
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
config.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,

View File

@@ -5,9 +5,10 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
@@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" {
return "", errors.New("无效的参数")
}
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", common.WeChatServerToken)
req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{
Timeout: 5 * time.Second,
}
@@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
}
func WeChatAuth(c *gin.Context) {
if !common.WeChatAuthEnabled {
if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
@@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) {
return
}
} else {
if common.RegisterEnabled {
if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser
@@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
}
func WeChatBind(c *gin.Context) {
if !common.WeChatAuthEnabled {
if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,

6
go.mod
View File

@@ -1,4 +1,4 @@
module one-api
module github.com/songquanpeng/one-api
// +heroku goVersion go1.18
go 1.18
@@ -16,7 +16,7 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.14.0
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
@@ -58,7 +58,7 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

8
go.sum
View File

@@ -150,8 +150,8 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -164,8 +164,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View File

@@ -86,6 +86,7 @@
"该令牌已过期": "The token has expired",
"该令牌额度已用尽": "The token quota has been used up",
"无效的令牌": "Invalid token",
"令牌验证失败": "Token verification failed",
"id 或 userId 为空!": "id or userId is empty!",
"quota 不能为负数!": "quota cannot be negative!",
"令牌额度不足": "Insufficient token quota",
@@ -455,9 +456,11 @@
"已绑定的邮箱账户": "Email Account Bound",
"用户信息更新成功!": "User information updated successfully!",
"模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f",
"模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f",
"使用明细(总消耗额度:{renderQuota(stat.quota)}": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
"用户名称": "User Name",
"令牌名称": "Token Name",
"默认令牌": "Default Token",
"留空则查询全部用户": "Leave blank to query all users",
"留空则查询全部令牌": "Leave blank to query all tokens",
"模型名称": "Model Name",
@@ -526,5 +529,250 @@
"模型版本": "Model version",
"请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
"点击查看": "click to view",
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!",
"处理中...": "Processing...",
"绑定成功!": "Binding successful!",
"登录成功!": "Login successful!",
"操作失败,重定向至登录界面中...": "Operation failed, redirecting to login screen...",
"出现错误,第 ${count} 次重试中...": "An error occurred, retrying ${count}...",
"首页": "Home",
"渠道": "Channel",
"令牌": "API Keys",
"兑换": "Redeem",
"充值": "Recharge",
"用户": "Users",
"日志": "Logs",
"设置": "Settings",
"关于": "About",
"聊天": "Chat",
"注销成功!": "Logout successful!",
"注销": "Log out",
"登录": "Log in",
"注册": "Sign up",
"加载{name}中...": "Loading {name}...",
"未登录或登录已过期,请重新登录!": "Not logged in or login has expired, please log in again!",
"请立刻修改默认密码!": "Please change the default password immediately!",
"欢迎回来": "Welcome back",
"没有账户?": "No account?",
"立刻注册": "Sign up now",
"用户名": "Username",
"密码": "Password",
"正在登录……": "Logging in...",
"忘记密码": "Forgot password",
"其他方式": "Other methods",
"微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "Scan the QR code with WeChat, follow the official account and enter 'verification code' to get the verification code (valid within three minutes)",
"验证码": "Verification code",
"全部用户": "All users",
"当前用户": "Current user",
"全部": "All",
"消费": "Consumption",
"管理": "Management",
"系统": "System",
"未知": "Unknown",
"其他模型": "Other models",
"复制成功": "Copy successful",
"使用明细": "Usages",
"刷新": "Refresh",
"收起面板": "Collapse panel",
"展开面板": "Expand panel",
"显示查询选项": "Show search options",
"隐藏查询选项": "Hide search options",
"用户名称": "User name",
"可选值": "Optional values",
"渠道 ID": "Channel ID",
"令牌名称": "Key name",
"模型名称": "Model name",
"起始时间": "Start time",
"结束时间": "End time",
"查询": "Query",
"隐藏条形图": "Hide bar chart",
"显示条形图": "Show bar chart",
"折线条形图只展示最新50条数据": "Line and bar charts only show the latest 50 pieces of data",
"总消耗": "Total consumption",
"总共调用了 {payload[0].value} 次": "A total of {payload[0].value} calls were made",
"{model.name}: {model.value} 次": "{model.name}: {model.value} times",
"总共调用了 {payload[0].value} 次 {payload[0].name}": "A total of {payload[0].value} {payload[0].name} calls were made",
"总消耗额度": "Total consumption limit",
"暂无数据": "No data available",
"更多数据统计图形即将到来,敬请期待!": "More data statistics graphics are coming soon, stay tuned!",
"复制用户名": "Copy username",
"{`共 ${counts} 条数据`}": "{`A total of ${counts} pieces of data`}",
"共 0 条数据": "A total of 0 pieces of data",
"选择明细分类": "Select detail category",
"模型倍率": "model rate",
"分组倍率": "group rate",
"新密码已复制到剪贴板:": "New password has been copied to the clipboard:",
"密码重置确认": "Password reset confirmation",
"邮箱地址": "Email address",
"新密码": "New password",
"密码已复制到剪贴板:": "Password has been copied to the clipboard:",
"密码重置完成": "Password reset complete",
"提交": "Submit",
"返回登录": "Return to login",
"请稍后重试,浏览器环境检查未通过": "Please try again later, browser environment check failed",
"重置邮件发送成功,请检查邮箱!": "Reset email sent successfully, please check your email!",
"密码重置": "Password reset",
"重试": "Retry",
"组": "Group",
"令牌已重置并已复制到剪贴板": "Token has been reset and copied to the clipboard",
"邀请链接已复制到剪切板": "Invitation link has been copied to the clipboard",
"系统令牌已复制到剪切板": "System token has been copied to the clipboard",
"请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!",
"账户已删除!": "Account has been deleted!",
"微信账户绑定成功!": "WeChat account binding successful!",
"请稍后几秒重试Turnstile 正在检查用户环境!": "Please try again in a few seconds, Turnstile is checking the user environment!",
"验证码发送成功,请检查邮箱!": "Verification code sent successfully, please check your email!",
"邮箱账户绑定成功!": "Email account binding successful!",
"个人信息": "Personal information",
"编辑个人信息": "Edit personal information",
"生成系统访问令牌": "Generate system access token",
"复制邀请链接": "Copy invitation link",
"删除个人帐户": "Delete personal account",
"普通用户": "Regular user",
"管理员": "Administrator",
"超级管理员": "Super administrator",
"显示名称": "Display name",
"GitHub 账号": "GitHub account",
"微信账号": "WeChat account",
"修改个人信息只允许在电脑端进行。生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。": "Modifying personal information is only allowed on a computer. The generated token is for system management, not for requesting OpenAI related services. Please be aware.",
"可用模型": "Available models",
"账号绑定": "Account binding",
"绑定微信": "Bind WeChat",
"绑定 GitHub": "Bind GitHub",
"绑定邮箱": "Bind Email",
"绑定": "Bind",
"绑定邮箱地址": "Bind email address",
"输入邮箱地址": "Enter email address",
"重新发送": "Resend",
"获取验证码": "Get verification code",
"确认绑定": "Confirm binding",
"取消": "Cancel",
"危险操作": "Dangerous operation",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your own account, all data will be cleared and cannot be recovered",
"输入你的账户名": "Enter your account name",
"以确认删除": "To confirm deletion",
"确认删除": "Confirm deletion",
"未使用": "Not used",
"已禁用": "Disabled",
"已使用": "Used",
"未知状态": "Unknown status",
"操作成功完成!": "Operation successfully completed!",
"搜索兑换码的 ID 和名称 ...": "Search for the ID and name of the redemption code ...",
"名称": "Name",
"状态": "Status",
"额度": "Quota",
"创建时间": "Creation time",
"兑换时间": "Redemption time",
"操作": "Operation",
"尚未兑换": "Not yet redeemed",
"已复制到剪贴板!": "Copied to clipboard!",
"无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually. The redemption code has been filled in the search box.",
"复制": "Copy",
"删除": "Delete",
"禁用": "Disable",
"启用": "Enable",
"编辑": "Edit",
"添加新的兑换码": "Add new redemption code",
"密码长度不得小于 8 位!": "Password length must not be less than 8 characters!",
"两次输入的密码不一致": "The two passwords entered do not match",
"注册成功!": "Registration successful!",
"请填写注册邮箱!": "Please fill in the registration email!",
"请在${verificationTimeout}秒后再试": "Please try again after ${verificationTimeout} seconds",
"验证码发送成功,请检查你的邮箱!": "Verification code sent successfully, please check your email!",
"已有账户?": "Already have an account?",
"请输入用户名(最长 12 位)": "Please enter a username (up to 12 characters)",
"请输入密码(最短 8 位,最长 20 位)": "Please enter a password (minimum 8 characters, maximum 20 characters)",
"请再次输入密码": "Please enter the password again",
"请输入邮箱地址": "Please enter an email address",
"秒后可重发": "Can be resent after seconds",
"请输入邮箱验证码": "Please enter the email verification code",
"已过期": "Expired",
"已启用": "Enabled",
"已耗尽": "Exhausted",
"无": "None",
"令牌密钥": "API Key",
"令牌状态": "Key status",
"已用额度": "Used quota",
"剩余额度": "Remaining quota",
"过期时间": "Expiration time",
"你确定要删除这个令牌吗?": "Are you sure you want to delete this key?",
"无法复制到剪贴板,请手动复制,已将令牌密钥填入搜索框": "Unable to copy to clipboard, please copy manually. The key key has been filled in the search box.",
"无限制": "Unlimited",
"永不过期": "Never expires",
"使用 API 访问令牌进行服务鉴权和计费。": "Use API Key for service authentication and billing.",
"API 访问令牌关系到您的个人利益,请妥善留存,不要与其他人共享,也不要保存在客户端代码中。": "API Key is related to your personal interests. Please keep it properly. Do not share it with others or save it in client code.",
"创建令牌": "Create Key",
"什么都还没有,快去创建一个令牌开始使用吧!": "Nothing yet, go create a key to start using!",
"你确定要删除该令牌吗": "Are you sure you want to delete this key",
"导出令牌信息": "Export key information",
"错误:未登录或登录已过期,请重新登录!": "Error: Not logged in or login has expired, please log in again!",
"错误:请求次数过多,请稍后再试!": "Error: Too many requests, please try again later!",
"错误:服务器内部错误,请联系管理员!": "Error: Server internal error, please contact the online customer service!",
"本站仅作演示之用,无服务端!": "This site is for demonstration purposes only, no server!",
"错误:": "Error:",
"加载首页内容失败...": "Failed to load homepage content...",
"系统状况": "System status",
"系统信息": "System information",
"系统信息总览": "System information overview",
"名称:": "Name:",
"版本:": "Version:",
"源码:": "Source code:",
"启动时间:": "Startup time:",
"系统配置": "System configuration",
"系统配置总览": "System configuration overview",
"邮箱验证:": "Email verification:",
"未启用": "Not enabled",
"Turnstile 用户校验:": "Turnstile user verification:",
"页面不存在": "Page does not exist",
"请检查你的浏览器地址是否正确": "Please check if your browser address is correct",
"个人设置": "Personal settings",
"运营设置": "Operations settings",
"系统设置": "System settings",
"其他设置": "Other settings",
"默认令牌": "Default key",
"过期时间必须在当前时间之后!": "Expiration time must be after the current time!",
"额度必须大于等于 0": "Quota must be greater than or equal to 0!",
"过期时间格式错误!": "Expiration time format error!",
"创建令牌数量必须大于等于 1": "The number of keys to create must be greater than or equal to 1!",
"令牌修改成功": "API Key modification successful",
"令牌创建成功": "API Key creation successful",
"更新令牌信息": "Update key information",
"创建新的令牌": "Create a new key",
"请输入名称": "Please enter a name",
"请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means unlimited",
"无限额度": "Unlimited quota",
"注意:启用无限额度后,已用额度将不再进行计算。": "Note: After enabling unlimited quota, the used quota will no longer be calculated.",
"等于": "Equals",
"请输入额度单位token": "Please enter the quota (unit: token)",
"创建令牌数量": "Create key quantity",
"请输入令牌数量": "Please enter the number of keys",
"注意:令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note: The quota of the key is only used to limit the maximum quota usage of the key itself, and the actual usage is subject to the remaining quota of the account.",
"我的令牌": "My keys",
"请输入额度兑换码!": "Please enter the redeem code!",
"充值成功!": "Recharge successful!",
"请求失败": "Request failed",
"超级管理员未设置充值链接!": "The super administrator did not set a recharge link!",
"充值额度": "Recharge quota",
"兑换中...": "Redeeming...",
"请点击充值以获取额度兑换码。": "Please click recharge to get the quota redemption code.",
"用户信息更新成功!": "User information updated successfully!",
"更新用户信息": "Update user information",
"请输入新的用户名": "Please enter a new username",
"请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characters",
"请输入新的显示名称": "Please enter a new display name",
"分组": "Group",
"请选择分组": "Please select a group",
"请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group rate on the system settings page to add a new group:",
"请输入新的剩余额度": "Please enter a new remaining quota",
"已绑定的 GitHub 账户": "Bound GitHub account",
"此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only, users need to bind through the relevant binding button on the personal settings page, cannot be directly modified",
"已绑定的微信账户": "Bound WeChat account",
"已绑定的邮箱账户": "Bound email account",
"新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using the shortcut key Shift + F5",
"无法正常连接至服务器!": "Unable to connect to the server normally!",
"提示:": "Input:",
"补全:": "Output:",
"搜索令牌名称": "Search key name",
"测试所有渠道": "Test all channels",
"更新已启用渠道余额": "Update the balance of enabled channels"
}

71
main.go
View File

@@ -6,83 +6,80 @@ import (
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/controller"
"one-api/middleware"
"one-api/model"
"one-api/router"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/router"
"os"
"strconv"
)
//go:embed web/build
//go:embed web/build/*
var buildFS embed.FS
//go:embed web/build/index.html
var indexPage []byte
func main() {
common.SetupLogger()
common.SysLog("One API " + common.Version + " started")
logger.SetupLogger()
logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
if common.DebugEnabled {
common.SysLog("running in debug mode")
if config.DebugEnabled {
logger.SysLog("running in debug mode")
}
// Initialize SQL Database
err := model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
logger.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
common.FatalLog("failed to close database: " + err.Error())
logger.FatalLog("failed to close database: " + err.Error())
}
}()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
common.FatalLog("failed to initialize Redis: " + err.Error())
logger.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize options
model.InitOptionMap()
logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
config.MemoryCacheEnabled = true
}
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
if config.MemoryCacheEnabled {
logger.SysLog("memory cache enabled")
logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache()
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
if config.MemoryCacheEnabled {
go model.SyncOptions(config.SyncFrequency)
go model.SyncChannelCache(config.SyncFrequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
config.BatchUpdateEnabled = true
logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
controller.InitTokenEncoders()
if config.EnableMetric {
logger.SysLog("metric enabled, will disable channel if too much request failed")
}
openai.InitTokenEncoders()
// Initialize HTTP server
server := gin.New()
@@ -92,16 +89,16 @@ func main() {
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS, indexPage)
router.SetRouter(server, buildFS)
var port = os.Getenv("PORT")
if port == "" {
port = strconv.Itoa(*common.Port)
}
err = server.Run(":" + port)
if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error())
logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}

View File

@@ -3,9 +3,10 @@ package middleware
import (
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
@@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if status.(int) == common.UserStatusDisabled {
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
})
session := sessions.Default(c)
session.Clear()
_ = session.Save()
c.Abort()
return
}
@@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
@@ -108,7 +112,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("token_name", token.Name)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
c.Set("specific_channel_id", parts[1])
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return

View File

@@ -2,9 +2,10 @@ package middleware
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
@@ -20,8 +21,9 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var requestModel string
var channel *model.Channel
channelId, ok := c.Get("channelId")
channelId, ok := c.Get("specific_channel_id")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -65,35 +67,46 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "whisper-1"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
SetupContextForSelectedChannel(c, channel, requestModel)
c.Next()
}
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping())
c.Set("original_model", modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type {
case common.ChannelTypeAzure:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
c.Set(common.ConfigKeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
}
}

View File

@@ -3,14 +3,14 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"github.com/songquanpeng/one-api/common/logger"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string)
requestID = param.Keys[logger.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@@ -4,8 +4,9 @@ import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"net/http"
"one-api/common"
"time"
)
@@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
@@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
}
}
}
@@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
} else {
// It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark)
}
@@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
func GlobalWebRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
}
func GlobalAPIRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
}
func CriticalRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
}
func DownloadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
}
func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
}

View File

@@ -3,18 +3,25 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"one-api/common"
"runtime/debug"
)
func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
ctx := c.Request.Context()
logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err))
logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path))
body, _ := common.GetRequestBody(c)
logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body)))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err),
"type": "one_api_panic",
},
})

View File

@@ -3,16 +3,17 @@ package middleware
import (
"context"
"github.com/gin-gonic/gin"
"one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8)
c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
id := helper.GenRequestID()
c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id)
c.Header(logger.RequestIdKey, id)
c.Next()
}
}

View File

@@ -4,9 +4,10 @@ import (
"encoding/json"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"net/url"
"one-api/common"
)
type turnstileCheckResponse struct {
@@ -15,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) {
if common.TurnstileCheckEnabled {
if config.TurnstileCheckEnabled {
session := sessions.Default(c)
turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil {
@@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc {
return
}
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
"secret": {common.TurnstileSecretKey},
"secret": {config.TurnstileSecretKey},
"response": {response},
"remoteip": {c.ClientIP()},
})
if err != nil {
common.SysError(err.Error())
logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil {
common.SysError(err.Error())
logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@@ -2,16 +2,17 @@ package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
logger.Error(c.Request.Context(), message)
}

View File

@@ -1,7 +1,7 @@
package model
import (
"one-api/common"
"github.com/songquanpeng/one-api/common"
"strings"
)

View File

@@ -1,11 +1,14 @@
package model
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
@@ -14,10 +17,10 @@ import (
)
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
TokenCacheSeconds = config.SyncFrequency
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -42,7 +45,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
}
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set token error: " + err.Error())
logger.SysError("Redis set token error: " + err.Error())
}
return &token, nil
}
@@ -62,37 +65,48 @@ func CacheGetUserGroup(id int) (group string, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user group error: " + err.Error())
logger.SysError("Redis set user group error: " + err.Error())
}
}
return group, err
}
func CacheGetUserQuota(id int) (quota int, err error) {
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
quota, err = GetUserQuota(id)
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
logger.Error(ctx, "Redis set user quota error: "+err.Error())
}
return
}
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
}
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
if err != nil {
quota, err = GetUserQuota(id)
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user quota error: " + err.Error())
}
return quota, err
return fetchAndUpdateUserQuota(ctx, id)
}
quota, err = strconv.Atoi(quotaString)
return quota, err
quota, err = strconv.ParseInt(quotaString, 10, 64)
if err != nil {
return 0, nil
}
if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
return fetchAndUpdateUserQuota(ctx, id)
}
return quota, nil
}
func CacheUpdateUserQuota(id int) error {
func CacheUpdateUserQuota(ctx context.Context, id int) error {
if !common.RedisEnabled {
return nil
}
quota, err := GetUserQuota(id)
quota, err := CacheGetUserQuota(ctx, id)
if err != nil {
return err
}
@@ -100,7 +114,7 @@ func CacheUpdateUserQuota(id int) error {
return err
}
func CacheDecreaseUserQuota(id int, quota int) error {
func CacheDecreaseUserQuota(id int, quota int64) error {
if !common.RedisEnabled {
return nil
}
@@ -127,7 +141,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
logger.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
@@ -178,19 +192,19 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database")
logger.SysLog("syncing channels from database")
InitChannelCache()
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if !common.MemoryCacheEnabled {
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
@@ -211,5 +225,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
}
}
idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil
}

View File

@@ -1,14 +1,19 @@
package model
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
"one-api/common"
)
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null;index"`
Key string `json:"key" gorm:"type:text"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"`
@@ -16,7 +21,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"`
Other string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
@@ -24,25 +29,25 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
var channels []*Channel
var err error
if selectAll {
switch scope {
case "all":
err = DB.Order("id desc").Find(&channels).Error
} else {
case "disabled":
err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error
default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
return channels, err
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error
return channels, err
}
@@ -86,11 +91,17 @@ func (channel *Channel) GetBaseURL() string {
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() string {
if channel.ModelMapping == nil {
return ""
func (channel *Channel) GetModelMapping() map[string]string {
if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
return nil
}
return *channel.ModelMapping
modelMapping := make(map[string]string)
err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
if err != nil {
logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
return nil
}
return modelMapping
}
func (channel *Channel) Insert() error {
@@ -116,21 +127,21 @@ func (channel *Channel) Update() error {
func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
TestTime: common.GetTimestamp(),
TestTime: helper.GetTimestamp(),
ResponseTime: int(responseTime),
}).Error
if err != nil {
common.SysError("failed to update response time: " + err.Error())
logger.SysError("failed to update response time: " + err.Error())
}
}
func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: common.GetTimestamp(),
BalanceUpdatedTime: helper.GetTimestamp(),
Balance: balance,
}).Error
if err != nil {
common.SysError("failed to update balance: " + err.Error())
logger.SysError("failed to update balance: " + err.Error())
}
}
@@ -144,29 +155,41 @@ func (channel *Channel) Delete() error {
return err
}
func (channel *Channel) LoadConfig() (map[string]string, error) {
if channel.Config == "" {
return nil, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
logger.SysError("failed to update ability status: " + err.Error())
}
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
logger.SysError("failed to update channel status: " + err.Error())
}
}
func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
func UpdateChannelUsedQuota(id int, quota int64) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
updateChannelUsedQuota(id, quota)
}
func updateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int64) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
common.SysError("failed to update channel used quota: " + err.Error())
logger.SysError("failed to update channel used quota: " + err.Error())
}
}

View File

@@ -3,14 +3,18 @@ package model
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
"one-api/common"
)
type Log struct {
Id int `json:"id;index:idx_created_at_id,priority:1"`
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
@@ -31,43 +35,43 @@ const (
)
func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
if logType == LogTypeConsume && !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(),
CreatedAt: helper.GetTimestamp(),
Type: logType,
Content: content,
}
err := DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(),
CreatedAt: helper.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
Quota: int(quota),
ChannelId: channelId,
}
err := DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
logger.Error(ctx, "failed to record log: "+err.Error())
}
}
@@ -124,16 +128,16 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
if username != "" {
tx = tx.Where("username = ?", username)
@@ -182,3 +186,40 @@ func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}
type LogStatistic struct {
Day string `gorm:"column:day"`
ModelName string `gorm:"column:model_name"`
RequestCount int `gorm:"column:request_count"`
Quota int `gorm:"column:quota"`
PromptTokens int `gorm:"column:prompt_tokens"`
CompletionTokens int `gorm:"column:completion_tokens"`
}
func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) {
groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day"
if common.UsingPostgreSQL {
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
}
if common.UsingSQLite {
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
}
err = DB.Raw(`
SELECT `+groupSelect+`,
model_name, count(1) as request_count,
sum(quota) as quota,
sum(prompt_tokens) as prompt_tokens,
sum(completion_tokens) as completion_tokens
FROM logs
WHERE type=2
AND user_id= ?
AND created_at BETWEEN ? AND ?
GROUP BY day, model_name
ORDER BY day, model_name
`, userId, start, end).Scan(&LogStatistics).Error
return LogStatistics, err
}

View File

@@ -1,11 +1,16 @@
package model
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
"os"
"strings"
"time"
@@ -15,9 +20,9 @@ var DB *gorm.DB
func createRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
//if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -28,7 +33,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: common.GetUUID(),
AccessToken: helper.GetUUID(),
Quota: 100000000,
}
DB.Create(&rootUser)
@@ -41,7 +46,7 @@ func chooseDB() (*gorm.DB, error) {
dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
@@ -51,15 +56,17 @@ func chooseDB() (*gorm.DB, error) {
})
}
// Use MySQL
common.SysLog("using MySQL as database")
logger.SysLog("using MySQL as database")
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
@@ -67,7 +74,7 @@ func chooseDB() (*gorm.DB, error) {
func InitDB() (err error) {
db, err := chooseDB()
if err == nil {
if common.DebugEnabled {
if config.DebugSQLEnabled {
db = db.Debug()
}
DB = db
@@ -75,14 +82,17 @@ func InitDB() (err error) {
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
if !config.IsMasterNode {
return nil
}
common.SysLog("database migration started")
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return err
@@ -111,11 +121,11 @@ func InitDB() (err error) {
if err != nil {
return err
}
common.SysLog("database migrated")
logger.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
} else {
common.FatalLog(err)
logger.FatalLog(err)
}
return err
}

View File

@@ -1,7 +1,9 @@
package model
import (
"one-api/common"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"strconv"
"strings"
"time"
@@ -20,68 +22,71 @@ func AllOption() ([]*Option, error) {
}
func InitOptionMap() {
common.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string)
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
common.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = ""
common.OptionMap["Notice"] = ""
common.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = ""
common.OptionMap["Footer"] = common.Footer
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
common.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMapRWMutex.Unlock()
config.OptionMapRWMutex.Lock()
config.OptionMap = make(map[string]string)
config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
config.OptionMap["SMTPServer"] = ""
config.OptionMap["SMTPFrom"] = ""
config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
config.OptionMap["SMTPAccount"] = ""
config.OptionMap["SMTPToken"] = ""
config.OptionMap["Notice"] = ""
config.OptionMap["About"] = ""
config.OptionMap["HomePageContent"] = ""
config.OptionMap["Footer"] = config.Footer
config.OptionMap["SystemName"] = config.SystemName
config.OptionMap["Logo"] = config.Logo
config.OptionMap["ServerAddress"] = ""
config.OptionMap["GitHubClientId"] = ""
config.OptionMap["GitHubClientSecret"] = ""
config.OptionMap["WeChatServerAddress"] = ""
config.OptionMap["WeChatServerToken"] = ""
config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
config.OptionMap["MessagePusherAddress"] = ""
config.OptionMap["MessagePusherToken"] = ""
config.OptionMap["TurnstileSiteKey"] = ""
config.OptionMap["TurnstileSecretKey"] = ""
config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
config.OptionMap["Theme"] = config.Theme
config.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
if option.Key == "ModelRatio" {
option.Value = common.AddNewMissingRatio(option.Value)
}
err := updateOptionMap(option.Key, option.Value)
if err != nil {
common.SysError("failed to update option map: " + err.Error())
logger.SysError("failed to update option map: " + err.Error())
}
}
}
@@ -89,7 +94,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing options from database")
logger.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}
@@ -111,115 +116,110 @@ func UpdateOption(key string, value string) error {
}
func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
case "FileUploadPermission":
common.FileUploadPermission = intValue
case "FileDownloadPermission":
common.FileDownloadPermission = intValue
case "ImageUploadPermission":
common.ImageUploadPermission = intValue
case "ImageDownloadPermission":
common.ImageDownloadPermission = intValue
}
}
config.OptionMapRWMutex.Lock()
defer config.OptionMapRWMutex.Unlock()
config.OptionMap[key] = value
if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
common.PasswordRegisterEnabled = boolValue
config.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled":
common.PasswordLoginEnabled = boolValue
config.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled":
common.EmailVerificationEnabled = boolValue
config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
config.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue
config.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
config.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
config.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
config.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
common.AutomaticEnableChannelEnabled = boolValue
config.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
common.ApproximateTokenEnabled = boolValue
config.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue
config.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled":
common.DisplayInCurrencyEnabled = boolValue
config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
common.DisplayTokenStatEnabled = boolValue
config.DisplayTokenStatEnabled = boolValue
}
}
switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
config.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer":
common.SMTPServer = value
config.SMTPServer = value
case "SMTPPort":
intValue, _ := strconv.Atoi(value)
common.SMTPPort = intValue
config.SMTPPort = intValue
case "SMTPAccount":
common.SMTPAccount = value
config.SMTPAccount = value
case "SMTPFrom":
common.SMTPFrom = value
config.SMTPFrom = value
case "SMTPToken":
common.SMTPToken = value
config.SMTPToken = value
case "ServerAddress":
common.ServerAddress = value
config.ServerAddress = value
case "GitHubClientId":
common.GitHubClientId = value
config.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
config.GitHubClientSecret = value
case "Footer":
common.Footer = value
config.Footer = value
case "SystemName":
common.SystemName = value
config.SystemName = value
case "Logo":
common.Logo = value
config.Logo = value
case "WeChatServerAddress":
common.WeChatServerAddress = value
config.WeChatServerAddress = value
case "WeChatServerToken":
common.WeChatServerToken = value
config.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value
config.WeChatAccountQRCodeImageURL = value
case "MessagePusherAddress":
config.MessagePusherAddress = value
case "MessagePusherToken":
config.MessagePusherToken = value
case "TurnstileSiteKey":
common.TurnstileSiteKey = value
config.TurnstileSiteKey = value
case "TurnstileSecretKey":
common.TurnstileSecretKey = value
config.TurnstileSecretKey = value
case "QuotaForNewUser":
common.QuotaForNewUser, _ = strconv.Atoi(value)
config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInviter":
common.QuotaForInviter, _ = strconv.Atoi(value)
config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInvitee":
common.QuotaForInvitee, _ = strconv.Atoi(value)
config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
config.TopUpLink = value
case "ChatLink":
common.ChatLink = value
config.ChatLink = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "Theme":
config.Theme = value
}
return err
}

View File

@@ -3,8 +3,9 @@ package model
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm"
"one-api/common"
)
type Redemption struct {
@@ -13,7 +14,7 @@ type Redemption struct {
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
Quota int64 `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
@@ -41,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
return &redemption, err
}
func Redeem(key string, userId int) (quota int, err error) {
func Redeem(key string, userId int) (quota int64, err error) {
if key == "" {
return 0, errors.New("未提供兑换码")
}
@@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return err
}
redemption.RedeemedTime = common.GetTimestamp()
redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err

View File

@@ -3,8 +3,12 @@ package model
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"gorm.io/gorm"
"one-api/common"
)
type Token struct {
@@ -16,9 +20,9 @@ type Token struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
RemainQuota int64 `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@@ -38,39 +42,43 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.New("未提供令牌")
}
token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
if err != nil {
logger.SysError("CacheGetTokenByKey failed: " + err.Error())
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("无效的令牌")
}
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
}
return token, nil
return nil, errors.New("令牌验证失败")
}
return nil, errors.New("无效的令牌")
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
}
return token, nil
}
func GetTokenByIds(id int, userId int) (*Token, error) {
@@ -130,51 +138,51 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
func IncreaseTokenQuota(id int, quota int) (err error) {
func IncreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
return increaseTokenQuota(id, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
func increaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(),
"accessed_time": helper.GetTimestamp(),
},
).Error
return err
}
func DecreaseTokenQuota(id int, quota int) (err error) {
func DecreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
return decreaseTokenQuota(id, quota)
}
func decreaseTokenQuota(id int, quota int) (err error) {
func decreaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(),
"accessed_time": helper.GetTimestamp(),
},
).Error
return err
}
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -192,24 +200,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota {
return errors.New("用户额度不足")
}
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(token.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
logger.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email,
topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
err = message.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("failed to send email" + err.Error())
logger.SysError("failed to send email" + err.Error())
}
}
}()
@@ -224,7 +232,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
return err
}
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)

View File

@@ -3,8 +3,12 @@ package model
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
"one-api/common"
"strings"
)
@@ -15,15 +19,15 @@ type User struct {
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Role int `json:"role" gorm:"type:int;default:1"` // admin, util
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
Quota int64 `json:"quota" gorm:"type:int;default:0"`
UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
@@ -37,12 +41,16 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error
return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
if !common.UsingPostgreSQL {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
} else {
err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
}
return users, err
}
@@ -85,24 +93,24 @@ func (user *User) Insert(inviterId int) error {
return err
}
}
user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID()
user.AffCode = common.GetRandomString(4)
user.Quota = config.QuotaForNewUser
user.AccessToken = helper.GetUUID()
user.AffCode = helper.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
if config.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
}
}
return nil
@@ -116,6 +124,11 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
if user.Status == common.UserStatusDisabled {
blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled {
blacklist.UnbanUser(user.Id)
}
err = DB.Model(user).Updates(user).Error
return err
}
@@ -124,7 +137,10 @@ func (user *User) Delete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Delete(user).Error
blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID())
user.Status = common.UserStatusDeleted
err := DB.Model(user).Updates(user).Error
return err
}
@@ -137,7 +153,15 @@ func (user *User) ValidateAndFill() (err error) {
if user.Username == "" || password == "" {
return errors.New("用户名或密码为空")
}
DB.Where(User{Username: user.Username}).First(user)
err = DB.Where("username = ?", user.Username).First(user).Error
if err != nil {
// we must make sure check username firstly
// consider this case: a malicious user set his username as other's email
err := DB.Where("email = ?", user.Username).First(user).Error
if err != nil {
return errors.New("用户名或密码错误,或用户已被封禁")
}
}
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
@@ -220,7 +244,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
common.SysError("no such user " + err.Error())
logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -250,12 +274,12 @@ func ValidateAccessToken(token string) (user *User) {
return nil
}
func GetUserQuota(id int) (quota int, err error) {
func GetUserQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
return quota, err
}
func GetUserUsedQuota(id int) (quota int, err error) {
func GetUserUsedQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
return quota, err
}
@@ -275,34 +299,34 @@ func GetUserGroup(id int) (group string, err error) {
return group, err
}
func IncreaseUserQuota(id int, quota int) (err error) {
func IncreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
func increaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err
}
func DecreaseUserQuota(id int, quota int) (err error) {
func DecreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
func decreaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
@@ -312,8 +336,8 @@ func GetRootUserEmail() (email string) {
return email
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
@@ -321,7 +345,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
@@ -329,25 +353,25 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error())
logger.SysError("failed to update user used quota and request count: " + err.Error())
}
}
func updateUserUsedQuota(id int, quota int) {
func updateUserUsedQuota(id int, quota int64) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
common.SysError("failed to update user used quota: " + err.Error())
logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysError("failed to update user request count: " + err.Error())
logger.SysError("failed to update user request count: " + err.Error())
}
}

View File

@@ -1,7 +1,8 @@
package model
import (
"one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"sync"
"time"
)
@@ -15,12 +16,12 @@ const (
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
)
var batchUpdateStores []map[int]int
var batchUpdateStores []map[int]int64
var batchUpdateLocks []sync.Mutex
func init() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
}
}
@@ -28,13 +29,13 @@ func init() {
func InitBatchUpdater() {
go func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
}
func addNewRecord(type_ int, id int, value int) {
func addNewRecord(type_ int, id int, value int64) {
batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok {
@@ -45,11 +46,11 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
common.SysLog("batch update started")
logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateStores[i] = make(map[int]int64)
batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store {
@@ -57,21 +58,21 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysError("failed to batch update user quota: " + err.Error())
logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
updateUserRequestCount(key, int(value))
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
common.SysLog("batch update finished")
logger.SysLog("batch update finished")
}

55
monitor/channel.go Normal file
View File

@@ -0,0 +1,55 @@
package monitor
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/model"
)
func notifyRootUser(subject string, content string) {
if config.MessagePusherAddress != "" {
err := message.SendMessage(subject, content, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error()))
} else {
return
}
}
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
err := message.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("通道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
notifyRootUser(subject, content)
}
// EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}

79
monitor/metric.go Normal file
View File

@@ -0,0 +1,79 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
)
var store = make(map[int][]bool)
var metricSuccessChan = make(chan int, config.MetricSuccessChanSize)
var metricFailChan = make(chan int, config.MetricFailChanSize)
func consumeSuccess(channelId int) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], true)
}
func consumeFail(channelId int) (bool, float64) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], false)
successCount := 0
for _, success := range store[channelId] {
if success {
successCount++
}
}
successRate := float64(successCount) / float64(len(store[channelId]))
if len(store[channelId]) < config.MetricQueueSize {
return false, successRate
}
if successRate < config.MetricSuccessRateThreshold {
store[channelId] = make([]bool, 0)
return true, successRate
}
return false, successRate
}
func metricSuccessConsumer() {
for {
select {
case channelId := <-metricSuccessChan:
consumeSuccess(channelId)
}
}
}
func metricFailConsumer() {
for {
select {
case channelId := <-metricFailChan:
disable, successRate := consumeFail(channelId)
if disable {
go MetricDisableChannel(channelId, successRate)
}
}
}
}
func init() {
if config.EnableMetric {
go metricSuccessConsumer()
go metricFailConsumer()
}
}
func Emit(channelId int, success bool) {
if !config.EnableMetric {
return
}
go func() {
if success {
metricSuccessChan <- channelId
} else {
metricFailChan <- channelId
}
}()
}

View File

@@ -1,3 +1,9 @@
[//]: # (请按照以下格式关联 issue)
[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR谢谢)
[//]: # (项目维护者一般仅在周末处理 PR因此如若未能及时回复希望能理解)
[//]: # (开发者交流群910657413)
[//]: # (请在提交 PR 之前删除上面的注释)
close #issue_number
我已确认该 PR 已自测通过,相关截图如下:

View File

@@ -0,0 +1,8 @@
package ai360
var ModelList = []string{
"360GPT_S2_V9",
"embedding-bert-512-v1",
"embedding_s1_v1",
"semantic_similarity_s1_v1",
}

View File

@@ -0,0 +1,60 @@
package aiproxy
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "aiproxy"
}

View File

@@ -0,0 +1,9 @@
package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai"
var ModelList = []string{""}
func init() {
ModelList = openai.ModelList
}

View File

@@ -1,63 +1,37 @@
package controller
package aiproxy
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
}
return &AIProxyLibraryRequest{
return &LibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
if len(documents) == 0 {
return ""
}
@@ -68,52 +42,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
return content
}
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := OpenAITextResponseChoice{
choice := openai.TextResponseChoice{
Index: 0,
Message: Message{
Message: model.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: common.GetUUID(),
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Created: helper.GetTimestamp(),
Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Created: helper.GetTimestamp(),
Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -143,15 +117,15 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
}
stopChan <- true
}()
setEventStreamHeaders(c)
var documents []AIProxyLibraryDocument
common.SetEventStreamHeaders(c)
var documents []LibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
@@ -160,7 +134,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -169,7 +143,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -179,28 +153,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
@@ -211,10 +185,13 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage
}

View File

@@ -0,0 +1,32 @@
package aiproxy
type LibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type LibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type LibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type LibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []LibraryDocument `json:"documents"`
LibraryError
}
type LibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []LibraryDocument `json:"documents"`
}

View File

@@ -0,0 +1,83 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ali"
}

View File

@@ -0,0 +1,6 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
}

263
relay/channel/ali/main.go Normal file
View File

@@ -0,0 +1,263 @@
package ali
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
const EnableSearchModelSuffix = "-internet"
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, Message{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
if request.TopP >= 1 {
request.TopP = 0.9999
}
return &ChatRequest{
Model: aliModel,
Input: Input{
Messages: messages,
},
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
},
}
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "qwen",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -0,0 +1,73 @@
package ali
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -0,0 +1,63 @@
package anthropic
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
req.Header.Set("anthropic-beta", "messages-2023-12-15")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "anthropic"
}

View File

@@ -0,0 +1,8 @@
package anthropic
var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240229",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
}

View File

@@ -0,0 +1,272 @@
package anthropic
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func stopReasonClaude2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
// legacy model name mapping
if claudeRequest.Model == "claude-instant-1" {
claudeRequest.Model = "claude-instant-1.1"
} else if claudeRequest.Model == "claude-2" {
claudeRequest.Model = "claude-2.1"
}
for _, message := range textRequest.Messages {
if message.Role == "system" && claudeRequest.System == "" {
claudeRequest.System = message.StringContent()
continue
}
claudeMessage := Message{
Role: message.Role,
}
var content Content
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
claudeMessage.Content = append(claudeMessage.Content, content)
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue
}
var contents []Content
openaiContent := message.ParseContent()
for _, part := range openaiContent {
var content Content
if part.Type == model.ContentTypeText {
content.Type = "text"
content.Text = part.Text
} else if part.Type == model.ContentTypeImageURL {
content.Type = "image"
content.Source = &ImageSource{
Type: "base64",
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
content.Source.MediaType = mimeType
content.Source.Data = data
}
contents = append(contents, content)
}
claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
}
return &claudeRequest
}
// https://docs.anthropic.com/claude/reference/messages-streaming
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
}
case "message_delta":
if claudeResponse.Usage != nil {
response = &Response{
Usage: *claudeResponse.Usage,
}
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
}
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var usage model.Usage
var modelName string
var id string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,75 @@
package anthropic
// https://docs.anthropic.com/claude/reference/messages_post
type Metadata struct {
UserId string `json:"user_id"`
}
type ImageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type Content struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"`
}
type Message struct {
Role string `json:"role"`
Content []Content `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []Content `json:"content"`
Model string `json:"model"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Usage Usage `json:"usage"`
Error Error `json:"error"`
}
type Delta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
}
type StreamResponse struct {
Type string `json:"type"`
Message *Response `json:"message"`
Index int `json:"index"`
ContentBlock *Content `json:"content_block"`
Delta *Delta `json:"delta"`
Usage *Usage `json:"usage"`
}

View File

@@ -0,0 +1,7 @@
package baichuan
var ModelList = []string{
"Baichuan2-Turbo",
"Baichuan2-Turbo-192k",
"Baichuan-Text-Embedding",
}

View File

@@ -0,0 +1,105 @@
package baidu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix("Embedding", meta.ActualModelName) {
suffix = "embeddings/"
}
switch meta.ActualModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
suffix += "embedding-v1"
default:
suffix += meta.ActualModelName
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string
var err error
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "baidu"
}

View File

@@ -0,0 +1,10 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
}

View File

@@ -1,4 +1,4 @@
package controller
package baidu
import (
"bufio"
@@ -6,9 +6,14 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"one-api/common"
"strings"
"sync"
"time"
@@ -16,148 +21,111 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct {
type TokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
type BaiduMessage struct {
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
type ChatRequest struct {
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}
type BaiduError struct {
type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
baiduRequest := ChatRequest{
Messages: make([]Message, 0, len(request.Messages)),
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: request.MaxTokens,
UserId: request.User,
}
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
baiduRequest.System = message.StringContent()
} else {
messages = append(messages, BaiduMessage{
baiduRequest.Messages = append(baiduRequest.Messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
}
return &baiduRequest
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: Message{
Message: model.Message{
Role: "assistant",
Content: response.Result,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
fullTextResponse := openai.TextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
Choices: []OpenAITextResponseChoice{choice},
Choices: []openai.TextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason
choice.FinishReason = &constant.StopFinishReason
}
response := ChatCompletionsStreamResponse{
response := openai.ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Input: request.ParseInput(),
}
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
@@ -166,8 +134,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -194,14 +162,14 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}
stopChan <- true
}()
setEventStreamHeaders(c)
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse
var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -212,7 +180,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -224,28 +192,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduChatResponse
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -255,9 +223,10 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
}, nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
fullTextResponse.Model = "ernie-bot"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -265,23 +234,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
return nil, &fullTextResponse.Usage
}
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduEmbeddingResponse
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -293,7 +262,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -301,10 +270,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
return nil, &fullTextResponse.Usage
}
func getBaiduAccessToken(apiKey string) (string, error) {
func GetAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken
if accessToken, ok = val.(BaiduAccessToken); ok {
var accessToken AccessToken
if accessToken, ok = val.(AccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
@@ -319,12 +288,12 @@ func getBaiduAccessToken(apiKey string) (string, error) {
return "", err
}
if accessToken == nil {
return "", errors.New("getBaiduAccessToken return a nil token")
return "", errors.New("GetAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
@@ -336,13 +305,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := impatientHTTPClient.Do(req)
res, err := util.ImpatientHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
var accessToken BaiduAccessToken
var accessToken AccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil {
return nil, err

View File

@@ -0,0 +1,50 @@
package baidu
import (
"github.com/songquanpeng/one-api/relay/model"
"time"
)
type ChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage model.Usage `json:"usage"`
Error
}
type ChatStreamResponse struct {
ChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type EmbeddingRequest struct {
Input []string `json:"input"`
}
type EmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type EmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []EmbeddingData `json:"data"`
Usage model.Usage `json:"usage"`
Error
}
type AccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}

51
relay/channel/common.go Normal file
View File

@@ -0,0 +1,51 @@
package channel
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := DoRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}

View File

@@ -0,0 +1,66 @@
package gemini
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "google gemini"
}

View File

@@ -0,0 +1,6 @@
package gemini
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
}

View File

@@ -0,0 +1,303 @@
package gemini
import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
VisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := ChatContent{
Role: message.Role,
Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []Part
imageNum := 0
for _, part := range openaiContent {
if part.Type == model.ContentTypeText {
parts = append(parts, Part{
Text: part.Text,
})
} else if part.Type == model.ContentTypeImageURL {
imageNum += 1
if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
Parts: []Part{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
}
func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type ChatCandidate struct {
Content ChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: model.Message{
Role: "assistant",
Content: "",
},
FinishReason: constant.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,41 @@
package gemini
type ChatRequest struct {
Contents []ChatContent `json:"contents"`
SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
Tools []ChatTools `json:"tools,omitempty"`
}
type InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type Part struct {
Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"`
}
type ChatContent struct {
Role string `json:"role,omitempty"`
Parts []Part `json:"parts"`
}
type ChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type ChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type ChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}

View File

@@ -0,0 +1,10 @@
package groq
// https://console.groq.com/docs/models
var ModelList = []string{
"gemma-7b-it",
"llama2-7b-2048",
"llama2-70b-4096",
"mixtral-8x7b-32768",
}

View File

@@ -0,0 +1,20 @@
package channel
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor interface {
Init(meta *util.RelayMeta)
GetRequestURL(meta *util.RelayMeta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

Some files were not shown because too many files have changed in this diff Show More