Compare commits

..

71 Commits

Author SHA1 Message Date
ckt1031
a891a3e64f feat: bump deps 2023-07-18 22:29:47 +08:00
ckt1031
fd72565011 feat: support Discord Guild Join 2023-07-18 22:24:38 +08:00
ckt1031
4b9756b257 feat: support chatbot ui 2023-07-17 15:35:02 +08:00
ckt1031
a6ae20ed54 fix: chatgptweb 2023-07-16 21:48:54 +08:00
ckt1031
617149d731 fix: custom models 2023-07-16 21:23:56 +08:00
ckt1031
edd2c4f6e9 fix: testing issue 2023-07-16 16:01:52 +08:00
ckt1031
481c4ebf49 fix: chatgptweb issue 2023-07-16 15:35:32 +08:00
ckt1031
203471d7a9 Merge remote-tracking branch 'upstream/main' 2023-07-16 13:12:45 +08:00
ckt1031
7ad6f7d99d fix: update i18n 2023-07-15 23:13:43 +08:00
ckt1031
68abcd48ab fix: scanner issue 2023-07-15 23:05:43 +08:00
ckt1031
0c175b4e44 fix: model issue from upstream 2023-07-15 23:05:33 +08:00
ckt1031
4e31c3991d fix: improved checking for chatgptweb 2023-07-15 22:16:52 +08:00
ckt1031
5b8a826cf9 fix: let user to define api path for chatgptweb 2023-07-15 22:14:42 +08:00
ckt1031
f5f21dffd8 fix: remove printing invalid stream response 2023-07-15 21:51:28 +08:00
ckt1031
4e94c85a9a feat: move to vite for faster builld 2023-07-15 21:41:23 +08:00
ckt1031
caabdd1e21 fix: run prettier 2023-07-15 21:14:40 +08:00
ckt1031
0424baef6a fix: merge 2 2023-07-15 21:13:26 +08:00
ckt1031
256d290507 fix: merge latest change from remote 2023-07-15 21:12:55 +08:00
ckt1031
8f0799d909 feat: support reverse proxy of Chanzhaoyu/chatgpt-web 2023-07-15 21:03:27 +08:00
ckt1031
349e3a3661 feat: add default models for token creation 2023-07-15 11:47:09 +08:00
ckt1031
8cc7f983e1 fix: model creation issue 2023-07-14 23:53:23 +08:00
ckt1031
455643e317 fix: model token creation issue 2023-07-14 23:29:11 +08:00
ckt1031
1c7bad7b87 fix: token model list 2023-07-14 23:07:22 +08:00
ckt1031
3141292026 fix: i18n 2023-07-14 22:42:27 +08:00
ckt1031
e4500bf8bf featL add token-side model selection 2023-07-14 22:41:22 +08:00
ckt1031
4043fccedb feat: support ip randomize in http header 2023-07-14 21:30:13 +08:00
ckt1031
164df4e708 fix: resp body when error 2023-07-14 20:21:25 +08:00
ckt1031
d850f465cd Merge remote-tracking branch 'upstream/main' 2023-07-13 22:27:29 +08:00
ckt1031
e2f5c1eb8c fix: channel testing for reverse proxy 2023-07-13 22:07:07 +08:00
ckt1031
d68aa4c96f fix: removing maxtokens 2023-07-13 21:28:14 +08:00
ckt1031
47cb77de53 fix: better text phrasing 2023-07-13 20:49:57 +08:00
ckt1031
61912f5e2c fix: patch testing 2023-07-13 19:40:36 +08:00
ckt1031
379d03798c fix: add user edit discord 2023-07-12 21:14:30 +08:00
ckt1031
520eb34b72 fix: json i18n 2023-07-12 18:05:25 +08:00
ckt1031
855bb82ae7 feat: improve i18n 2023-07-12 17:58:09 +08:00
ckt1031
8c91bd9c97 feat: enforce streaming in channel testing 2023-07-12 17:43:43 +08:00
ckt1031
7c7a45a4f5 feat: support account deletion 2023-07-12 15:57:40 +08:00
ckt1031
0ac0214c41 fix: billing date json issue 2023-07-12 15:15:07 +08:00
ckt1031
b63400ebe2 feat: add Discord Oauth2 support (1) 2023-07-12 15:11:02 +08:00
ckt1031
b17d9bc649 fix: add stream body if not exist 2023-07-11 23:05:01 +08:00
ckt1031
9ef8167e5d feat: strict testing 2023-07-11 23:01:36 +08:00
ckt1031
3baad1d926 chore: update readme 2023-07-11 17:51:41 +08:00
ckt1031
80d5d6edfb feat: support return date for billing 2023-07-11 17:40:52 +08:00
ckt1031
12365ccf69 feat: optimized channel testing (1) 2023-07-11 17:11:55 +08:00
ckt1031
4928319494 fix: docekrfile 2023-07-11 17:06:03 +08:00
ckt1031
839dcc3ab2 feat: better dockerfile 2023-07-11 17:03:22 +08:00
ckt1031
270e366cd9 fix: add @babel/plugin-proposal-private-property-in-object 2023-07-11 16:59:44 +08:00
ckt1031
67b8e82457 fix: billing status code check 2023-07-10 23:15:30 +08:00
ckt1031
4b2cb573b6 fix: testing channel reject if not 200 2023-07-10 23:09:15 +08:00
ckt1031
bde43cc358 feat: support dotenv 2023-07-10 23:05:21 +08:00
ckt
6e06dcfcf8 Update english.dockerfile 2023-07-10 20:27:42 +08:00
ckt
b617599211 Update english.dockerfile 2023-07-10 20:24:46 +08:00
ckt
1656b9a1de Update english.dockerfile 2023-07-10 20:21:30 +08:00
ckt
eb5f8f2d75 Create english.dockerfile 2023-07-10 20:17:32 +08:00
ckt1031
2ae5741214 fix: move back to react-scripts 2023-07-09 21:18:22 +08:00
ckt1031
28d58849a0 fix: environmental issue 2023-07-09 21:12:56 +08:00
ckt1031
adc9679d56 feat: optimized env for docker 2023-07-09 20:43:18 +08:00
ckt1031
07589ae305 fix: bump dependencies 2023-07-09 20:37:51 +08:00
ckt1031
95bc32c555 Merge branch 'dall-e-image-creation' 2023-07-09 20:28:57 +08:00
ckt1031
d61dc4a9ca feat: initial support of Dall-E 2023-07-09 19:15:15 +08:00
ckt1031
b29acb0c89 chore: use personal docker 2023-07-09 17:26:26 +08:00
ckt1031
a8e418275d fix: about word 2023-07-09 17:16:26 +08:00
ckt1031
d7ab9b0935 fix: turnstil should show in login form 2023-07-09 17:12:31 +08:00
ckt1031
ebd62c3bfc fix: vite process 2023-07-09 17:01:14 +08:00
ckt1031
4e6f9f67b3 feat: better chn translation 2023-07-09 16:43:39 +08:00
ckt1031
77d295bbf5 fix: removed local update notice 2023-07-08 16:43:57 +08:00
ckt1031
7e3e25fbd9 Fix faulty outdir 2023-07-08 07:56:31 +00:00
ckt1031
9bf98ab53a Fix vite react variable issue 2023-07-08 07:52:37 +00:00
ckt
3ed9a219c7 Update package.json 2023-07-08 15:08:46 +08:00
ckt1031
37d7afcedc feat: use vite 2023-07-08 14:58:42 +08:00
ckt1031
2756554f7c fix: channel testing issue 2023-07-08 14:24:51 +08:00
366 changed files with 11648 additions and 23150 deletions

View File

@@ -1,46 +0,0 @@
root = "."
testdata_dir = "testdata"
tmp_dir = "tmp"
[build]
args_bin = []
bin = "./tmp/main"
cmd = "go build -o ./tmp/main ."
delay = 1000
exclude_dir = ["assets", "tmp", "vendor", "testdata", "web"]
exclude_file = []
exclude_regex = ["_test.go"]
exclude_unchanged = false
follow_symlink = false
full_bin = ""
include_dir = []
include_ext = ["go", "tpl", "tmpl", "html"]
include_file = []
kill_delay = "0s"
log = "build-errors.log"
poll = false
poll_interval = 0
post_cmd = []
pre_cmd = []
rerun = false
rerun_delay = 500
send_interrupt = false
stop_on_error = false
[color]
app = ""
build = "yellow"
main = "magenta"
runner = "green"
watcher = "cyan"
[log]
main_only = false
time = false
[misc]
clean_on_exit = false
[screen]
clear_on_rebuild = false
keep_scroll = true

View File

@@ -0,0 +1,49 @@
name: Publish Docker image (amd64, English)
on:
push:
tags:
- '*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Save version info
run: |
git describe --tags > VERSION
- name: Translate
run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
ckt1031/one-api-en
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -0,0 +1,54 @@
name: Publish Docker image (amd64)
on:
push:
tags:
- '*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Save version info
run: |
git describe --tags > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
ckt1031/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -0,0 +1,62 @@
name: Publish Docker image (arm64)
on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Save version info
run: |
git describe --tags > VERSION
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
ckt1031/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,62 +0,0 @@
name: one-api docker image
on:
push:
branches:
- main
tags:
- "v*"
env:
# github.repository as <account>/<repo>
IMAGE_NAME: martialbe/one-api
jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to GHCR
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GT_Token }}
- name: Docker meta
id: meta
uses: docker/metadata-action@v4
with:
# list of Docker images to use as base name for tags
images: ghcr.io/${{ env.IMAGE_NAME }}
# generate Docker tags based on the following events/attributes
tags: |
type=raw,value=dev,enable=${{ github.ref == 'refs/heads/main' }}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=pep440,pattern={{raw}},enable=${{ startsWith(github.ref, 'refs/tags/') }}
- name: Build and push
uses: docker/build-push-action@v4
with:
context: .
platforms: linux/amd64
build-args: |
COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -5,8 +5,8 @@ permissions:
on:
push:
tags:
- "*"
- "!*-alpha*"
- '*'
- '!*-alpha*'
jobs:
release:
runs-on: ubuntu-latest
@@ -24,12 +24,12 @@ jobs:
run: |
cd web
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
VITE_REACT_APP_VERSION=$(git describe --tags) npm run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ">=1.18.0"
go-version: '>=1.18.0'
- name: Build Backend (amd64)
run: |
go mod download
@@ -51,4 +51,4 @@ jobs:
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GT_Token }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -5,8 +5,8 @@ permissions:
on:
push:
tags:
- "*"
- "!*-alpha*"
- '*'
- '!*-alpha*'
jobs:
release:
runs-on: macos-latest
@@ -24,12 +24,12 @@ jobs:
run: |
cd web
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
VITE_REACT_APP_VERSION=$(git describe --tags) npm run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ">=1.18.0"
go-version: '>=1.18.0'
- name: Build Backend
run: |
go mod download
@@ -42,4 +42,4 @@ jobs:
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GT_Token }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -5,8 +5,8 @@ permissions:
on:
push:
tags:
- "*"
- "!*-alpha*"
- '*'
- '!*-alpha*'
jobs:
release:
runs-on: windows-latest
@@ -27,12 +27,12 @@ jobs:
run: |
cd web
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
VITE_REACT_APP_VERSION=$(git describe --tags) npm run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ">=1.18.0"
go-version: '>=1.18.0'
- name: Build Backend
run: |
go mod download
@@ -45,4 +45,4 @@ jobs:
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GT_Token }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

5
.gitignore vendored
View File

@@ -5,7 +5,4 @@ upload
*.db
build
*.db-journal
logs
data
tmp/
.env
.env*

View File

@@ -1,33 +1,29 @@
FROM node:16 as builder
# Node build stage
FROM node:18 as builder
WORKDIR /build
COPY web/package.json .
RUN npm install
COPY ./web/package*.json ./
RUN npm ci
COPY ./web .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
RUN VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
# Go build stage
FROM golang AS builder2
ENV GO111MODULE=on \
CGO_ENABLED=1 \
GOOS=linux
WORKDIR /build
ADD go.mod go.sum ./
COPY go.mod .
COPY go.sum .
RUN go mod download
COPY . .
COPY --from=builder /build/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
# Final stage
FROM alpine
RUN apk update \
&& apk upgrade \
&& apk add --no-cache ca-certificates tzdata \
&& update-ca-certificates 2>/dev/null || true
RUN apk update && apk upgrade && apk add --no-cache ca-certificates tzdata && update-ca-certificates 2>/dev/null || true
WORKDIR /data
COPY --from=builder2 /build/one-api /
EXPOSE 3000
WORKDIR /data
ENTRYPOINT ["/one-api"]

View File

@@ -1,5 +1,5 @@
<p align="right">
<a href="./README.md">中文</a> | <strong>English</strong> | <a href="./README.ja.md">日本語</a>
<a href="./README.md">中文</a> | <strong>English</strong>
</p>
<p align="center">
@@ -10,38 +10,28 @@
# One API
_This project is a derivative of [one-api](https://github.com/songquanpeng/one-api), where the main focus has been on modularizing the module code from the original project and modifying the frontend interface. This project also adheres to the MIT License._
<p align="center">
<a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE">
<img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license">
</a>
<a href="https://github.com/MartialBE/one-api/releases/latest">
<img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://github.com/users/MartialBE/packages/container/package/one-api">
<img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
</a>
<a href="https://goreportcard.com/report/github.com/MartialBE/one-api">
<img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard">
</a>
</p>
**Please do not mix with the original version, as the different channel ID may cause data disorder.**
## Screenshots
![dashboard](https://github.com/MartialBE/one-api/assets/42402987/c7f95d64-e7e3-4d0f-8ad8-36d6740da8db)
![topup](https://github.com/MartialBE/one-api/assets/42402987/4bc9dbfd-84f6-4700-9ea5-308c09230c7a)
_The following is the original project description:_
---
_✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_
_✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_
</div>
<p align="center">
<a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE">
<img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license">
</a>
<a href="https://github.com/songquanpeng/one-api/releases/latest">
<img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://hub.docker.com/repository/docker/justsong/one-api">
<img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull">
</a>
<a href="https://github.com/songquanpeng/one-api/releases/latest">
<img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://goreportcard.com/report/github.com/songquanpeng/one-api">
<img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard">
</a>
</p>
<p align="center">
<a href="#deployment">Deployment Tutorial</a>
·
@@ -67,14 +57,15 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
> **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability.
## Features
1. Support for multiple large models:
- [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
- [x] [Anthropic Claude Series Models](https://anthropic.com)
- [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google)
- [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
- [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
- [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
1. Supports multiple API access channels:
+ [x] Official OpenAI channel (support proxy configuration)
+ [x] **Azure OpenAI API**
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`)
+ [x] Custom channel: Various third-party proxy services not included in the list
2. Supports access to multiple channels through **load balancing**.
3. Supports **stream mode** that enables typewriter-like effect through stream transmission.
4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details.
@@ -93,15 +84,13 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
15. Supports management API access through system access tokens.
16. Supports Cloudflare Turnstile user verification.
17. Supports user management and multiple user login/registration methods:
- Email login/registration and password reset via email.
- [GitHub OAuth](https://github.com/settings/applications/new).
- WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)).
+ Email login/registration and password reset via email.
+ [GitHub OAuth](https://github.com/settings/applications/new).
+ WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)).
18. Immediate support and encapsulation of other major model APIs as they become available.
## Deployment
### Docker Deployment
Deployment command: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`
Update command: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR`
@@ -111,7 +100,6 @@ The first `3000` in `-p 3000:3000` is the port of the host, which can be modifie
Data will be saved in the `/home/ubuntu/data/one-api` directory on the host. Ensure that the directory exists and has write permissions, or change it to a suitable directory.
Nginx reference configuration:
```
server{
server_name openai.justsong.cn; # Modify your domain name accordingly
@@ -129,7 +117,6 @@ server{
```
Next, configure HTTPS with Let's Encrypt certbot:
```bash
# Install certbot on Ubuntu:
sudo snap install --classic certbot
@@ -144,9 +131,7 @@ sudo service nginx restart
The initial account username is `root` and password is `123456`.
### Manual Deployment
1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source:
```shell
git clone https://github.com/songquanpeng/one-api.git
@@ -160,7 +145,6 @@ The initial account username is `root` and password is `123456`.
go mod download
go build -ldflags "-s -w" -o one-api
```
2. Run:
```shell
chmod u+x one-api
@@ -171,7 +155,6 @@ The initial account username is `root` and password is `123456`.
For more detailed deployment tutorials, please refer to [this page](https://iamazing.cn/page/how-to-deploy-a-website).
### Multi-machine Deployment
1. Set the same `SESSION_SECRET` for all servers.
2. Set `SQL_DSN` and use MySQL instead of SQLite. All servers should connect to the same database.
3. Set the `NODE_TYPE` for all non-master nodes to `slave`.
@@ -183,22 +166,16 @@ For more detailed deployment tutorials, please refer to [this page](https://iama
Please refer to the [environment variables](#environment-variables) section for details on using environment variables.
### Deployment on Control Panels (e.g., Baota)
Refer to [#175](https://github.com/songquanpeng/one-api/issues/175) for detailed instructions.
If you encounter a blank page after deployment, refer to [#97](https://github.com/songquanpeng/one-api/issues/97) for possible solutions.
### Deployment on Third-Party Platforms
<details>
<summary><strong>Deploy on Sealos</strong></summary>
<div>
> Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users.
> Click the button below to deploy with one click.👇
[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md).
</div>
</details>
@@ -209,12 +186,10 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. First, fork the code.
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
2. Go to [Zeabur](https://zeabur.com/), log in, and enter the console.
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
4. Copy the connection parameters and run `` create database `one-api` `` to create the database.
4. Copy the connection parameters and run ```create database `one-api` ``` to create the database.
5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository.
6. Automatic deployment will start, but please cancel it for now. Go to the Variable tab, add a `PORT` with a value of `3000`, and then add a `SQL_DSN` with a value of `<username>:<password>@tcp(<addr>:<port>)/one-api`. Save the changes. Please note that if `SQL_DSN` is not set, data will not be persisted, and the data will be lost after redeployment.
7. Select Redeploy.
@@ -225,7 +200,6 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
</details>
## Configuration
The system is ready to use out of the box.
You can configure it by setting environment variables or command line parameters.
@@ -233,7 +207,6 @@ You can configure it by setting environment variables or command line parameters
After the system starts, log in as the `root` user to further configure the system.
## Usage
Add your API Key on the `Channels` page, and then add an access token on the `Tokens` page.
You can then use your access token to access One API. The usage is consistent with the [OpenAI API](https://platform.openai.com/docs/api-reference/introduction).
@@ -257,65 +230,59 @@ Note that the token needs to be created by an administrator to specify the chann
If the channel ID is not provided, load balancing will be used to distribute the requests to multiple channels.
### Environment Variables
1. `REDIS_CONN_STRING`: When set, Redis will be used as the storage for request rate limiting instead of memory.
- Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `SESSION_SECRET`: When set, a fixed session key will be used to ensure that cookies of logged-in users are still valid after the system restarts.
- Example: `SESSION_SECRET=random_string`
+ Example: `SESSION_SECRET=random_string`
3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0.
- Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
+ Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
- Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
- Example: `SYNC_FREQUENCY=60`
+ Example: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
- Example: `NODE_TYPE=slave`
+ Example: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
- Example: `CHANNEL_UPDATE_FREQUENCY=1440`
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
- Example: `CHANNEL_TEST_FREQUENCY=1440`
+ Example: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
- Example: `POLLING_INTERVAL=5`
+ Example: `POLLING_INTERVAL=5`
### Command Line Parameters
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.
- Example: `--port 3000`
+ Example: `--port 3000`
2. `--log-dir <log_dir>`: Specifies the log directory. If not set, the logs will not be saved.
- Example: `--log-dir ./logs`
+ Example: `--log-dir ./logs`
3. `--version`: Prints the system version number and exits.
4. `--help`: Displays the command usage help and parameter descriptions.
## Screenshots
![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png)
![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png)
## FAQ
1. What is quota? How is it calculated? Does One API have quota calculation issues?
- Quota = Group multiplier _ Model multiplier _ (number of prompt tokens + number of completion tokens \* completion multiplier)
- The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition.
- If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different.
+ Quota = Group multiplier * Model multiplier * (number of prompt tokens + number of completion tokens * completion multiplier)
+ The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition.
+ If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different.
2. Why does it prompt "insufficient quota" even though my account balance is sufficient?
- Please check if your token quota is sufficient. It is separate from the account balance.
- The token quota is used to set the maximum usage and can be freely set by the user.
+ Please check if your token quota is sufficient. It is separate from the account balance.
+ The token quota is used to set the maximum usage and can be freely set by the user.
3. It says "No available channels" when trying to use a channel. What should I do?
- Please check the user and channel group settings.
- Also check the channel model settings.
+ Please check the user and channel group settings.
+ Also check the channel model settings.
4. Channel testing reports an error: "invalid character '<' looking for beginning of value"
- This error occurs when the returned value is not valid JSON but an HTML page.
- Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare.
+ This error occurs when the returned value is not valid JSON but an HTML page.
+ Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare.
5. ChatGPT Next Web reports an error: "Failed to fetch"
- Do not set `BASE_URL` during deployment.
- Double-check that your interface address and API Key are correct.
+ Do not set `BASE_URL` during deployment.
+ Double-check that your interface address and API Key are correct.
## Related Projects
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
[FastGPT](https://github.com/c121914yu/FastGPT): Build an AI knowledge base in three minutes
## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.
This project is released under the MIT license. Based on this, attribution and a link to this project must be included at the bottom of the page.

View File

@@ -1,328 +0,0 @@
<p align="right">
<a href="./README.md">中文</a> | <a href="./README.en.md">English</a> | <strong>日本語</strong>
</p>
<p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p>
<div align="center">
# One API
_このプロジェクトは、[one-api](https://github.com/songquanpeng/one-api)をベースにしており、元のプロジェクトのモジュールコードを分離し、モジュール化し、フロントエンドのインターフェースを変更しました。このプロジェクトも MIT ライセンスに従っています。_
<p align="center">
<a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE">
<img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license">
</a>
<a href="https://github.com/MartialBE/one-api/releases/latest">
<img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://github.com/users/MartialBE/packages/container/package/one-api">
<img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
</a>
<a href="https://goreportcard.com/report/github.com/MartialBE/one-api">
<img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard">
</a>
</p>
**オリジナルバージョンと混合しないでください。チャンネル ID が異なるため、データの混乱を引き起こす可能性があります**
## スクリーンショット
![dashboard](https://github.com/MartialBE/one-api/assets/42402987/c7f95d64-e7e3-4d0f-8ad8-36d6740da8db)
![topup](https://github.com/MartialBE/one-api/assets/42402987/4bc9dbfd-84f6-4700-9ea5-308c09230c7a)
_以下は元の項目の説明です_
---
_✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_
</div>
<p align="center">
<a href="#deployment">デプロイチュートリアル</a>
·
<a href="#usage">使用方法</a>
·
<a href="https://github.com/songquanpeng/one-api/issues">フィードバック</a>
·
<a href="#screenshots">スクリーンショット</a>
·
<a href="https://openai.justsong.cn/">ライブデモ</a>
·
<a href="#faq">FAQ</a>
·
<a href="#related-projects">関連プロジェクト</a>
·
<a href="https://iamazing.cn/page/reward">寄付</a>
</p>
> **警告**: この README は ChatGPT によって翻訳されています。翻訳ミスを発見した場合は遠慮なく PR を投稿してください。
> **警告** 英語版の Docker イメージは `justsong/one-api-en` です。
> **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。
## 特徴
1. 複数の大型モデルをサポート:
- [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート)
- [x] [Anthropic Claude シリーズモデル](https://anthropic.com)
- [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google)
- [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
- [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html)
- [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn)
2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。
3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。
4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。
5. トークンの有効期限や使用回数を設定できる**トークン管理**に対応しています。
6. **バウチャー管理**に対応しており、バウチャーの一括生成やエクスポートが可能です。バウチャーは口座残高の補充に利用できます。
7. **チャンネル管理**に対応し、チャンネルの一括作成が可能。
8. グループごとに異なるレートを設定するための**ユーザーグループ**と**チャンネルグループ**をサポートしています。
9. チャンネル**モデルリスト設定**に対応。
10. **クォータ詳細チェック**をサポート。
11. **ユーザー招待報酬**をサポートします。
12. 米ドルでの残高表示が可能。
13. 新規ユーザー向けのお知らせ公開、リチャージリンク設定、初期残高設定に対応。
14. 豊富な**カスタマイズ**オプションを提供します:
1. システム名、ロゴ、フッターのカスタマイズが可能。
2. HTML と Markdown コードを使用したホームページとアバウトページのカスタマイズ、または iframe を介したスタンドアロンウェブページの埋め込みをサポートしています。
15. システム・アクセストークンによる管理 API アクセスをサポートする。
16. Cloudflare Turnstile によるユーザー認証に対応。
17. ユーザー管理と複数のユーザーログイン/登録方法をサポート:
- 電子メールによるログイン/登録とパスワードリセット。
- [GitHub OAuth](https://github.com/settings/applications/new)。
- WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。
18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。
## デプロイメント
### Docker デプロイメント
デプロイコマンド: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`
コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`
`-p 3000:3000` の最初の `3000` はホストのポートで、必要に応じて変更できます。
データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。
Nginx リファレンス設定:
```
server{
server_name openai.justsong.cn; # ドメイン名は適宜変更
location / {
client_max_body_size 64m;
proxy_http_version 1.1;
proxy_pass http://localhost:3000; # それに応じてポートを変更
proxy_set_header Host $host;
proxy_set_header X-Forwarded-For $remote_addr;
proxy_cache_bypass $http_upgrade;
proxy_set_header Accept-Encoding gzip;
proxy_read_timeout 300s; # GPT-4 はより長いタイムアウトが必要
}
}
```
次に、Let's Encrypt certbot を使って HTTPS を設定します:
```bash
# Ubuntu に certbot をインストール:
sudo snap install --classic certbot
sudo ln -s /snap/bin/certbot /usr/bin/certbot
# 証明書の生成と Nginx 設定の変更
sudo certbot --nginx
# プロンプトに従う
# Nginx を再起動
sudo service nginx restart
```
初期アカウントのユーザー名は `root` で、パスワードは `123456` です。
### マニュアルデプロイ
1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする:
```shell
git clone https://github.com/songquanpeng/one-api.git
# フロントエンドのビルド
cd one-api/web
npm install
npm run build
# バックエンドのビルド
cd ..
go mod download
go build -ldflags "-s -w" -o one-api
```
2. 実行:
```shell
chmod u+x one-api
./one-api --port 3000 --log-dir ./logs
```
3. [http://localhost:3000/](http://localhost:3000/) にアクセスし、ログインする。初期アカウントのユーザー名は `root`、パスワードは `123456` である。
より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。
### マルチマシンデプロイ
1. すべてのサーバに同じ `SESSION_SECRET` を設定する。
2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。
3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。
4. データベースから定期的に設定を同期するサーバーには `SYNC_FREQUENCY` を設定する。
5. マスター以外のノードでは、オプションで `FRONTEND_BASE_URL` を設定して、ページ要求をマスターサーバーにリダイレクトすることができます。
6. マスター以外のノードには Redis を個別にインストールし、`REDIS_CONN_STRING` を設定して、キャッシュの有効期限が切れていないときにデータベースにゼロレイテンシーでアクセスできるようにする。
7. メインサーバーでもデータベースへのアクセスが高レイテンシになる場合は、Redis を有効にし、`SYNC_FREQUENCY` を設定してデータベースから定期的に設定を同期する必要がある。
Please refer to the [environment variables](#environment-variables) section for details on using environment variables.
### コントロールパネル(例: Baotaへの展開
詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。
配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。
### サードパーティプラットフォームへのデプロイ
<details>
<summary><strong>Sealos へのデプロイ</strong></summary>
<div>
> Sealos は、高い同時実行性、ダイナミックなスケーリング、数百万人のユーザーに対する安定した運用をサポートしています。
> 下のボタンをクリックすると、ワンクリックで展開できます。👇
[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
</div>
</details>
<details>
<summary><strong>Zeabur へのデプロイ</strong></summary>
<div>
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. まず、コードをフォークする。
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
3. 新しいプロジェクトを作成します。Service -> Add Service で Marketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
4. 接続パラメータをコピーし、`` create database `one-api` `` を実行してデータベースを作成する。
5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。
6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `<username>:<password>@tcp(<addr>:<port>)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。
7. 再デプロイを選択します。
8. Domains タブで、"my-one-api" のような適切なドメイン名の接頭辞を選択する。最終的なドメイン名は "my-one-api.zeabur.app" となります。独自のドメイン名を CNAME することもできます。
9. デプロイが完了するのを待ち、生成されたドメイン名をクリックして One API にアクセスします。
</div>
</details>
## コンフィグ
システムは箱から出してすぐに使えます。
環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。
システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。
## 使用方法
`Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。
アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。
OpenAI API が使用されている場所では、API Base に One API のデプロイアドレスを設定することを忘れないでください(例: `https://openai.justsong.cn`。API Key は One API で生成されたトークンでなければなりません。
具体的な API Base のフォーマットは、使用しているクライアントに依存することに注意してください。
```mermaid
graph LR
A(ユーザ)
A --->|リクエスト| B(One API)
B -->|中継リクエスト| C(OpenAI)
B -->|中継リクエスト| D(Azure)
B -->|中継リクエスト| E(その他のダウンストリームチャンネル)
```
現在のリクエストにどのチャネルを使うかを指定するには、トークンの後に チャネル ID を追加します: 例えば、`Authorization: Bearer ONE_API_KEY-CHANNEL_ID` のようにします。
チャンネル ID を指定するためには、トークンは管理者によって作成される必要があることに注意してください。
もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。
### 環境変数
1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。
- 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。
- 例: `SESSION_SECRET=random_string`
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
- 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
- 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
- 例: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master`
- 例: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
- 例: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
- 例: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
- 例: `POLLING_INTERVAL=5`
### コマンドラインパラメータ
1. `--port <port_number>`: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。
- 例: `--port 3000`
2. `--log-dir <log_dir>`: ログディレクトリを指定。設定しない場合、ログは保存されません。
- 例: `--log-dir ./logs`
3. `--version`: システムのバージョン番号を表示して終了する。
4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。
## スクリーンショット
![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png)
![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png)
## FAQ
1. ルマとは何かどのように計算されますかOne API にはノルマ計算の問題はありますか?
- ノルマ = グループ倍率 _ モデル倍率 _ (プロンプトトークンの数 + 完了トークンの数 \* 完了倍率)
- 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。
- ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。
2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか?
- トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。
- トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。
3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか?
- ユーザーとチャンネルグループの設定を確認してください。
- チャンネルモデルの設定も確認してください。
4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value"
- このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。
- ほとんどの場合、デプロイサイトの IP かプロキシのノードが CloudFlare によってブロックされています。
5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch"
- デプロイ時に `BASE_URL` を設定しないでください。
- インターフェイスアドレスと API Key が正しいか再確認してください。
## 関連プロジェクト
[FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム
## 注
本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。
このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。
このプロジェクトを基にした派生プロジェクトについても同様です。
帰属表示を含めたくない場合は、事前に許可を得なければなりません。
MIT ライセンスによると、このプロジェクトを利用するリスクと責任は利用者が負うべきであり、このオープンソースプロジェクトの開発者は責任を負いません。

335
README.md
View File

@@ -1,47 +1,38 @@
<p align="right">
<strong>中文</strong> | <a href="./README.en.md">English</a> | <a href="./README.ja.md">日本語</a>
<strong>中文</strong> | <a href="./README.en.md">English</a>
</p>
<p align="center">
<a href="https://github.com/MartialBE/one-api"><img src="https://raw.githubusercontent.com/MartialBE/one-api/main/web/src/assets/images/logo.svg" width="150" height="150" alt="one-api logo"></a>
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p>
<div align="center">
# One API
_本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开发而来的,主要将原项目中的模块代码分离,模块化,并修改了前端界面。本项目同样遵循 MIT 协议。_
<p align="center">
<a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE">
<img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license">
</a>
<a href="https://github.com/MartialBE/one-api/releases/latest">
<img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://github.com/users/MartialBE/packages/container/package/one-api">
<img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
</a>
<a href="https://goreportcard.com/report/github.com/MartialBE/one-api">
<img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard">
</a>
</p>
**请不要和原版混用,因为 channel id 不同的原因,会导致数据错乱**
# 截图展示
![dashboard](https://github.com/MartialBE/one-api/assets/42402987/c7f95d64-e7e3-4d0f-8ad8-36d6740da8db)
![topup](https://github.com/MartialBE/one-api/assets/42402987/4bc9dbfd-84f6-4700-9ea5-308c09230c7a)
_以下为原项目说明_
---
_✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_
_✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用✨_
</div>
<p align="center">
<a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE">
<img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license">
</a>
<a href="https://github.com/songquanpeng/one-api/releases/latest">
<img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://hub.docker.com/repository/docker/justsong/one-api">
<img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull">
</a>
<a href="https://github.com/songquanpeng/one-api/releases/latest">
<img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://goreportcard.com/report/github.com/songquanpeng/one-api">
<img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard">
</a>
</p>
<p align="center">
<a href="https://github.com/songquanpeng/one-api#部署">部署教程</a>
·
@@ -60,82 +51,65 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
<a href="https://iamazing.cn/page/reward">赞赏支持</a>
</p>
> [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
>
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> **Note**:本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> [!WARNING]
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> **Note**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> [!WARNING]
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
## 功能
1. 支持多种大模型:
- [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
- [x] [Anthropic Claude 系列模型](https://anthropic.com)
- [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
- [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
- [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
- [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
- [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
- [x] [360 智脑](https://ai.360.cn)
- [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)
3. 支持通过**负载均衡**的方式访问多个渠道
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果
5. 支持**多机部署**[详见此处](#多机部署)
6. 支持**令牌管理**设置令牌的过期时间和额度
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值
8. 支持**通道管理**,批量创建通道
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率
10. 支持渠道**设置模型列表**。
11. 支持**查看额度明细**
12. 支持**用户邀请奖励**
13. 支持以美元为单位显示额度
14. 支持发布公告,设置充值链接,设置新用户初始额度
15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功
16. 支持失败自动重试。
17. 支持绘图接口。
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置,
1. 支持多种 API 访问渠道:
+ [x] OpenAI 官方通道(支持配置镜像)
+ [x] **Azure OpenAI API**
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
2. 支持通过**负载均衡**的方式访问多个渠道。
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果
4. 支持**多机部署**[详见此处](#多机部署)
5. 支持**令牌管理**,设置令牌的过期时间和额度
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值
7. 支持**通道管理**批量创建通道
8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率
9. 支持渠道**设置模型列表**
10. 支持**查看额度明细**
11. 支持**用户邀请奖励**。
12. 支持以美元为单位显示额度
13. 支持发布公告,设置充值链接,设置新用户初始额度
14. 支持模型映射,重定向用户的请求模型
15. 支持失败自动重试
16. 支持绘图接口
17. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法)
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
- 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
- [GitHub 开放授权](https://github.com/settings/applications/new)。
- 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
18. 支持通过系统访问令牌访问管理 API。
19. 支持 Cloudflare Turnstile 用户校验。
20. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
21. 支持 [ChatGLM](https://github.com/THUDM/ChatGLM2-6B)。
22. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:
docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
# 例如:
docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
```
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
如果你的并发量较大,推荐设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR`
Nginx 的参考配置:
`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
Nginx 的参考配置:
```
server{
server_name openai.justsong.cn; # 请根据实际情况修改你的域名
@@ -154,7 +128,6 @@ server{
```
之后使用 Let's Encrypt 的 certbot 配置 HTTPS
```bash
# Ubuntu 安装 certbot
sudo snap install --classic certbot
@@ -168,22 +141,8 @@ sudo service nginx restart
初始账号用户名为 `root`,密码为 `123456`
### 基于 Docker Compose 进行部署
> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
```shell
# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
docker-compose up -d
# 查看部署状态
docker-compose ps
```
### 手动部署
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell
git clone https://github.com/songquanpeng/one-api.git
@@ -196,8 +155,7 @@ docker-compose ps
cd ..
go mod download
go build -ldflags "-s -w" -o one-api
```
````
2. 运行:
```shell
chmod u+x one-api
@@ -208,7 +166,6 @@ docker-compose ps
更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
### 多机部署
1. 所有服务器 `SESSION_SECRET` 设置一样的值。
2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite所有服务器连接同一个数据库。
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
@@ -226,11 +183,9 @@ docker-compose ps
如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。
### 部署第三方服务配合 One API 使用
> 欢迎 PR 添加更多示例。
#### ChatGPT Next Web
项目主页https://github.com/Yidadaa/ChatGPT-Next-Web
```bash
@@ -240,7 +195,6 @@ docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web
注意修改端口号之后在页面上设置接口地址例如https://openai.justsong.cn/ )和 API Key 即可。
#### ChatGPT Web
项目主页https://github.com/Chanzhaoyu/chatgpt-web
```bash
@@ -249,25 +203,14 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
#### QChatGPT - QQ 机器人
项目主页https://github.com/RockChinQ/QChatGPT
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的 key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
### 部署到第三方平台
<details>
<summary><strong>部署到 Sealos </strong></summary>
<div>
> Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩
> Sealos 可视化部署,仅需 1 分钟
点击以下按钮一键部署(部署后访问出现 404 请等待 3~5 分钟):
[![Deploy-on-Sealos.svg](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
参考这个[教程](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md)中 1~5 步。
</div>
</details>
@@ -276,14 +219,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
<summary><strong>部署到 Zeabur</strong></summary>
<div>
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
1. 首先 fork 一份代码。
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
2. 进入 [Zeabur](https://zeabur.com/),登录,进入控制台。
3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口
4. 复制链接参数,运行 `` create database `one-api` `` 创建数据库。
4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。
5. 然后在 Service -> Add Service选择 Git第一次使用需要先授权选择你 fork 的仓库。
6. Deploy 会自动开始,先取消。进入下方 Variable添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `<username>:<password>@tcp(<addr>:<port>)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。
7. 选择 Redeploy。
@@ -293,19 +234,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
</div>
</details>
<details>
<summary><strong>部署到 Render</strong></summary>
<div>
> Render 提供免费额度,绑卡后可以进一步提升额度
Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashboard.render.com
</div>
</details>
## 配置
系统本身开箱即用。
你可以通过设置环境变量或者命令行参数进行配置。
@@ -315,7 +244,6 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashbo
**Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。
## 使用方法
在`渠道`页面中添加你的 API Key之后在`令牌`页面中新增访问令牌。
之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。
@@ -324,21 +252,13 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashbo
注意,具体的 API Base 的格式取决于你所使用的客户端。
例如对于 OpenAI 的官方库:
```bash
OPENAI_API_KEY="sk-xxxxxx"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
```
```mermaid
graph LR
A(用户)
A --->|使用 One API 分发的 key 进行请求| B(One API)
A --->|请求| B(One API)
B -->|中继请求| C(OpenAI)
B -->|中继请求| D(Azure)
B -->|中继请求| E(其他 OpenAI API 格式下游渠道)
B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
B -->|中继请求| E(其他下游渠道)
```
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
@@ -347,106 +267,67 @@ graph LR
不加的话将会使用负载均衡的方式使用多个渠道。
### 环境变量
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
- 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
- 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
- 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL
- 例子:
- MySQL`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
- PostgreSQL`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈)
- 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表
- 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。
- 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。
- 请根据你的数据库配置修改下列参数(或者保持默认值):
- `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。
- `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
- 如果报错 `Error 1040: Too many connections`,请适当减小该值。
- `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
+ 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 8.0 版本
+ 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
+ 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。
+ 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。
+ 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
- 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`
- 例子:`MEMORY_CACHE_ENABLED=true`
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒
- 例子:`SYNC_FREQUENCY=60`
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`
- 例子:`NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新
- 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查
- 例子:`CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
- 例子:`POLLING_INTERVAL=5`
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
- 例子:`BATCH_UPDATE_ENABLED=true`
- 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
- 例子:`BATCH_UPDATE_INTERVAL=5`
13. 请求频率限制:
- `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
- `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置:
- `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
- `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
16. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步
+ 例子:`SYNC_FREQUENCY=60`
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`
+ 例子:`NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔
+ 例子:`POLLING_INTERVAL=5`
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
- 例子:`--port 3000`
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下
- 例子:`--log-dir ./logs`
+ 例子:`--port 3000`
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存
+ 例子:`--log-dir ./logs`
3. `--version`: 打印系统版本号并退出。
4. `--help`: 查看命令的使用帮助和参数说明。
## 演示
### 在线演示
注意,该演示站不提供对外服务:
https://openai.justsong.cn
### 截图展示
![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png)
![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png)
## 常见问题
1. 额度是什么怎么计算的One API 的额度计算有问题?
- 额度 = 分组倍率 _ 模型倍率 _ (提示 token 数 + 补全 token 数 \* 补全倍率)
- 其中补全倍率对于 GPT3.5 固定为 1.33GPT4 为 2与官方保持一致。
- 如果是非流模式,官方接口会返回消耗的总 token但是你要注意提示和补全的消耗倍率不一样。
- 注意One API 的默认倍率就是官方倍率,是已经调整过的。
+ 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率)
+ 其中补全倍率对于 GPT3.5 固定为 1.33GPT4 为 2与官方保持一致。
+ 如果是非流模式,官方接口会返回消耗的总 token但是你要注意提示和补全的消耗倍率不一样。
2. 账户额度足够为什么提示额度不足?
- 请检查你的令牌额度是否足够,这个和账户额度是分开的。
- 令牌额度仅供用户设置最大使用量,用户可自由设置。
+ 请检查你的令牌额度是否足够,这个和账户额度是分开的。
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。
3. 提示无可用渠道?
- 请检查的用户分组和渠道分组设置。
- 以及渠道的模型设置。
+ 请检查的用户分组和渠道分组设置。
+ 以及渠道的模型设置。
4. 渠道测试报错:`invalid character '<' looking for beginning of value`
- 这是因为返回值不是合法的 JSON而是一个 HTML 页面。
- 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。
+ 这是因为返回值不是合法的 JSON而是一个 HTML 页面。
+ 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。
5. ChatGPT Next Web 报错:`Failed to fetch`
- 部署的时候不要设置 `BASE_URL`。
- 检查你的接口地址和 API Key 有没有填对。
- 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
+ 部署的时候不要设置 `BASE_URL`。
+ 检查你的接口地址和 API Key 有没有填对。
6. 报错:`当前分组负载已饱和,请稍后再试`
- 上游通道 429 了。
7. 升级之后我的数据会丢失吗?
- 如果使用 MySQL不会。
- 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
8. 升级之前数据库需要做变更吗?
- 一般情况下不需要,系统将在初始化的时候自动调整。
- 如果需要的话,我会在更新日志中说明,并给出脚本。
+ 上游通道 429 了。
## 相关项目
- [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
- [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
[FastGPT](https://github.com/c121914yu/FastGPT): 三分钟搭建 AI 知识库
## 注意

View File

@@ -1,257 +0,0 @@
package common
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/types"
"strconv"
"time"
"github.com/gin-gonic/gin"
)
var HttpClient *http.Client
func init() {
if RelayTimeout == 0 {
HttpClient = &http.Client{}
} else {
HttpClient = &http.Client{
Timeout: time.Duration(RelayTimeout) * time.Second,
}
}
}
type Client struct {
requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder
}
func NewClient() *Client {
return &Client{
requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
},
}
}
type requestOptions struct {
body any
header http.Header
}
type requestOption func(*requestOptions)
type Stringer interface {
GetString() *string
}
func WithBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}
func WithHeader(header map[string]string) requestOption {
return func(args *requestOptions) {
for k, v := range header {
args.header.Set(k, v)
}
}
}
func WithContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}
type RequestError struct {
HTTPStatusCode int
Err error
}
func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
// Default Options
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := c.requestBuilder.Build(method, url, args.body, args.header)
if err != nil {
return nil, err
}
return req, nil
}
func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := HttpClient.Do(req)
if err != nil {
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp)
}
// 解析响应
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = DecodeResponse(resp.Body, response)
}
if err != nil {
return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
if outputResp {
return resp, nil
}
return nil, nil
}
type GeneralErrorResponse struct {
Error types.OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
// 处理错误响应
func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
// var errorResponse types.OpenAIErrorResponse
var errorResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errorResponse)
if err != nil {
return
}
if errorResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage()
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}
func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
resp, err := HttpClient.Do(req)
if err != nil {
return
}
return resp.Body, nil
}
func IsFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
func DecodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return DecodeString(body, result)
}
if stringer, ok := v.(Stringer); ok {
return DecodeString(body, stringer.GetString())
}
return json.NewDecoder(body).Decode(v)
}
func DecodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@@ -21,9 +21,12 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var UsingSQLite = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var SQLitePath = "one-api.db"
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
@@ -35,26 +38,11 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var DiscordOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
@@ -66,6 +54,12 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var DiscordClientId = ""
var DiscordClientSecret = ""
var DiscordGuildId = ""
var DiscordAllowJoiningGuild = "false"
var DiscordBotToken = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -78,7 +72,6 @@ var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
@@ -91,17 +84,6 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const (
RoleGuestUser = 0
RoleCommonUser = 1
@@ -119,10 +101,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = 180
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 100)
GlobalWebRateLimitNum = 60
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@@ -158,8 +140,7 @@ const (
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
ChannelStatusDisabled = 2 // also don't use 0
)
const (
@@ -177,18 +158,10 @@ const (
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeAzureSpeech = 24
ChannelTypeGemini = 25
// Reserve engineering for public projects
ChannelTypeChatGPTWeb = 14 // Chanzhaoyu/chatgpt-web
ChannelTypeChatbotUI = 15 // mckaywrigley/chatbot-ui
)
var ChannelBaseURLs = []string{
@@ -206,31 +179,8 @@ var ChannelBaseURLs = []string{
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", //24
"", //25
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeImagesEdits
RelayModeImagesVariations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
// Reserve engineering for public projects
"", // 14 // Chanzhaoyu/chatgpt-web
"", // 15 // mckaywrigley/chatbot-ui
}

View File

@@ -1,7 +0,0 @@
package common
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -1,13 +1,11 @@
package common
import (
"crypto/rand"
"crypto/tls"
"encoding/base64"
"fmt"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
@@ -15,32 +13,15 @@ func SendEmail(subject string, receiver string, content string) error {
SMTPFrom = SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
}
// Generate a unique Message-ID
buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
return err
}
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
receiver, SystemName, SMTPFrom, encodedSubject, content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
var err error
if SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,

View File

@@ -1,71 +0,0 @@
package common
import (
"fmt"
"io"
"mime/multipart"
"path"
)
type FormBuilder interface {
CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
}
type DefaultFormBuilder struct {
writer *multipart.Writer
}
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
return &DefaultFormBuilder{
writer: multipart.NewWriter(body),
}
}
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error {
file, err := fileHeader.Open()
if err != nil {
return err
}
defer file.Close()
return fb.createFormFile(fieldname, file, fileHeader.Filename)
}
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename))
}
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
if filename == "" {
return fmt.Errorf("filename cannot be empty")
}
fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}
return nil
}
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
return fb.writer.WriteField(fieldname, value)
}
func (fb *DefaultFormBuilder) Close() error {
return fb.writer.Close()
}
func (fb *DefaultFormBuilder) FormDataContentType() string {
return fb.writer.FormDataContentType()
}

View File

@@ -2,12 +2,9 @@ package common
import (
"bytes"
"fmt"
"io"
"one-api/types"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"io"
)
func UnmarshalBodyReusable(c *gin.Context, v any) error {
@@ -19,43 +16,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil {
return err
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = c.ShouldBind(v)
err = json.Unmarshal(requestBody, &v)
if err != nil {
if errs, ok := err.(validator.ValidationErrors); ok {
// 返回第一个错误字段的名称
return fmt.Errorf("field %s is required", errs[0].Field())
}
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWithStatusCode {
return StringErrorWrapper(err.Error(), code, statusCode)
}
func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode {
openAIError := types.OpenAIError{
Message: err,
Type: "one_api_error",
Code: code,
}
return &types.OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func AbortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
LogError(c.Request.Context(), message)
}

View File

@@ -1,64 +0,0 @@
package image
import (
"bytes"
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"net/http"
"regexp"
"strings"
"sync"
_ "golang.org/x/image/webp"
)
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
img, _, err := image.DecodeConfig(resp.Body)
if err != nil {
return
}
return img.Width, img.Height, nil
}
var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
var readerPool = sync.Pool{
New: func() interface{} {
return &bytes.Reader{}
},
}
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, ""))
if err != nil {
return 0, 0, err
}
reader := readerPool.Get().(*bytes.Reader)
defer readerPool.Put(reader)
reader.Reset(decoded)
img, _, err := image.DecodeConfig(reader)
if err != nil {
return 0, 0, err
}
return img.Width, img.Height, nil
}
func GetImageSize(image string) (width int, height int, err error) {
if strings.HasPrefix(image, "data:image/") {
return GetImageSizeFromBase64(image)
}
return GetImageSizeFromUrl(image)
}

View File

@@ -1,154 +0,0 @@
package image_test
import (
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strconv"
"strings"
"testing"
img "one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
)
type CountingReader struct {
reader io.Reader
BytesRead int
}
func (r *CountingReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.BytesRead += n
return n, err
}
var (
cases = []struct {
url string
format string
width int
height int
}{
{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669},
{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592},
{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985},
{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533},
{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230},
}
)
func TestDecode(t *testing.T) {
// Bytes read: varies sometimes
// jpeg: 1063892
// png: 294462
// webp: 99529
// gif: 956153
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 4096
// png: 4096
// webp: 4096
// gif: 4096
// jpeg#01: 4096
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestBase64(t *testing.T) {
// Bytes read:
// jpeg: 1063892
// png: 294462
// webp: 99072
// gif: 953856
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 1536
// png: 768
// webp: 768
// gif: 1536
// jpeg#01: 3840
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestGetImageSize(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
width, height, err := img.GetImageSize(c.url)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

View File

@@ -6,15 +6,13 @@ import (
"log"
"os"
"path/filepath"
"github.com/joho/godotenv"
)
var (
Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
LogDir = flag.String("log-dir", "", "specify the log directory")
)
func printHelp() {
@@ -25,11 +23,6 @@ func printHelp() {
}
func init() {
// 加载.env文件
err := godotenv.Load()
if err != nil {
SysLog("failed to load .env file: " + err.Error())
}
flag.Parse()
if *PrintVersion {
@@ -43,12 +36,8 @@ func init() {
}
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")
}

16
common/ip-gen.go Normal file
View File

@@ -0,0 +1,16 @@
package common
import (
"fmt"
"math/rand"
)
func GenerateIP() string {
// Generate a random number between 20 and 240
segment2 := rand.Intn(221) + 20
segment3 := rand.Intn(256)
segment4 := rand.Intn(256)
ipAddress := fmt.Sprintf("104.%d.%d.%d", segment2, segment3, segment4)
return ipAddress
}

View File

@@ -1,47 +1,29 @@
package common
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"os"
"path/filepath"
"sync"
"time"
)
const (
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
func SetupGinLog() {
if *LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
commonLogPath := filepath.Join(*LogDir, "common.log")
errorLogPath := filepath.Join(*LogDir, "error.log")
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
}
}
@@ -55,36 +37,6 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func LogInfo(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
SetupLogger()
}()
}
}
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

View File

@@ -1,15 +0,0 @@
package common
import (
"encoding/json"
)
type Marshaller interface {
Marshal(value any) ([]byte, error)
}
type JSONMarshaller struct{}
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

View File

@@ -1,41 +1,12 @@
package common
import (
"encoding/json"
"strings"
"time"
)
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
import "encoding/json"
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
"gpt-4": 15,
"gpt-4-0314": 15,
@@ -43,15 +14,11 @@ var ModelRatio = map[string]float64{
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
@@ -59,11 +26,7 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"whisper-1": 10,
"davinci": 10,
"curie": 10,
"babbage": 10,
@@ -72,33 +35,7 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"claude-2.0": 5.51, // $11.02 / 1M tokens
"claude-2.1": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"dall-e": 8,
}
func ModelRatio2JSONString() string {
@@ -122,34 +59,3 @@ func GetModelRatio(name string) float64 {
}
return ratio
}
func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "1106") {
return 2
}
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
// TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
}
return 1.333333
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3.38
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
}
return 1
}

View File

@@ -1,59 +0,0 @@
package common
// type Quota struct {
// ModelName string
// ModelRatio float64
// GroupRatio float64
// Ratio float64
// UserQuota int
// }
// func CreateQuota(modelName string, userQuota int, group string) *Quota {
// modelRatio := GetModelRatio(modelName)
// groupRatio := GetGroupRatio(group)
// return &Quota{
// ModelName: modelName,
// ModelRatio: modelRatio,
// GroupRatio: groupRatio,
// Ratio: modelRatio * groupRatio,
// UserQuota: userQuota,
// }
// }
// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
// if ApproximateTokenEnabled {
// return int(float64(len(text)) * 0.38)
// }
// return len(tokenEncoder.Encode(text, nil, nil))
// }
// func (q *Quota) CountTokenMessages(messages []Message, model string) int {
// tokenEncoder := q.getTokenEncoder(model)
// // Reference:
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// // https://github.com/pkoukk/tiktoken-go/issues/6
// //
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
// var tokensPerMessage int
// var tokensPerName int
// if model == "gpt-3.5-turbo-0301" {
// tokensPerMessage = 4
// tokensPerName = -1 // If there's a name, the role is omitted
// } else {
// tokensPerMessage = 3
// tokensPerName = 1
// }
// tokenNum := 0
// for _, message := range messages {
// tokenNum += tokensPerMessage
// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent())
// tokenNum += q.getTokenNum(tokenEncoder, message.Role)
// if message.Name != nil {
// tokenNum += tokensPerName
// tokenNum += q.getTokenNum(tokenEncoder, *message.Name)
// }
// }
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
// return tokenNum
// }

View File

@@ -61,8 +61,3 @@ func RedisDel(key string) error {
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisDecrease(key string, value int64) error {
ctx := context.Background()
return RDB.DecrBy(ctx, key, value).Err()
}

View File

@@ -1,50 +0,0 @@
package common
import (
"bytes"
"io"
"net/http"
)
type RequestBuilder interface {
Build(method, url string, body any, header http.Header) (*http.Request, error)
}
type HTTPRequestBuilder struct {
marshaller Marshaller
}
func NewRequestBuilder() *HTTPRequestBuilder {
return &HTTPRequestBuilder{
marshaller: &JSONMarshaller{},
}
}
func (b *HTTPRequestBuilder) Build(
method string,
url string,
body any,
header http.Header,
) (req *http.Request, err error) {
var bodyReader io.Reader
if body != nil {
if v, ok := body.(io.Reader); ok {
bodyReader = v
} else {
var reqBytes []byte
reqBytes, err = b.marshaller.Marshal(body)
if err != nil {
return
}
bodyReader = bytes.NewBuffer(reqBytes)
}
}
req, err = http.NewRequest(method, url, bodyReader)
if err != nil {
return
}
if header != nil {
req.Header = header
}
return
}

View File

@@ -1,238 +0,0 @@
package common
import (
"errors"
"fmt"
"math"
"strings"
"one-api/common/image"
"one-api/types"
"github.com/pkoukk/tiktoken-go"
)
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
tokenEncoderMap[model] = nil
}
}
SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
return tokenEncoder
}
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
return defaultTokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
}
func CountTokenMessages(messages []types.ChatCompletionMessage, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
switch v := message.Content.(type) {
case string:
tokenNum += getTokenNum(tokenEncoder, v)
case []any:
for _, it := range v {
m := it.(map[string]any)
switch m["type"] {
case "text":
tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
case "image_url":
imageUrl, ok := m["image_url"].(map[string]any)
if ok {
url := imageUrl["url"].(string)
detail := ""
if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail)
if err != nil {
SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
}
}
}
}
tokenNum += getTokenNum(tokenEncoder, message.StringContent())
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum
}
const (
lowDetailCost = 85
highDetailCostPerTile = 170
additionalCost = 85
)
// https://platform.openai.com/docs/guides/vision/calculating-costs
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
func countImageTokens(url string, detail string) (_ int, err error) {
var fetchSize = true
var width, height int
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
// According to the official guide, "low" disable the high-res model,
// and only receive low-res 512px x 512px version of the image, indicating
// that image is treated as low-res when size is smaller than 512px x 512px,
// then we can assume that image size larger than 512px x 512px is treated
// as high-res. Then we have the following logic:
// if detail == "" || detail == "auto" {
// width, height, err = image.GetImageSize(url)
// if err != nil {
// return 0, err
// }
// fetchSize = false
// // not sure if this is correct
// if width > 512 || height > 512 {
// detail = "high"
// } else {
// detail = "low"
// }
// }
// However, in my test, it seems to be always the same as "high".
// The following image, which is 125x50, is still treated as high-res, taken
// 255 tokens in the response of non-stream chat completion api.
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
if detail == "" || detail == "auto" {
// assume by test, not sure if this is correct
detail = "high"
}
switch detail {
case "low":
return lowDetailCost, nil
case "high":
if fetchSize {
width, height, err = image.GetImageSize(url)
if err != nil {
return 0, err
}
}
if width > 2048 || height > 2048 { // max(width, height) > 2048
ratio := float64(2048) / math.Max(float64(width), float64(height))
width = int(float64(width) * ratio)
height = int(float64(height) * ratio)
}
if width > 768 && height > 768 { // min(width, height) > 768
ratio := float64(768) / math.Min(float64(width), float64(height))
width = int(float64(width) * ratio)
height = int(float64(height) * ratio)
}
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
result := numSquares*highDetailCostPerTile + additionalCost
return result, nil
default:
return 0, errors.New("invalid detail option")
}
}
func CountTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
return CountTokenInput(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenInput(text, model)
}
return 0
}
func CountTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
}
func CountTokenImage(input interface{}) (int, error) {
switch v := input.(type) {
case types.ImageRequest:
// 处理 ImageRequest
return calculateToken(v.Model, v.Size, v.N, v.Quality)
case types.ImageEditRequest:
// 处理 ImageEditsRequest
return calculateToken(v.Model, v.Size, v.N, "")
default:
return 0, errors.New("unsupported type")
}
}
func calculateToken(model string, size string, n int, quality string) (int, error) {
imageCostRatio, hasValidSize := DalleSizeRatios[model][size]
if hasValidSize {
if quality == "hd" && model == "dall-e-3" {
if size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
} else {
return 0, errors.New("size not supported for this image model")
}
return int(imageCostRatio*1000) * n, nil
}

View File

@@ -7,7 +7,6 @@ import (
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
@@ -171,11 +170,6 @@ func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
@@ -183,27 +177,3 @@ func Max(a int, b int) int {
return b
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -3,7 +3,6 @@ package controller
import (
"one-api/common"
"one-api/model"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -12,36 +11,27 @@ func GetSubscription(c *gin.Context) {
var remainQuota int
var usedQuota int
var err error
var token *model.Token
var expiredTime int64
var expirationDate int64
tokenId := c.GetInt("token_id")
token, err := model.GetTokenById(tokenId)
expirationDate = token.ExpiredTime
if common.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
} else {
userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId)
if err != nil {
openAIError := types.OpenAIError{
Message: err.Error(),
Type: "upstream_error",
}
c.JSON(200, gin.H{
"error": openAIError,
})
return
}
usedQuota, err = model.GetUserUsedQuota(userId)
}
if expiredTime <= 0 {
expiredTime = 0
}
if err != nil {
openAIError := types.OpenAIError{
openAIError := OpenAIError{
Message: err.Error(),
Type: "upstream_error",
Type: "one_api_error",
}
c.JSON(200, gin.H{
"error": openAIError,
@@ -62,9 +52,10 @@ func GetSubscription(c *gin.Context) {
SoftLimitUSD: amount,
HardLimitUSD: amount,
SystemHardLimitUSD: amount,
AccessUntil: expiredTime,
AccessUntil: expirationDate,
}
c.JSON(200, subscription)
return
}
func GetUsage(c *gin.Context) {
@@ -80,7 +71,7 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
openAIError := types.OpenAIError{
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
}
@@ -98,4 +89,5 @@ func GetUsage(c *gin.Context) {
TotalUsage: amount * 100,
}
c.JSON(200, usage)
return
}

View File

@@ -1,13 +1,13 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"one-api/common"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"strconv"
"time"
@@ -46,30 +46,217 @@ type OpenAIUsageResponse struct {
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
req, err := http.NewRequest("POST", "/balance", nil)
type OpenAISBUsageResponse struct {
Msg string `json:"msg"`
Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}
type AIProxyUserOverviewResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
ErrorCode int `json:"error_code"`
Data struct {
TotalPoints float64 `json:"totalPoints"`
} `json:"data"`
}
type API2GPTUsageResponse struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
type APGC2DGPTUsageResponse struct {
//Grants interface{} `json:"grants"`
Object string `json:"object"`
TotalAvailable float64 `json:"total_available"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
}
// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
return h
}
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
return nil, err
}
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status code: %d", res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
err = res.Body.Close()
if err != nil {
return nil, err
}
return body, nil
}
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
setChannelToContext(c, channel)
req.Header.Set("Content-Type", "application/json")
provider := providers.GetProvider(channel.Type, c)
if provider == nil {
return 0, errors.New("provider not found")
response := OpenAICreditGrants{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}
balanceProvider, ok := provider.(providersBase.BalanceInterface)
if !ok {
return 0, errors.New("provider not implemented")
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := OpenAISBUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Data == nil {
return 0, errors.New(response.Msg)
}
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
return balanceProvider.Balance(channel)
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
url := "https://aiproxy.io/api/report/getUserOverview"
headers := http.Header{}
headers.Add("Api-Key", channel.Key)
body, err := GetResponseBody("GET", url, channel, headers)
if err != nil {
return 0, err
}
response := AIProxyUserOverviewResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if !response.Success {
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
}
channel.UpdateBalance(response.Data.TotalPoints)
return response.Data.TotalPoints, nil
}
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := API2GPTUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalRemaining)
return response.TotalRemaining, nil
}
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := APGC2DGPTUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.BaseURL == "" {
channel.BaseURL = baseURL
}
switch channel.Type {
case common.ChannelTypeOpenAI:
if channel.BaseURL != "" {
baseURL = channel.BaseURL
}
case common.ChannelTypeAzure:
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.BaseURL
case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D:
return updateChannelAIGC2DBalance(channel)
default:
return 0, errors.New("尚未实现")
}
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
subscription := OpenAISubscriptionResponse{}
err = json.Unmarshal(body, &subscription)
if err != nil {
return 0, err
}
now := time.Now()
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
endDate := now.Format("2006-01-02")
if !subscription.HasPaymentMethod {
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
}
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
usage := OpenAIUsageResponse{}
err = json.Unmarshal(body, &usage)
if err != nil {
return 0, err
}
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)
return balance, nil
}
func UpdateChannelBalance(c *gin.Context) {

View File

@@ -1,97 +1,278 @@
package controller
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/http/httptest"
"one-api/common"
"one-api/model"
"one-api/providers"
providers_base "one-api/providers/base"
"one-api/types"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) {
// 创建一个 http.Request
req, err := http.NewRequest("POST", "/v1/chat/completions", nil)
func formatFloat(input float64) float64 {
if input == float64(int64(input)) {
return input
}
return float64(int64(input*10)) / 10
}
func testChannel(channel *model.Channel, request ChatRequest) error {
switch channel.Type {
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
default:
request.Model = "gpt-3.5-turbo"
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else if channel.Type == common.ChannelTypeChatGPTWeb {
if channel.BaseURL != "" {
requestURL = channel.BaseURL
}
} else if channel.Type == common.ChannelTypeChatbotUI {
if channel.BaseURL != "" {
requestURL = channel.BaseURL
}
} else {
if channel.BaseURL != "" {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if channel.Type == common.ChannelTypeChatGPTWeb {
// Get system message from Message json, Role == "system"
var systemMessage Message
for _, message := range request.Messages {
if message.Role == "system" {
systemMessage = message
break
}
}
var prompt string
// Get all the Message, Roles from request.Messages, and format it into string by
// ||> role: content
for _, message := range request.Messages {
// Exclude system message
if message.Role == "system" {
continue
}
prompt += "||> " + message.Role + ": " + message.Content + "\n"
}
// Construct json data without adding escape character
map1 := make(map[string]interface{})
map1["prompt"] = prompt
map1["systemMessage"] = systemMessage.Content
if request.Temperature != 0 {
map1["temperature"] = formatFloat(request.Temperature)
}
if request.TopP != 0 {
map1["top_p"] = formatFloat(request.TopP)
}
// Convert map to json string
jsonData, err = json.Marshal(map1)
} else if channel.Type == common.ChannelTypeChatbotUI {
// Get system message from Message json, Role == "system"
var systemMessage string
for _, message := range request.Messages {
if message.Role == "system" {
systemMessage = message.Content
break
}
}
// Construct json data without adding escape character
map1 := make(map[string]interface{})
map1["prompt"] = systemMessage
map1["temperature"] = formatFloat(request.Temperature)
map1["key"] = ""
map1["messages"] = request.Messages
map1["model"] = map[string]interface{}{
"id": request.Model,
}
// Convert map to json string
jsonData, err = json.Marshal(map1)
//Print jsoinData to console
log.Println(string(jsonData))
}
if err != nil {
return err, nil
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
setChannelToContext(c, channel)
// 创建映射
channelTypeToModel := map[int]string{
common.ChannelTypePaLM: "PaLM-2",
common.ChannelTypeAnthropic: "claude-2",
common.ChannelTypeBaidu: "ERNIE-Bot",
common.ChannelTypeZhipu: "chatglm_lite",
common.ChannelTypeAli: "qwen-turbo",
common.ChannelType360: "360GPT_S2_V9",
common.ChannelTypeXunfei: "SparkDesk",
common.ChannelTypeTencent: "hunyuan",
common.ChannelTypeAzure: "gpt-3.5-turbo",
if channel.EnableIpRandomization {
// Generate random IP
ip := common.GenerateIP()
req.Header.Set("X-Forwarded-For", ip)
req.Header.Set("X-Real-IP", ip)
req.Header.Set("X-Client-IP", ip)
req.Header.Set("X-Forwarded-Host", ip)
req.Header.Set("X-Originating-IP", ip)
req.RemoteAddr = ip
req.Header.Set("X-Remote-IP", ip)
req.Header.Set("X-Remote-Addr", ip)
}
// 从映射中获取模型名称
model, ok := channelTypeToModel[channel.Type]
if !ok {
model = "gpt-3.5-turbo" // 默认值
}
request.Model = model
provider := providers.GetProvider(channel.Type, c)
if provider == nil {
return errors.New("channel not implemented"), nil
}
chatProvider, ok := provider.(providers_base.ChatInterface)
if !ok {
return errors.New("channel not implemented"), nil
}
modelMap, err := parseModelMapping(channel.GetModelMapping())
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err, nil
}
if modelMap != nil && modelMap[request.Model] != "" {
request.Model = modelMap[request.Model]
return err
}
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
if openAIErrorWithStatusCode != nil {
return nil, &openAIErrorWithStatusCode.OpenAIError
if resp.StatusCode != http.StatusOK {
// Print the body in string
if resp.Body != nil {
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
return errors.New("error response: " + strconv.Itoa(resp.StatusCode) + " " + buf.String())
}
if Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil
return errors.New("error response: " + strconv.Itoa(resp.StatusCode))
}
return nil, nil
var streamResponseText = ""
scanner := bufio.NewScanner(resp.Body)
if channel.Type != common.ChannelTypeChatGPTWeb && channel.Type != common.ChannelTypeChatbotUI {
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
func buildTestRequest() *types.ChatCompletionRequest {
testRequest := &types.ChatCompletionRequest{
Messages: []types.ChatCompletionMessage{
{
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
}
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // must be something wrong!
continue
}
if channel.Type == common.ChannelTypeChatGPTWeb {
var chatResponse ChatGptWebChatResponse
err = json.Unmarshal([]byte(data), &chatResponse)
if err != nil {
// Print the body in string
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
return err
}
// if response role is assistant and contains delta, append the content to streamResponseText
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
for _, choice := range chatResponse.Detail.Choices {
streamResponseText += choice.Delta.Content
}
}
} else if channel.Type == common.ChannelTypeChatbotUI {
streamResponseText += data
} else if channel.Type != common.ChannelTypeChatGPTWeb {
// If data has event: event content inside, remove it, it can be prefix or inside the data
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back
data = strings.TrimPrefix(data, "event: event")
data = strings.TrimSuffix(data, "event: event")
// Remove everything, only keep `data: {...}` <--- this is the json
// Find the start and end indices of `data: {...}` substring
startIndex := strings.Index(data, "data:")
endIndex := strings.LastIndex(data, "}")
// If both indices are found and end index is greater than start index
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
// Extract the `data: {...}` substring
data = data[startIndex : endIndex+1]
}
// Trim whitespace and newlines from the modified data string
data = strings.TrimSpace(data)
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
var streamResponse ChatCompletionsStreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
// Prinnt the body in string
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
common.SysError("error unmarshalling stream response: " + err.Error() + " " + buf.String())
return err
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
}
}
}
defer resp.Body.Close()
// Check if streaming is complete and streamResponseText is populated
if streamResponseText == "" {
return errors.New("Streaming not complete")
}
return nil
}
func buildTestRequest() *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
Stream: true,
}
testMessage := Message{
Role: "user",
Content: "You just need to output 'hi' next.",
},
},
Model: "",
MaxTokens: 1,
Stream: false,
Content: "Hello ChatGPT!",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
@@ -114,7 +295,7 @@ func TestChannel(c *gin.Context) {
}
testRequest := buildTestRequest()
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
err = testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@@ -138,32 +319,20 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) {
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
// enable & notify
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
@@ -186,21 +355,19 @@ func testAllChannels(notify bool) error {
}
go func() {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest)
err := testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
disableChannel(channel.Id, channel.Name, err.Error())
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
disableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
@@ -230,6 +397,7 @@ func TestAllChannels(c *gin.Context) {
"success": true,
"message": "",
})
return
}
func AutomaticallyTestChannels(frequency int) {

View File

@@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys))
channels := make([]model.Channel, 0)
for _, key := range keys {
if key == "" {
continue
@@ -127,23 +127,6 @@ func DeleteChannel(c *gin.Context) {
return
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)

237
controller/discord.go Normal file
View File

@@ -0,0 +1,237 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
disgoauth "github.com/realTristan/disgoauth"
)
type DiscordOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type DiscordUser struct {
Id string `json:"id"`
Username string `json:"username"`
}
func getDiscordUserInfoByCode(codeFromURLParamaters string, host string) (*DiscordUser, error) {
if codeFromURLParamaters == "" {
return nil, errors.New("Invalid parameter")
}
// Establish a new discord client
var dc *disgoauth.Client = disgoauth.Init(&disgoauth.Client{
ClientID: common.DiscordClientId,
ClientSecret: common.DiscordClientSecret,
RedirectURI: fmt.Sprintf("https://%s/oauth/discord", host),
Scopes: []string{disgoauth.ScopeIdentify, disgoauth.ScopeEmail, disgoauth.ScopeGuilds, disgoauth.ScopeGuildsJoin},
})
accessToken, _ := dc.GetOnlyAccessToken(codeFromURLParamaters)
// Get the authorized user's data using the above accessToken
userData, _ := disgoauth.GetUserData(accessToken)
// Create a new DiscordUser
var discordUser DiscordUser
// Decode the userData map[string]interface{} into the discordUser
// Convert the map to JSON
jsonData, _ := json.Marshal(userData)
// Convert the JSON to a struct
err := json.Unmarshal(jsonData, &discordUser)
if err != nil {
return nil, err
}
// Add guild member.
if common.DiscordGuildId != "" && discordUser.Id != "" && common.DiscordBotToken != "" && common.DiscordAllowJoiningGuild == "true" {
url := fmt.Sprintf("https://discord.com/api/guilds/%s/members/%s", common.DiscordGuildId, discordUser.Id)
// Set JSON
map1 := map[string]interface{}{
// accessToken remove "Bearer "
"access_token": string(accessToken[7:]),
}
// Convert map to JSON
jsonData, _ := json.Marshal(map1)
req, _ := http.NewRequest("PUT", url, bytes.NewBuffer(jsonData))
// Set Header
req.Header.Set("Authorization", fmt.Sprintf("Bot %s", common.DiscordBotToken))
req.Header.Set("Content-Type", "application/json")
// Create a new HTTP Client
client := &http.Client{}
resp, err := client.Do(req)
log.Print(resp.StatusCode)
if err != nil || (resp.StatusCode != 200 && resp.StatusCode != 201) {
// Print content
stringBuff := new(bytes.Buffer)
stringBuff.ReadFrom(resp.Body)
// Print error
fmt.Println("Error: ", stringBuff.String())
return nil, errors.New("You must join the discord server first or be verified member to be able to login!")
}
// Close the response body
defer resp.Body.Close()
}
if discordUser.Username == "" {
return nil, errors.New("Invalid return value, user field is empty, please try again later!")
}
return &discordUser, nil
}
func DiscordOAuth(c *gin.Context) {
session := sessions.Default(c)
username := session.Get("username")
if username != nil {
DiscordBind(c)
return
}
if !common.DiscordOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
host := c.Request.Host
discordUser, err := getDiscordUserInfoByCode(code, host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
DiscordId: discordUser.Id,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
err := user.FillUserByDiscordId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
if discordUser.Username != "" {
user.DisplayName = discordUser.Username
} else {
user.DisplayName = "Discord User"
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func DiscordBind(c *gin.Context) {
if !common.DiscordOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
DiscordId: discordUser.Id,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Discord 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.DiscordId = discordUser.Id
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -5,13 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type GitHubOAuthResponse struct {
@@ -79,14 +80,6 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
func GitHubOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
GitHubBind(c)
@@ -213,22 +206,3 @@ func GitHubBind(c *gin.Context) {
})
return
}
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}

View File

@@ -1,10 +1,9 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {

View File

@@ -2,7 +2,6 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
@@ -19,21 +18,19 @@ func GetAllLogs(c *gin.Context) {
username := c.Query("username")
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func GetUserLogs(c *gin.Context) {
@@ -49,36 +46,34 @@ func GetUserLogs(c *gin.Context) {
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func SearchUserLogs(c *gin.Context) {
@@ -86,18 +81,17 @@ func SearchUserLogs(c *gin.Context) {
userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func GetLogsStat(c *gin.Context) {
@@ -107,10 +101,9 @@ func GetLogsStat(c *gin.Context) {
tokenName := c.Query("token_name")
username := c.Query("username")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": gin.H{
@@ -118,7 +111,6 @@ func GetLogsStat(c *gin.Context) {
//"token": tokenNum,
},
})
return
}
func GetLogsSelfStat(c *gin.Context) {
@@ -128,10 +120,9 @@ func GetLogsSelfStat(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(http.StatusOK, gin.H{
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": gin.H{
@@ -139,30 +130,4 @@ func GetLogsSelfStat(c *gin.Context) {
//"token": tokenNum,
},
})
return
}
func DeleteHistoryLogs(c *gin.Context) {
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
if targetTimestamp == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "target timestamp is required",
})
return
}
count, err := model.DeleteOldLog(targetTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
return
}

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
@@ -21,6 +20,10 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"discord_oauth": common.DiscordOAuthEnabled,
"discord_client_id": common.DiscordClientId,
"discord_guild_id": common.DiscordGuildId,
"discord_allow_joining_guild": common.DiscordAllowJoiningGuild,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
@@ -80,22 +83,6 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
if common.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
}
}
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
})
return
}
}
if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -146,8 +133,7 @@ func SendPasswordResetEmail(c *gin.Context) {
subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击<a href='%s'>此处</a>进行密码重置。</p>"+
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@@ -2,7 +2,6 @@ package controller
import (
"fmt"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -56,66 +55,12 @@ func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "dall-e-2",
Id: "dall-e",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-2",
Parent: nil,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-3",
Parent: nil,
},
{
Id: "whisper-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "whisper-1",
Parent: nil,
},
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Root: "dall-e",
Parent: nil,
},
{
@@ -163,24 +108,6 @@ func init() {
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-1106",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-instruct",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-instruct",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
@@ -235,24 +162,6 @@ func init() {
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "gpt-4-1106-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-preview",
Parent: nil,
},
{
Id: "gpt-4-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-vision-preview",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
@@ -344,228 +253,21 @@ func init() {
Parent: nil,
},
{
Id: "claude-instant-1",
Id: "ChatGLM",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
OwnedBy: "thudm",
Permission: permission,
Root: "claude-instant-1",
Root: "ChatGLM",
Parent: nil,
},
{
Id: "claude-2",
Id: "ChatGLM2",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
OwnedBy: "thudm",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "claude-2.1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.1",
Parent: nil,
},
{
Id: "claude-2.0",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.0",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_turbo",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-turbo",
Parent: nil,
},
{
Id: "qwen-plus",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus",
Parent: nil,
},
{
Id: "qwen-max",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max",
Parent: nil,
},
{
Id: "qwen-max-longcontext",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max-longcontext",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
{
Id: "360GPT_S2_V9",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9",
Parent: nil,
},
{
Id: "embedding-bert-512-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding-bert-512-v1",
Parent: nil,
},
{
Id: "embedding_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding_s1_v1",
Parent: nil,
},
{
Id: "semantic_similarity_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "semantic_similarity_s1_v1",
Parent: nil,
},
{
Id: "hunyuan",
Object: "model",
Created: 1677649963,
OwnedBy: "tencent",
Permission: permission,
Root: "hunyuan",
Root: "ChatGLM2",
Parent: nil,
},
}
@@ -587,7 +289,7 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := types.OpenAIError{
openAIError := OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",

View File

@@ -42,19 +42,19 @@ func UpdateOption(c *gin.Context) {
return
}
switch option.Key {
case "DiscordOAuthEnabled":
if option.Value == "true" && common.DiscordClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Discord OAuth请先填入 Discord Client ID 以及 Discord Client Secret",
})
return
}
case "GitHubOAuthEnabled":
if option.Value == "true" && common.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
"message": "无法启用 GitHub OAuth请先填入 GitHub Client ID 以及 GitHub Client Secret",
})
return
}

View File

@@ -1,94 +0,0 @@
package controller
import (
"context"
"math"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayChat(c *gin.Context) {
var chatRequest types.ChatCompletionRequest
if err := common.UnmarshalBodyReusable(c, &chatRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
channel, pass := fetchChannel(c, chatRequest.Model)
if pass {
return
}
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[chatRequest.Model] != "" {
chatRequest.Model = modelMap[chatRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeChatCompletions)
if pass {
return
}
chatProvider, ok := provider.(providersBase.ChatInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, chatRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

View File

@@ -1,94 +0,0 @@
package controller
import (
"context"
"math"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayCompletions(c *gin.Context) {
var completionRequest types.CompletionRequest
if err := common.UnmarshalBodyReusable(c, &completionRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
channel, pass := fetchChannel(c, completionRequest.Model)
if pass {
return
}
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[completionRequest.Model] != "" {
completionRequest.Model = modelMap[completionRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeCompletions)
if pass {
return
}
completionProvider, ok := provider.(providersBase.CompletionInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, completionRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

View File

@@ -1,93 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func RelayEmbeddings(c *gin.Context) {
var embeddingsRequest types.EmbeddingRequest
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
embeddingsRequest.Model = c.Param("model")
}
if err := common.UnmarshalBodyReusable(c, &embeddingsRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
channel, pass := fetchChannel(c, embeddingsRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeEmbeddings)
if pass {
return
}
embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

View File

@@ -1,106 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayImageEdits(c *gin.Context) {
var imageEditRequest types.ImageEditRequest
if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
if imageEditRequest.Prompt == "" {
common.AbortWithMessage(c, http.StatusBadRequest, "field prompt is required")
return
}
if imageEditRequest.Model == "" {
imageEditRequest.Model = "dall-e-2"
}
if imageEditRequest.Size == "" {
imageEditRequest.Size = "1024x1024"
}
channel, pass := fetchChannel(c, imageEditRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesEdits)
if pass {
return
}
imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens, err := common.CountTokenImage(imageEditRequest)
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

View File

@@ -1,109 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayImageGenerations(c *gin.Context) {
var imageRequest types.ImageRequest
if err := common.UnmarshalBodyReusable(c, &imageRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
channel, pass := fetchChannel(c, imageRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesGenerations)
if pass {
return
}
imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens, err := common.CountTokenImage(imageRequest)
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

View File

@@ -1,101 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayImageVariations(c *gin.Context) {
var imageEditRequest types.ImageEditRequest
if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
if imageEditRequest.Model == "" {
imageEditRequest.Model = "dall-e-2"
}
if imageEditRequest.Size == "" {
imageEditRequest.Size = "1024x1024"
}
channel, pass := fetchChannel(c, imageEditRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesVariations)
if pass {
return
}
imageVariations, ok := provider.(providersBase.ImageVariationsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens, err := common.CountTokenImage(imageEditRequest)
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

181
controller/relay-image.go Normal file
View File

@@ -0,0 +1,181 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var imageRequest ImageRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
}
// N should between 1 and 10
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var textResponse ImageResponse
defer func() {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -1,93 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayModerations(c *gin.Context) {
var moderationRequest types.ModerationRequest
if err := common.UnmarshalBodyReusable(c, &moderationRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
if moderationRequest.Model == "" {
moderationRequest.Model = "text-moderation-stable"
}
channel, pass := fetchChannel(c, moderationRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[moderationRequest.Model] != "" {
moderationRequest.Model = modelMap[moderationRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeModerations)
if pass {
return
}
moderationProvider, ok := provider.(providersBase.ModerationInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, moderationRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

59
controller/relay-palm.go Normal file
View File

@@ -0,0 +1,59 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
)
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
type PaLMChatRequest struct {
Prompt []Message `json:"prompt"`
Temperature float64 `json:"temperature"`
CandidateCount int `json:"candidateCount"`
TopP float64 `json:"topP"`
TopK int `json:"topK"`
}
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
type PaLMChatResponse struct {
Candidates []Message `json:"candidates"`
Messages []Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
}
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
for _, message := range openAIRequest.Messages {
var author string
if message.Role == "user" {
author = "0"
} else {
author = "1"
}
messages = append(messages, PaLMChatMessage{
Author: author,
Content: message.Content,
})
}
request := PaLMChatRequest{
Prompt: nil,
Temperature: openAIRequest.Temperature,
CandidateCount: openAIRequest.N,
TopP: openAIRequest.TopP,
TopK: openAIRequest.MaxTokens,
}
// TODO: forward request to PaLM & convert response
fmt.Print(request)
return nil
}

View File

@@ -1,89 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelaySpeech(c *gin.Context) {
var speechRequest types.SpeechAudioRequest
if err := common.UnmarshalBodyReusable(c, &speechRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
channel, pass := fetchChannel(c, speechRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[speechRequest.Model] != "" {
speechRequest.Model = modelMap[speechRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioSpeech)
if pass {
return
}
speechProvider, ok := provider.(providersBase.SpeechInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens := len(speechRequest.Input)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, speechRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

609
controller/relay-text.go Normal file
View File

@@ -0,0 +1,609 @@
package controller
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
// Get token info
tokenInfo, err := model.GetTokenById(tokenId)
if err != nil {
return errorWrapper(err, "get_token_info_failed", http.StatusInternalServerError)
}
hasModelAvailable := func() bool {
for _, token := range strings.Split(tokenInfo.Models, ",") {
if token == textRequest.Model {
return true
}
}
return false
}()
if !hasModelAvailable {
return errorWrapper(errors.New("model not available for use"), "model_not_available_for_use", http.StatusBadRequest)
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
} else if channelType == common.ChannelTypeChatGPTWeb {
// remove /v1/chat/completions from request url
requestURL := strings.Split(requestURL, "/v1/chat/completions")[0]
fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL)
} else if channelType == common.ChannelTypeChatbotUI {
// remove /v1/chat/completions from request url
requestURL := strings.Split(requestURL, "/v1/chat/completions")[0]
fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL)
} else if channelType == common.ChannelTypePaLM {
err := relayPaLM(textRequest, c)
return err
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 10*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if consumeQuota && preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return errorWrapper(err, "read_request_body_failed", http.StatusInternalServerError)
}
var bodyMap map[string]interface{}
err = json.Unmarshal(bodyBytes, &bodyMap)
if err != nil {
return errorWrapper(err, "unmarshal_request_body_failed", http.StatusInternalServerError)
}
// Add "stream":true to body map if it doesn't exist
if _, exists := bodyMap["stream"]; !exists {
bodyMap["stream"] = true
}
// Marshal the body map back into JSON
bodyBytes, err = json.Marshal(bodyMap)
if err != nil {
return errorWrapper(err, "marshal_request_body_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(bodyBytes)
}
if channelType == common.ChannelTypeChatGPTWeb {
// Get system message from Message json, Role == "system"
var reqBody ChatRequest
var systemMessage Message
// Parse requestBody into systemMessage
err := json.NewDecoder(requestBody).Decode(&reqBody)
if err != nil {
return errorWrapper(err, "decode_request_body_failed", http.StatusInternalServerError)
}
for _, message := range reqBody.Messages {
if message.Role == "system" {
systemMessage = message
break
}
}
var prompt string
// Get all the Message, Roles from request.Messages, and format it into string by
// ||> role: content
for _, message := range reqBody.Messages {
// Exclude system message
if message.Role == "system" {
continue
}
prompt += "||> " + message.Role + ": " + message.Content + "\n"
}
// Construct json data without adding escape character
map1 := make(map[string]interface{})
map1["prompt"] = prompt + "\nResponse as assistant, but do not include the role in response."
map1["systemMessage"] = systemMessage.Content
if reqBody.Temperature != 0 {
map1["temperature"] = formatFloat(reqBody.Temperature)
}
if reqBody.TopP != 0 {
map1["top_p"] = formatFloat(reqBody.TopP)
}
// Convert map to json string
jsonData, err := json.Marshal(map1)
if err != nil {
return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError)
}
// Convert json string to io.Reader
requestBody = bytes.NewReader(jsonData)
} else if channelType == common.ChannelTypeChatbotUI {
// Get system message from Message json, Role == "system"
var reqBody ChatRequest
// Parse requestBody into systemMessage
err := json.NewDecoder(requestBody).Decode(&reqBody)
if err != nil {
return errorWrapper(err, "decode_request_body_failed", http.StatusInternalServerError)
}
// Get system message from Message json, Role == "system"
var systemMessage string
for _, message := range reqBody.Messages {
if message.Role == "system" {
systemMessage = message.Content
break
}
}
// Construct json data without adding escape character
map1 := make(map[string]interface{})
map1["prompt"] = systemMessage
map1["temperature"] = formatFloat(reqBody.Temperature)
map1["key"] = ""
map1["messages"] = reqBody.Messages
map1["model"] = map[string]interface{}{
"id": reqBody.Model,
}
// Convert map to json string
jsonData, err := json.Marshal(map1)
if err != nil {
return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError)
}
// Convert json string to io.Reader
requestBody = bytes.NewReader(jsonData)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if channelType == common.ChannelTypeAzure {
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
req.Header.Set("api-key", key)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
if c.GetBool("enable_ip_randomization") == true {
// Generate random IP
ip := common.GenerateIP()
req.Header.Set("X-Forwarded-For", ip)
req.Header.Set("X-Real-IP", ip)
req.Header.Set("X-Client-IP", ip)
req.Header.Set("X-Forwarded-Host", ip)
req.Header.Set("X-Originating-IP", ip)
req.RemoteAddr = ip
req.Header.Set("X-Remote-IP", ip)
req.Header.Set("X-Remote-Addr", ip)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
// Print the body in string
if resp.Body != nil {
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
log.Printf("Error Channel (%s): %s", baseURL, buf.String())
return errorWrapper(err, "request_failed", resp.StatusCode)
}
return errorWrapper(err, "request_failed", resp.StatusCode)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var textResponse TextResponse
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") || strings.HasPrefix(resp.Header.Get("Content-Type"), "application/octet-stream")
var streamResponseText string
defer func() {
if consumeQuota {
quota := 0
completionRatio := 1.0
if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
completionRatio = 1.333333
}
if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2
}
if isStream {
completionTokens = countTokenText(streamResponseText, textRequest.Model)
} else {
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
}
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
if isStream || channelType == common.ChannelTypeChatGPTWeb || channelType == common.ChannelTypeChatbotUI {
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
if channelType != common.ChannelTypeChatGPTWeb && channelType != common.ChannelTypeChatbotUI {
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
}
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // must be something wrong!
continue
}
if channelType == common.ChannelTypeChatGPTWeb {
var chatResponse ChatGptWebChatResponse
err = json.Unmarshal([]byte(data), &chatResponse)
if err != nil {
// Print the body in string
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
return
}
// if response role is assistant and contains delta, append the content to streamResponseText
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
for _, choice := range chatResponse.Detail.Choices {
streamResponseText += choice.Delta.Content
returnObj := map[string]interface{}{
"id": chatResponse.ID,
"object": chatResponse.Detail.Object,
"created": chatResponse.Detail.Created,
"model": chatResponse.Detail.Model,
"choices": []map[string]interface{}{
// set finish_reason to null in json
{
"finish_reason": nil,
"index": 0,
"delta": map[string]interface{}{
"content": choice.Delta.Content,
},
},
},
}
jsonData, _ := json.Marshal(returnObj)
dataChan <- "data: " + string(jsonData)
}
}
} else if channelType == common.ChannelTypeChatbotUI {
returnObj := map[string]interface{}{
"id": "chatcmpl-" + strconv.Itoa(int(time.Now().UnixNano())),
"object": "text_completion",
"created": time.Now().Unix(),
"model": textRequest.Model,
"choices": []map[string]interface{}{
// set finish_reason to null in json
{
"finish_reason": nil,
"index": 0,
"delta": map[string]interface{}{
"content": data,
},
},
},
}
jsonData, _ := json.Marshal(returnObj)
dataChan <- "data: " + string(jsonData)
} else {
// If data has event: event content inside, remove it, it can be prefix or inside the data
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back
data = strings.TrimPrefix(data, "event: event")
data = strings.TrimSuffix(data, "event: event")
// Remove everything, only keep `data: {...}` <--- this is the json
// Find the start and end indices of `data: {...}` substring
startIndex := strings.Index(data, "data:")
endIndex := strings.LastIndex(data, "}")
// If both indices are found and end index is greater than start index
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
// Extract the `data: {...}` substring
data = data[startIndex : endIndex+1]
}
// Trim whitespace and newlines from the modified data string
data = strings.TrimSpace(data)
}
if !strings.HasPrefix(data, "data:") {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
switch relayMode {
case RelayModeChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
case RelayModeCompletions:
var streamResponse CompletionsStreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Text
}
}
}
}
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
} else {
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the client will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}
}

View File

@@ -1,89 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayTranscriptions(c *gin.Context) {
var audioRequest types.AudioRequest
if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
channel, pass := fetchChannel(c, audioRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranscription)
if pass {
return
}
transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens := 0
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = transcriptionsProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

View File

@@ -1,89 +0,0 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
func RelayTranslations(c *gin.Context) {
var audioRequest types.AudioRequest
if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
channel, pass := fetchChannel(c, audioRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranslation)
if pass {
return
}
translationProvider, ok := provider.(providersBase.TranslationInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
return
}
// 获取Input Tokens
promptTokens := 0
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = translationProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
}

View File

@@ -1,300 +1,93 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"github.com/pkoukk/tiktoken-go"
"one-api/common"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"one-api/types"
"reflect"
"strconv"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
func GetValidFieldName(err error, obj interface{}) string {
getObj := reflect.TypeOf(obj)
if errs, ok := err.(validator.ValidationErrors); ok {
for _, e := range errs {
if f, exist := getObj.Elem().FieldByName(e.Field()); exist {
return f.Name
}
}
}
return err.Error()
}
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) {
channelId, ok := c.Get("channelId")
if ok {
channel, pass = fetchChannelById(c, channelId.(int))
if pass {
return
func getTokenEncoder(model string) *tiktoken.Tiktoken {
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
return tokenEncoder
}
}
channel, pass = fetchChannelByModel(c, modelName)
if pass {
return
}
setChannelToContext(c, channel)
return
}
func fetchChannelById(c *gin.Context, channelId any) (*model.Channel, bool) {
id, err := strconv.Atoi(channelId.(string))
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return nil, true
}
channel, err := model.GetChannelById(id, true)
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return nil, true
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
}
if channel.Status != common.ChannelStatusEnabled {
common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return nil, true
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
return channel, false
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if common.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
}
func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) {
group := c.GetString("group")
channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
common.AbortWithMessage(c, http.StatusServiceUnavailable, message)
return nil, true
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content)
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum
}
return channel, false
func countTokenInput(input any, model string) int {
switch input.(type) {
case string:
return countTokenText(input.(string), model)
case []string:
text := ""
for _, s := range input.([]string) {
text += s
}
return countTokenText(text, model)
}
return 0
}
func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.ProviderInterface, bool) {
provider := providers.GetProvider(channelType, c)
if provider == nil {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
return nil, true
func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
}
if !provider.SupportAPI(relayMode) {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API")
return nil, true
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
return provider, false
}
func setChannelToContext(c *gin.Context, channel *model.Channel) {
// c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("api_key", channel.Key)
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
}
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
return false
}
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
return &OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func parseModelMapping(modelMapping string) (map[string]string, error) {
if modelMapping == "" || modelMapping == "{}" {
return nil, nil
}
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return nil, err
}
return modelMap, nil
}
type QuotaInfo struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
quotaInfo := &QuotaInfo{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quotaInfo.initQuotaInfo(c.GetString("group"))
errWithCode := quotaInfo.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
}
return quotaInfo, nil
}
func (q *QuotaInfo) initQuotaInfo(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
return
}
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
q.preConsumedQuota = 0
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
}
return nil
}
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := common.GetCompletionRatio(q.modelName)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
if q.ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - q.preConsumedQuota
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(q.userId)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
}
return nil
}

View File

@@ -4,14 +4,193 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/types"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Name *string `json:"name,omitempty"`
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
)
// https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens *int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
// -1.0 to 1.0
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
type ImageRequest struct {
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
}
type ChatCompletionsStreamResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type ChatGptWebDetail struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []ChatGptWebChoice `json:"choices"`
}
type ChatGptWebChoice struct {
Delta struct {
Content string `json:"content"`
Role string `json:"role"`
} `json:"delta"`
Index int `json:"index"`
Finish_Reason string `json:"finish_reason"`
}
type ChatGptWebChatResponse struct {
Role string `json:"role"`
ID string `json:"id"`
ParentMessageID string `json:"parentMessageId"`
Text string `json:"text"`
Delta string `json:"delta"`
Detail *ChatGptWebDetail `json:"detail"`
}
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
}
if err != nil {
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
}
}
}
func RelayNotImplemented(c *gin.Context) {
err := types.OpenAIError{
err := OpenAIError{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
@@ -23,41 +202,13 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
err := types.OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
err := OpenAIError{
Message: fmt.Sprintf("API not found: %s:%s", c.Request.Method, c.Request.URL.Path),
Type: "one_api_error",
Param: "",
Code: "",
Code: "api_not_found",
}
c.JSON(http.StatusNotFound, gin.H{
"error": err,
})
}
func errorHelper(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
channelId := c.GetInt("channel_id")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
}
}

View File

@@ -1,11 +1,12 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-gonic/gin"
)
func GetAllTokens(c *gin.Context) {
@@ -109,10 +110,10 @@ func AddToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
if len(token.Name) == 0 || len(token.Name) > 20 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称长",
"message": "令牌名称长度必须在1-20之间",
})
return
}
@@ -125,6 +126,7 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
Models: token.Models,
}
err = cleanToken.Insert()
if err != nil {
@@ -171,13 +173,6 @@ func UpdateToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
})
return
}
cleanToken, err := model.GetTokenByIds(token.Id, userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -210,6 +205,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.Models = token.Models
}
err = cleanToken.Update()
if err != nil {

View File

@@ -7,7 +7,6 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
@@ -249,30 +248,6 @@ func GetUser(c *gin.Context) {
return
}
func GetUserDashboard(c *gin.Context) {
id := c.GetInt("id")
// 获取7天前 00:00:00 和 今天23:59:59 的秒时间戳
now := time.Now()
toDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
endOfDay := toDay.Add(time.Hour * 24).Add(-time.Second).Unix()
startOfDay := toDay.AddDate(0, 0, -7).Unix()
dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法获取统计信息.",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dashboards,
})
}
func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, true)
@@ -508,7 +483,7 @@ func DeleteSelf(c *gin.Context) {
if user.Role == common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除超级管理员账户",
"message": "无权删除超级管理员",
})
return
}

View File

@@ -9,21 +9,21 @@ services:
ports:
- "3000:3000"
volumes:
- ./data/oneapi:/data
- ./data:/data
- ./logs:/app/logs
environment:
- SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
depends_on:
- redis
- db
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
interval: 30s
timeout: 10s
retries: 3
@@ -32,18 +32,3 @@ services:
image: redis:latest
container_name: redis
restart: always
db:
image: mysql:8.2.0
restart: always
container_name: mysql
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
ports:
- '3306:3306'
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: oneapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: one-api # 自动创建数据库

34
english.dockerfile Normal file
View File

@@ -0,0 +1,34 @@
# Initial stage
FROM python:3.11 as translator
WORKDIR /app
COPY . .
RUN python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
# Node build stage
FROM node:18-alpine as nodeBuilder
WORKDIR /build
COPY ./web/package*.json ./
RUN npm ci
COPY --from=translator /app .
RUN cd web && VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
# Go build stage
FROM golang:1.20.5 AS goBuilder
ENV GO111MODULE=on \
CGO_ENABLED=1 \
GOOS=linux
WORKDIR /build
COPY go.mod .
COPY go.sum .
RUN go mod download
COPY --from=translator /app .
COPY --from=nodeBuilder /build/web/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
# Final stage
FROM alpine:latest
RUN apk update && apk upgrade && apk add --no-cache ca-certificates tzdata && update-ca-certificates 2>/dev/null || true
WORKDIR /data
COPY --from=goBuilder /build/one-api /
EXPOSE 3000
ENTRYPOINT ["/one-api"]

52
go.mod
View File

@@ -9,58 +9,56 @@ require (
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.14.0
github.com/go-playground/validator/v10 v10.14.1
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.14.0
golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
golang.org/x/crypto v0.11.0
gorm.io/driver/mysql v1.5.1
gorm.io/driver/sqlite v1.5.2
gorm.io/gorm v1.25.2
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/chenzhuoyu/iasm v0.9.0 // indirect
github.com/knz/go-libedit v1.10.1 // indirect
)
require (
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
github.com/bytedance/sonic v1.10.0-rc2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.7.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gomodule/redigo v2.0.0+incompatible // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/joho/godotenv v1.5.1
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/realTristan/disgoauth v1.0.2
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
golang.org/x/arch v0.4.0 // indirect
golang.org/x/net v0.12.0 // indirect
golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.11.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

105
go.sum
View File

@@ -1,17 +1,33 @@
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04=
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/bytedance/sonic v1.9.2 h1:GDaNjuWSGu09guE9Oql0MSTNhNCLlWwO8y/xM5BzcbM=
github.com/bytedance/sonic v1.9.2/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/bytedance/sonic v1.10.0-rc h1:3S5HeWxjX08CUqNrXtEittExpJsEKBNzrV5UnrzHxVQ=
github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM=
github.com/bytedance/sonic v1.10.0-rc2 h1:oDfRZ+4m6AYCOC0GFeOCeYqvBmucy1isvouS2K0cPzo=
github.com/bytedance/sonic v1.10.0-rc2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA=
github.com/chenzhuoyu/iasm v0.9.0 h1:9fhXjVzq5hUy2gkhhgHl95zG2cEAhw9OSGs8toWWAwo=
github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
@@ -45,17 +61,25 @@ github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GO
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+jU0zvx4AqHGnv4k=
github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -65,16 +89,9 @@ github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -88,6 +105,10 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1 h1:0pHpWtx9vcvC0xGZqEQlQdfSQs7WRlAjuPvk3fOZDCo=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@@ -119,11 +140,21 @@ github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pkoukk/tiktoken-go v0.1.4 h1:bniMzWdUvNO6YkRbASo2x5qJf2LAG/TIJojqz+Igm8E=
github.com/pkoukk/tiktoken-go v0.1.4/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/ravener/discord-oauth2 v0.0.0-20230514095040-ae65713199b3 h1:x3LgcvujjG+mx8PUMfPmwn3tcu2aA95uCB6ilGGObWk=
github.com/ravener/discord-oauth2 v0.0.0-20230514095040-ae65713199b3/go.mod h1:P/mZMYLZ87lqRSECEWsOqywGrO1hlZkk9RTwEw35IP4=
github.com/realTristan/disgoauth v1.0.2 h1:dfto2Kf1gFlZsf8XuwRNoemLgk+hGn/TJpSdtMrEh8E=
github.com/realTristan/disgoauth v1.0.2/go.mod h1:t72aRaWMq2gknUZcKONReJlEYFod5sHC86WCJ0X9GxA=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -140,6 +171,7 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@@ -151,36 +183,56 @@ github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc=
golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/oauth2 v0.10.0 h1:zHCpF2Khkwy4mMB4bv0U37YtJdTGW8jI0glAApi0Kh8=
golang.org/x/oauth2 v0.10.0/go.mod h1:kTpgurOux7LqtuxjuyZa4Gj2gdezIt/jQtGnNFfypQI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
@@ -197,12 +249,17 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw=
gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc=
gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -3,11 +3,6 @@
"%d 点额度": "%d point quota",
"尚未实现": "Not yet implemented",
"余额不足": "Insufficient balance",
"危险操作": "Hazardous operations",
"输入你的账户名": "Enter your account name",
"确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"通道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"通道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running",
@@ -39,8 +34,8 @@
"兑换码个数必须大于0": "The number of redemption codes must be greater than 0",
"一次兑换码批量生成的个数不能大于 100": "The number of redemption codes generated in a batch cannot be greater than 100",
"通过令牌「%s」使用模型 %s 消耗 %s模型倍率 %.2f,分组倍率 %.2f": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)",
"当前分组上游负载已饱和,请稍后再试": "The current group load is saturated, please try again later",
"令牌名称过长": "Token name is too long",
"当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。": "The current group load is saturated, please try again later, or upgrade your account to improve service quality.",
"令牌名称长度必须在1-20之间": "The length of the token name must be between 1-20",
"令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.",
"令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota",
"管理员关闭了密码登录": "The administrator has turned off password login",
@@ -119,7 +114,6 @@
" 年 ": " y ",
"未测试": "Not tested",
"通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!",
@@ -140,7 +134,6 @@
"启用": "Enable",
"编辑": "Edit",
"添加新的渠道": "Add a new channel",
"测试所有通道": "Test all channels",
"测试所有已启用通道": "Test all enabled channels",
"更新所有已启用通道余额": "Update the balance of all enabled channels",
"刷新": "Refresh",
@@ -231,7 +224,7 @@
"已是最新版本": "Is the latest version",
"检查更新": "Check for updates",
"公告": "Announcement",
"在此输入新的公告内容,支持 Markdown & HTML 代码": "Enter the new announcement content here, supports Markdown & HTML code",
"在此输入新的公告内容": "Enter new announcement content here",
"保存公告": "Save Announcement",
"个性化设置": "Personalization Settings",
"系统名称": "System Name",
@@ -244,7 +237,8 @@
"保存首页内容": "Save Home Page Content",
"在此输入新的关于内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为关于页面": "Enter new about content here, supports Markdown & HTML code. If a link is entered, it will be used as the src attribute of the iframe, allowing you to set any webpage as the about page.",
"保存关于": "Save About",
"移除 One API 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目": "Removal of One API copyright mark must first be authorized. Project maintenance requires a lot of effort. If this project is meaningful to you, please actively support it.",
"移除 One API": "Removal of One API",
"的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。": " copyright mark must first be authorized. Project maintenance requires a lot of effort. If this project is meaningful to you, please actively support it.",
"页脚": "Footer",
"在此输入新的页脚,留空则使用默认页脚,支持 HTML 代码": "Enter the new footer here, leave blank to use the default footer, supports HTML code.",
"设置页脚": "Set Footer",
@@ -434,7 +428,7 @@
"一分钟后过期": "Expires after one minute",
"创建新的令牌": "Create New Token",
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
"设为无限额度": "Set to unlimited quota",
"设为无限额度": "Set to unlimited quota",
"更新令牌信息": "Update Token Information",
"请输入充值码!": "Please enter the recharge code!",
"请输入名称": "Please enter a name",
@@ -450,6 +444,7 @@
"显示名称": "Display Name",
"请输入新的显示名称": "Please enter a new display name",
"已绑定的 GitHub 账户": "GitHub Account Bound",
"已绑定的 Discord 账户": "Discord Account Bound",
"此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only. Users need to bind through the relevant binding button on the personal settings page, and cannot be modified directly",
"已绑定的微信账户": "WeChat Account Bound",
"已绑定的邮箱账户": "Email Account Bound",
@@ -500,7 +495,6 @@
"参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)",
"模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!",
"取消无限额度": "Cancel unlimited quota",
"取消": "Cancel",
"请输入新的剩余额度": "Please enter the new remaining quota",
"请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code",
"请输入用户名": "Please enter username",
@@ -512,19 +506,36 @@
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为通道命名": "Please name the channel",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"模型重定向": "Model redirection",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
"注意,": "Note that, ",
",图片演示。": "related image demo.",
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
"代理": "Proxy",
"此项可选,用于通过代理站来进行 API 调用请输入代理站地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
"按照如下格式输入:": "Enter in the following format:",
"模型版本": "Model version",
"请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
"点击查看": "click to view",
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
"允许通过 Discord 账户登录和注册": "Allow login and registration via Discord account",
"Discord 身份验证": "Discord Authentication",
"确认文字": "Confirmation Text",
"请输入 \"CONFIRM\" 以删除您的帐户。": "Please enter \"CONFIRM\" to delete your account.",
"请确认您要删除账户!": "Please confirm that you want to delete the account!",
"账户已删除!": "Account deleted!",
"您是否确认删除自己的帐户?": "Are you sure you want to delete your account?",
"配置 Discord OAuth App": "Configure Discord OAuth App",
"管理你的 Discord OAuth App": "Manage your Discord OAuth App",
"输入你注册的 Discord OAuth APP 的 ID": "Enter the ID of your registered Discord OAuth APP",
"保存 Discord OAuth 设置": "Save Discord OAuth Settings",
"删除个人账户": "Delete personal account",
"绑定 Discord 账号": "Bind Discord account",
"无权将其他用户权限等级提升到大于等于自己的权限等级": "You are not allowed to upgrade the permission level of other users to greater than or equal to your own permission level",
"无权删除超级管理员": "You are not allowed to delete super administrators",
"该 Discord 账户已被绑定": "The Discord account has been bound",
"管理员未开启通过 Discord 登录以及注册": "The administrator has not enabled login and registration via Discord",
"无法启用 Discord OAuth请先填入 Discord Client ID 以及 Discord Client Secret": "Unable to enable Discord OAuth, please fill in the Discord Client ID and Discord Client Secret first!",
"兑换失败,": "Redemption failed, ",
"请选择此密钥支持的模型": "Please select the models supported by this key",
"将IP随机地址传递给HTTP头": "Pass the IP random address to the HTTP header",
"失败重试次数": "Number of failed retries",
"消费": "Consumption",
"管理": "Management",
"系统": "System",
"未知": "Unknown",
"One API 会把请求体中的 model": "One API will take the model in the request body",
",因为": ", because",
"参数替换为你的部署名称(模型名称中的点会被剔除),": "Replace the parameter with your deployment name (dots in the model name will be removed), ",
"注意,此处生成的令牌用于系统管理,而非用于请求 OpenAI": "Note that the generated token here is used for system management, not for requesting OpenAI",
"相关的服务,请知悉。": "related services, please be aware.",
"填入": "Fill in"
}

40
main.go
View File

@@ -2,7 +2,6 @@ package main
import (
"embed"
"fmt"
"one-api/common"
"one-api/controller"
"one-api/middleware"
@@ -14,6 +13,7 @@ import (
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/joho/godotenv"
)
//go:embed web/build
@@ -23,14 +23,13 @@ var buildFS embed.FS
var indexPage []byte
func main() {
common.SetupLogger()
godotenv.Load(".env")
common.SetupGinLog()
common.SysLog("One API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
if common.DebugEnabled {
common.SysLog("running in debug mode")
}
// Initialize SQL Database
err := model.InitDB()
if err != nil {
@@ -52,17 +51,17 @@ func main() {
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
}
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
if os.Getenv("SYNC_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
}
go model.SyncOptions(frequency)
if common.RedisEnabled {
go model.SyncChannelCache(frequency)
}
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
@@ -78,20 +77,13 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
common.InitTokenEncoders()
// Initialize HTTP server
server := gin.New()
server.Use(gin.Recovery())
server := gin.Default()
// This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
server.Use(middleware.CORS())
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
server.Use(sessions.Sessions("session", store))

View File

@@ -91,26 +91,45 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0]
token, err := model.ValidateUserToken(key)
if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
if !model.CacheIsUserEnabled(token.UserId) {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "用户已被封禁",
"type": "one_api_error",
},
})
c.Abort()
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "普通用户不支持指定渠道",
"type": "one_api_error",
},
})
c.Abort()
return
}
}

View File

@@ -1,16 +1,116 @@
package middleware
import (
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type ModelRequest struct {
Model string `json:"model"`
}
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return
}
if channel.Status != common.ChannelStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := "无可用渠道"
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
return
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.ModelMapping)
c.Set("enable_ip_randomization", channel.EnableIpRandomization)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure {
c.Set("api_version", channel.Other)
}
c.Next()
}
}

View File

@@ -1,25 +0,0 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
requestID,
param.StatusCode,
param.Latency,
param.ClientIP,
param.Method,
param.Path,
)
}))
}

View File

@@ -1,26 +0,0 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
)
func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
"type": "one_api_panic",
},
})
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -1,18 +0,0 @@
package middleware
import (
"context"
"github.com/gin-gonic/gin"
"one-api/common"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8)
c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id)
c.Next()
}
}

View File

@@ -1,17 +0,0 @@
package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
}

View File

@@ -10,25 +10,15 @@ type Ability struct {
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err
@@ -50,7 +40,6 @@ func (channel *Channel) AddAbilities() error {
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
}
abilities = append(abilities, ability)
}

View File

@@ -6,33 +6,28 @@ import (
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
"time"
)
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
const (
TokenCacheSeconds = 60 * 60
UserId2GroupCacheSeconds = 60 * 60
UserId2QuotaCacheSeconds = 10 * 60
UserId2StatusCacheSeconds = 60 * 60
)
func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
if !common.RedisEnabled {
err := DB.Where(keyCol+" = ?", key).First(&token).Error
err := DB.Where("`key` = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where(keyCol+" = ?", key).First(&token).Error
err := DB.Where("`key` = ?", key).First(&token).Error
if err != nil {
return nil, err
}
@@ -40,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
if err != nil {
return nil, err
}
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second)
if err != nil {
common.SysError("Redis set token error: " + err.Error())
}
@@ -60,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
if err != nil {
return "", err
}
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second)
if err != nil {
common.SysError("Redis set user group error: " + err.Error())
}
@@ -78,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
if err != nil {
common.SysError("Redis set user quota error: " + err.Error())
}
@@ -96,40 +91,27 @@ func CacheUpdateUserQuota(id int) error {
if err != nil {
return err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
return err
}
func CacheDecreaseUserQuota(id int, quota int) error {
if !common.RedisEnabled {
return nil
}
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
return err
}
func CacheIsUserEnabled(userId int) (bool, error) {
func CacheIsUserEnabled(userId int) bool {
if !common.RedisEnabled {
return IsUserEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
userEnabled, err := IsUserEnabled(userId)
if err != nil {
return false, err
status := common.UserStatusDisabled
if IsUserEnabled(userId) {
status = common.UserStatusEnabled
}
enabled = "0"
if userEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
enabled = fmt.Sprintf("%d", status)
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
return enabled == "1"
}
var group2model2channels map[string]map[string][]*Channel
@@ -164,17 +146,6 @@ func InitChannelCache() {
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
@@ -190,7 +161,7 @@ func SyncChannelCache(frequency int) {
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if !common.MemoryCacheEnabled {
if !common.RedisEnabled {
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
@@ -199,17 +170,6 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(endIdx)
idx := rand.Intn(len(channels))
return channels[idx], nil
}

View File

@@ -12,19 +12,21 @@ type Channel struct {
Key string `json:"key" gorm:"not null;index"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"`
Weight int `json:"weight"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
BaseURL string `json:"base_url" gorm:"column:base_url"`
Other string `json:"other"`
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
// Additional fields, default value is false
EnableIpRandomization bool `json:"enable_ip_randomization"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -39,11 +41,7 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -58,6 +56,17 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err
}
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Create(&channels).Error
@@ -73,27 +82,6 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
}
return *channel.Priority
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
}
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() string {
if channel.ModelMapping == nil {
return ""
}
return *channel.ModelMapping
}
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
@@ -157,26 +145,8 @@ func UpdateChannelStatusById(id int, status int) {
}
func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
updateChannelUsedQuota(id, quota)
}
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
common.SysError("failed to update channel used quota: " + err.Error())
}
}
func DeleteChannelByStatus(status int64) (int64, error) {
result := DB.Where("status = ?", status).Delete(&Channel{})
return result.RowsAffected, result.Error
}
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@@ -1,35 +1,22 @@
package model
import (
"context"
"fmt"
"one-api/common"
"gorm.io/gorm"
"one-api/common"
)
type Log struct {
Id int `json:"id;index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Id int `json:"id"`
UserId int `json:"user_id"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index"`
Type int `json:"type" gorm:"index"`
Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
Username string `json:"username" gorm:"index;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
ModelName string `json:"model_name" gorm:"index;default:''"`
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
ChannelId int `json:"channel" gorm:"index"`
}
type LogStatistic struct {
Day string `gorm:"column:day"`
ModelName string `gorm:"column:model_name"`
RequestCount int `gorm:"column:request_count"`
Quota int `gorm:"column:quota"`
PromptTokens int `gorm:"column:prompt_tokens"`
CompletionTokens int `gorm:"column:completion_tokens"`
}
const (
@@ -57,8 +44,7 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
if !common.LogConsumeEnabled {
return
}
@@ -73,15 +59,14 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
ChannelId: channelId,
}
err := DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
common.SysError("failed to record log: " + err.Error())
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
@@ -103,9 +88,6 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
}
@@ -143,8 +125,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
return logs, err
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
tx := DB.Table("logs").Select(assembleSumSelectStr("quota"))
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
tx := DB.Table("logs").Select("sum(quota)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -160,15 +142,12 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&quota)
return quota
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select(assembleSumSelectStr("prompt_tokens") + " + " + assembleSumSelectStr("completion_tokens"))
tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -187,46 +166,3 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx.Where("type = ?", LogTypeConsume).Scan(&token)
return token
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}
func SearchLogsByDayAndModel(user_id, start, end int) (LogStatistics []*LogStatistic, err error) {
groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day"
if common.UsingPostgreSQL {
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
}
err = DB.Raw(`
SELECT `+groupSelect+`,
model_name, count(1) as request_count,
sum(quota) as quota,
sum(prompt_tokens) as prompt_tokens,
sum(completion_tokens) as completion_tokens
FROM logs
WHERE type=2
AND user_id= ?
AND created_at BETWEEN ? AND ?
GROUP BY day, model_name
ORDER BY day, model_name
`, user_id, start, end).Scan(&LogStatistics).Error
fmt.Println(user_id, start, end)
return LogStatistics, err
}
func assembleSumSelectStr(selectStr string) string {
sumSelectStr := "%s(sum(%s),0)"
nullfunc := "ifnull"
if common.UsingPostgreSQL {
nullfunc = "coalesce"
}
sumSelectStr = fmt.Sprintf(sumSelectStr, nullfunc, selectStr)
return sumSelectStr
}

View File

@@ -1,15 +1,11 @@
package model
import (
"fmt"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
"os"
"strings"
"time"
)
var DB *gorm.DB
@@ -37,55 +33,34 @@ func createRootAccountIfNeed() error {
return nil
}
func chooseDB() (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" {
dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL
common.SysLog("using MySQL as database")
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
func CountTable(tableName string) (num int64) {
DB.Table(tableName).Count(&num)
return
}
func InitDB() (err error) {
db, err := chooseDB()
var db *gorm.DB
if os.Getenv("SQL_DSN") != "" {
// Use MySQL
common.SysLog("using MySQL as database")
db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
} else {
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
common.SysLog("database connected")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
DB = db
sqlDB, err := DB.DB()
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
common.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
err := db.AutoMigrate(&Channel{})
if err != nil {
return err
}

View File

@@ -30,18 +30,16 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
@@ -56,6 +54,11 @@ func InitOptionMap() {
common.OptionMap["ServerAddress"] = ""
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["DiscordClientId"] = ""
common.OptionMap["DiscordClientSecret"] = ""
common.OptionMap["DiscordGuildId"] = ""
common.OptionMap["DiscordBotToken"] = ""
common.OptionMap["DiscordAllowJoiningGuild"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
@@ -136,6 +139,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled":
common.EmailVerificationEnabled = boolValue
case "DiscordOAuthEnabled":
common.DiscordOAuthEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled":
@@ -144,12 +149,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
common.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
common.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled":
@@ -161,8 +162,6 @@ func updateOptionMap(key string, value string) (err error) {
}
}
switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer":
common.SMTPServer = value
case "SMTPPort":
@@ -180,6 +179,16 @@ func updateOptionMap(key string, value string) (err error) {
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
case "DiscordClientId":
common.DiscordClientId = value
case "DiscordGuildId":
common.DiscordGuildId = value
case "DiscordBotToken":
common.DiscordBotToken = value
case "DiscordAllowJoiningGuild":
common.DiscordAllowJoiningGuild = value
case "DiscordClientSecret":
common.DiscordClientSecret = value
case "Footer":
common.Footer = value
case "SystemName":

View File

@@ -3,9 +3,8 @@ package model
import (
"errors"
"fmt"
"one-api/common"
"gorm.io/gorm"
"one-api/common"
)
type Redemption struct {
@@ -28,7 +27,7 @@ func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
}
func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) {
err = DB.Where("id = ? or name LIKE ?", common.String2Int(keyword), keyword+"%").Find(&redemptions).Error
err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error
return redemptions, err
}
@@ -51,27 +50,21 @@ func Redeem(key string, userId int) (quota int, err error) {
}
redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
err := DB.Where("`key` = ?", key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}
if redemption.Status != common.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
if err != nil {
return err
}
redemption.RedeemedTime = common.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err
return redemption.SelectUpdate()
})
if err != nil {
return 0, errors.New("兑换失败," + err.Error())

View File

@@ -3,8 +3,9 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"gorm.io/gorm"
)
type Token struct {
@@ -19,6 +20,7 @@ type Token struct {
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Models string `json:"models"`
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@@ -39,35 +41,32 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
}
go func() {
token.AccessedTime = common.GetTimestamp()
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token" + err.Error())
}
}()
return token, nil
}
return nil, errors.New("无效的令牌")
@@ -102,7 +101,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error
return err
}
@@ -134,19 +133,10 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
return increaseTokenQuota(id, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(),
},
).Error
return err
@@ -156,19 +146,10 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
return decreaseTokenQuota(id, quota)
}
func decreaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(),
},
).Error
return err

View File

@@ -20,6 +20,7 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
@@ -43,8 +44,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
}
func SearchUsers(keyword string) (users []*User, err error) {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", common.String2Int(keyword), keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
return users, err
}
@@ -171,6 +171,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByDiscordId() error {
if user.DiscordId == "" {
return errors.New("Discord id 为空!")
}
DB.Where(User{DiscordId: user.DiscordId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -199,6 +207,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsDiscordIdAlreadyTaken(discordId string) bool {
return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -228,16 +240,17 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
func IsUserEnabled(userId int) (bool, error) {
func IsUserEnabled(userId int) bool {
if userId == 0 {
return false, errors.New("user id is empty")
return false
}
var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
if err != nil {
return false, err
common.SysError("no such user " + err.Error())
return false
}
return user.Status == common.UserStatusEnabled, nil
return user.Status == common.UserStatusEnabled
}
func ValidateAccessToken(token string) (user *User) {
@@ -268,12 +281,7 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
return group, err
}
@@ -281,14 +289,6 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err
}
@@ -297,14 +297,6 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
@@ -315,19 +307,10 @@ func GetRootUserEmail() (email string) {
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
}
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
"request_count": gorm.Expr("request_count + ?", count),
"request_count": gorm.Expr("request_count + ?", 1),
},
).Error
if err != nil {
@@ -335,24 +318,6 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
}
}
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
common.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysError("failed to update user request count: " + err.Error())
}
}
func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username

View File

@@ -1,77 +0,0 @@
package model
import (
"one-api/common"
"sync"
"time"
)
const (
BatchUpdateTypeUserQuota = iota
BatchUpdateTypeTokenQuota
BatchUpdateTypeUsedQuota
BatchUpdateTypeChannelUsedQuota
BatchUpdateTypeRequestCount
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
)
var batchUpdateStores []map[int]int
var batchUpdateLocks []sync.Mutex
func init() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
}
}
func InitBatchUpdater() {
go func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
}
func addNewRecord(type_ int, id int, value int) {
batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok {
batchUpdateStores[type_][id] = value
} else {
batchUpdateStores[type_][id] += value
}
}
func batchUpdate() {
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
common.SysLog("batch update finished")
}

View File

@@ -1,30 +0,0 @@
package aigc2d
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}

View File

@@ -1,20 +0,0 @@
package aigc2d
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Aigc2dProviderFactory struct{}
func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &Aigc2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"),
}
}
type Aigc2dProvider struct {
*openai.OpenAIProvider
}

View File

@@ -1,35 +0,0 @@
package aiproxy
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
)
func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := "https://aiproxy.io/api/report/getUserOverview"
headers := make(map[string]string)
headers["Api-Key"] = channel.Key
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response AIProxyUserOverviewResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
if !response.Success {
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
}
channel.UpdateBalance(response.Data.TotalPoints)
return response.Data.TotalPoints, nil
}

View File

@@ -1,20 +0,0 @@
package aiproxy
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type AIProxyProviderFactory struct{}
func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &AIProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"),
}
}
type AIProxyProvider struct {
*openai.OpenAIProvider
}

View File

@@ -1,10 +0,0 @@
package aiproxy
type AIProxyUserOverviewResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
ErrorCode int `json:"error_code"`
Data struct {
TotalPoints float64 `json:"totalPoints"`
} `json:"data"`
}

View File

@@ -1,41 +0,0 @@
package ali
import (
"fmt"
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂
type AliProviderFactory struct{}
// 创建 AliProvider
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &AliProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://dashscope.aliyuncs.com",
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
Context: c,
},
}
}
type AliProvider struct {
base.BaseProvider
}
// 获取请求头
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
if p.Context.GetString("plugin") != "" {
headers["X-DashScope-Plugin"] = p.Context.GetString("plugin")
}
return headers
}

View File

@@ -1,216 +0,0 @@
package ali
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
)
// 阿里云响应处理
func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}
return
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: aliResponse.Output.Text,
},
FinishReason: aliResponse.Output.FinishReason,
}
OpenAIResponse = types.ChatCompletionResponse{
ID: aliResponse.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Usage: &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
CompletionTokens: aliResponse.Usage.OutputTokens,
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
},
}
return
}
// 获取聊天请求体
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Messages: messages,
},
}
}
// 聊天
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
headers["X-DashScope-SSE"] = "enable"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
usage, errWithCode = p.sendStreamRequest(req)
if errWithCode != nil {
return
}
if usage == nil {
usage = &types.Usage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
}
} else {
aliResponse := &AliChatResponse{}
errWithCode = p.SendRequest(req, aliResponse, false)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
CompletionTokens: aliResponse.Usage.OutputTokens,
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
}
}
return
}
// 阿里云响应转OpenAI响应
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := types.ChatCompletionStreamResponse{
ID: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}
// 发送流请求
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
usage = &types.Usage{}
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if common.IsFailureStatusCode(resp) {
return nil, common.HandleErrorResp(resp)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
lastResponseText := ""
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := p.streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return
}

View File

@@ -1,73 +0,0 @@
package ali
import (
"net/http"
"one-api/common"
"one-api/types"
)
// 嵌入请求处理
func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}
}
openAIEmbeddingResponse := &types.EmbeddingResponse{
Object: "list",
Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens},
}
for _, item := range aliResponse.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return openAIEmbeddingResponse, nil
}
// 获取嵌入请求体
func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getEmbeddingsRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
aliEmbeddingResponse := &AliEmbeddingResponse{}
errWithCode = p.SendRequest(req, aliEmbeddingResponse, false)
if errWithCode != nil {
return
}
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
return usage, nil
}

View File

@@ -1,70 +0,0 @@
package ali
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}
type AliInput struct {
// Prompt string `json:"prompt"`
Messages []AliMessage `json:"messages"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}

View File

@@ -1,30 +0,0 @@
package api2d
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}

View File

@@ -1,21 +0,0 @@
package api2d
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Api2dProviderFactory struct{}
// 创建 Api2dProvider
func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &Api2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"),
}
}
type Api2dProvider struct {
*openai.OpenAIProvider
}

View File

@@ -1,30 +0,0 @@
package api2gpt
import (
"errors"
"one-api/common"
"one-api/model"
"one-api/providers/base"
)
func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response base.BalanceResponse
_, errWithCode := common.SendRequest(req, &response, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalRemaining, nil
}

View File

@@ -1,20 +0,0 @@
package api2gpt
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type Api2gptProviderFactory struct{}
func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &Api2gptProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"),
}
}
type Api2gptProvider struct {
*openai.OpenAIProvider
}

View File

@@ -1,36 +0,0 @@
package azure
import (
"one-api/providers/base"
"one-api/providers/openai"
"github.com/gin-gonic/gin"
)
type AzureProviderFactory struct{}
// 创建 AzureProvider
func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &AzureProvider{
OpenAIProvider: openai.OpenAIProvider{
BaseProvider: base.BaseProvider{
BaseURL: "",
Completions: "/completions",
ChatCompletions: "/chat/completions",
Embeddings: "/embeddings",
AudioTranscriptions: "/audio/transcriptions",
AudioTranslations: "/audio/translations",
ImagesGenerations: "/images/generations",
// ImagesEdit: "/images/edit",
// ImagesVariations: "/images/variations",
Context: c,
// AudioSpeech: "/audio/speech",
},
IsAzure: true,
},
}
}
type AzureProvider struct {
openai.OpenAIProvider
}

View File

@@ -1,102 +0,0 @@
package azure
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/providers/openai"
"one-api/types"
"time"
)
func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Status == "canceled" || c.Status == "failed" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: c.Error.Message,
Type: "one_api_error",
Code: c.Error.Code,
},
StatusCode: resp.StatusCode,
}
return
}
operation_location := resp.Header.Get("operation-location")
if operation_location == "" {
return nil, common.ErrorWrapper(errors.New("image url is empty"), "get_images_url_failed", http.StatusInternalServerError)
}
client := common.NewClient()
req, err := client.NewRequest("GET", operation_location, common.WithHeader(c.Header))
if err != nil {
return nil, common.ErrorWrapper(err, "get_images_request_failed", http.StatusInternalServerError)
}
getImageAzureResponse := ImageAzureResponse{}
for i := 0; i < 3; i++ {
// 休眠 2 秒
time.Sleep(2 * time.Second)
_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false)
fmt.Println("getImageAzureResponse", getImageAzureResponse)
if errWithCode != nil {
return
}
if getImageAzureResponse.Status == "canceled" || getImageAzureResponse.Status == "failed" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: c.Error.Message,
Type: "get_images_request_failed",
Code: c.Error.Code,
},
StatusCode: resp.StatusCode,
}
}
if getImageAzureResponse.Status == "succeeded" {
return getImageAzureResponse.Result, nil
}
}
return nil, common.ErrorWrapper(errors.New("get image Timeout"), "get_images_url_failed", http.StatusInternalServerError)
}
func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Model == "dall-e-2" {
imageAzureResponse := &ImageAzureResponse{
Header: headers,
}
errWithCode = p.SendRequest(req, imageAzureResponse, false)
} else {
openAIProviderImageResponseResponse := &openai.OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
}
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
}

View File

@@ -1,21 +0,0 @@
package azure
import "one-api/types"
type ImageAzureResponse struct {
ID string `json:"id,omitempty"`
Created int64 `json:"created,omitempty"`
Expires int64 `json:"expires,omitempty"`
Result types.ImageResponse `json:"result,omitempty"`
Status string `json:"status,omitempty"`
Error ImageAzureError `json:"error,omitempty"`
Header map[string]string `json:"header,omitempty"`
}
type ImageAzureError struct {
Code string `json:"code,omitempty"`
Target string `json:"target,omitempty"`
Message string `json:"message,omitempty"`
Details []string `json:"details,omitempty"`
InnerError any `json:"innererror,omitempty"`
}

View File

@@ -1,36 +0,0 @@
package azureSpeech
import (
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂
type AzureSpeechProviderFactory struct{}
// 创建 AliProvider
func (f AzureSpeechProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &AzureSpeechProvider{
BaseProvider: base.BaseProvider{
BaseURL: "",
AudioSpeech: "/cognitiveservices/v1",
Context: c,
},
}
}
type AzureSpeechProvider struct {
base.BaseProvider
}
// 获取请求头
func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Ocp-Apim-Subscription-Key"] = p.Context.GetString("api_key")
headers["Content-Type"] = "application/ssml+xml"
headers["User-Agent"] = "OneAPI"
// headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3"
return headers
}

View File

@@ -1,88 +0,0 @@
package azureSpeech
import (
"bytes"
"fmt"
"net/http"
"one-api/common"
"one-api/types"
)
var outputFormatMap = map[string]string{
"mp3": "audio-16khz-128kbitrate-mono-mp3",
"opus": "audio-16khz-128kbitrate-mono-opus",
"aac": "audio-24khz-160kbitrate-mono-mp3",
"flac": "audio-48khz-192kbitrate-mono-mp3",
}
func CreateSSML(text string, name string, role string) string {
ssmlTemplate := `<speak version='1.0' xml:lang='en-US'>
<voice xml:lang='en-US' %s name='%s'>
%s
</voice>
</speak>`
roleAttribute := ""
if role != "" {
roleAttribute = fmt.Sprintf("role='%s'", role)
}
return fmt.Sprintf(ssmlTemplate, roleAttribute, name, text)
}
func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest) *bytes.Buffer {
voiceMap := map[string][]string{
"alloy": {"zh-CN-YunxiNeural"},
"echo": {"zh-CN-YunyangNeural"},
"fable": {"zh-CN-YunxiNeural", "Boy"},
"onyx": {"zh-CN-YunyeNeural"},
"nova": {"zh-CN-XiaochenNeural"},
"shimmer": {"zh-CN-XiaohanNeural"},
}
voice := ""
role := ""
if voiceMap[request.Voice] != nil {
voice = voiceMap[request.Voice][0]
if len(voiceMap[request.Voice]) > 1 {
role = voiceMap[request.Voice][1]
}
}
ssml := CreateSSML(request.Input, voice, role)
return bytes.NewBufferString(ssml)
}
func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model)
headers := p.GetRequestHeaders()
responseFormatr := outputFormatMap[request.ResponseFormat]
if responseFormatr == "" {
responseFormatr = outputFormatMap["mp3"]
}
headers["X-Microsoft-OutputFormat"] = responseFormatr
requestBody := p.getRequestBody(request)
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
errWithCode = p.SendRequestRaw(req)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
}

View File

@@ -1,129 +0,0 @@
package baidu
import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
"one-api/providers/base"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂
type BaiduProviderFactory struct{}
// 创建 BaiduProvider
func (f BaiduProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &BaiduProvider{
BaseProvider: base.BaseProvider{
BaseURL: "https://aip.baidubce.com",
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
Context: c,
},
}
}
var baiduTokenStore sync.Map
type BaiduProvider struct {
base.BaseProvider
}
// 获取完整请求 URL
func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string {
var modelNameMap = map[string]string{
"ERNIE-Bot": "completions",
"ERNIE-Bot-turbo": "eb-instant",
"ERNIE-Bot-4": "completions_pro",
"BLOOMZ-7B": "bloomz_7b1",
"Embedding-V1": "embedding-v1",
}
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
apiKey, err := p.getBaiduAccessToken()
if err != nil {
return ""
}
return fmt.Sprintf("%s%s/%s?access_token=%s", baseURL, requestURL, modelNameMap[modelName], apiKey)
}
// 获取请求头
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
return headers
}
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
apiKey := p.Context.GetString("api_key")
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken
if accessToken, ok = val.(BaiduAccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
_, _ = p.getBaiduAccessTokenHelper(apiKey)
}()
}
return accessToken.AccessToken, nil
}
}
accessToken, err := p.getBaiduAccessTokenHelper(apiKey)
if err != nil {
return "", err
}
if accessToken == nil {
return "", errors.New("getBaiduAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
}
client := common.NewClient()
url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
var headers = map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
}
req, err := client.NewRequest("POST", url, common.WithHeader(headers))
if err != nil {
return nil, err
}
resp, err := common.HttpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var accessToken BaiduAccessToken
err = json.NewDecoder(resp.Body).Decode(&accessToken)
if err != nil {
return nil, err
}
if accessToken.Error != "" {
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
}
if accessToken.AccessToken == "" {
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
}
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
baiduTokenStore.Store(apiKey, accessToken)
return &accessToken, nil
}

Some files were not shown because too many files have changed in this diff Show More