Compare commits

...

54 Commits

Author SHA1 Message Date
Wei Tingjiang
49ffb1c60d feat: enhance response handling to support gemini-2.0-thinking (#1995) 2024-12-22 18:25:44 +08:00
longkeyy
2f16649896 feat: update qwen model and price (#1966) 2024-12-22 18:22:57 +08:00
dependabot[bot]
af3aa57bd6 chore(deps): bump golang.org/x/crypto from 0.24.0 to 0.31.0 (#1976)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.24.0 to 0.31.0.
- [Commits](https://github.com/golang/crypto/compare/v0.24.0...v0.31.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-22 18:21:00 +08:00
Laisky.Cai
e9f117ff72 feat: add gemini-2.0-flash-exp and fix race condition in processChannelRelayError (#1983)
Some checks are pending
CI / Unit tests (push) Waiting to run
CI / commit_lint (push) Waiting to run
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-12-21 20:32:30 +08:00
Laisky.Cai
6bb5247bd6 feat: add support for new OpenAI models and update billing ratios (#1990) 2024-12-21 20:28:51 +08:00
Laisky.Cai
305ce14fe3 feat: support replicate chat models (#1989)
* feat: add Replicate adaptor and integrate into channel and API types

* feat: support llm chat on replicate
2024-12-21 14:41:19 +08:00
JustSong
36c8f4f15c docs: update readme
Some checks are pending
CI / Unit tests (push) Waiting to run
CI / commit_lint (push) Waiting to run
2024-12-21 00:33:20 +08:00
JustSong
45b51ea0ee feat: update feishu oauth login 2024-12-20 23:27:00 +08:00
Calcium-Ion
7c8628bd95 feat: support gzip decode (#1962)
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
2024-12-04 23:34:24 +08:00
JustSong
6ab87f8a08 feat: add warning in log when system prompt is reset
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
2024-11-10 17:18:46 +08:00
JustSong
833fa7ad6f feat: support set system_prompt for theme air & berry 2024-11-10 15:09:02 +08:00
JustSong
6eb0770a89 feat: support set system prompt for channel (close #1920) 2024-11-10 14:53:34 +08:00
JustSong
92cd46d64f feat: able to use ENFORCE_INCLUDE_USAGE to enforce include usage in response
Some checks are pending
CI / Unit tests (push) Waiting to run
CI / commit_lint (push) Waiting to run
2024-11-10 00:36:08 +08:00
lihangfu
2b2dc2c733 fix: update Spark Lite's domain to lite (#1896) 2024-11-09 23:55:55 +08:00
JustSong
a3d7df7f89 feat: update GeneralOpenAIRequest 2024-11-09 23:43:08 +08:00
wanthigh
c368232f50 fix: changeoptional field to pointer type (#1907)
* fix:修复在渠道配置中设置模型重定向时,temperature为0被忽略的问题

* fix: set optional fields to pointer type

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-11-09 23:31:46 +08:00
Laisky.Cai
cbfc983dc3 feat: add new claude models (#1910)
* feat: Add new models to ModelList in constants.go

* feat: update model lists and mappings for Claude 3.5 versions

---------

Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-11-09 22:48:54 +08:00
Ryo Shen
8ec092ba44 feat: add support for xAI (#1915)
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
2024-11-07 23:52:38 +08:00
shaoyun
b0b88a79ff feat: added support for Claude 3.5 Haiku (#1912) 2024-11-07 23:51:17 +08:00
JustSong
7e51b04221 feat: able to hide test model selector and balance col
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
2024-10-27 18:31:43 +08:00
JustSong
f75a17f8eb feat: always return usage in stream mode 2024-10-27 17:58:44 +08:00
Wei Tingjiang
6f13a3bb3c feat: update Gemini adaptor to support custom response format (#1892) 2024-10-27 17:10:50 +08:00
shaoyun
f092eed1db feat: add support for Claude Sonnet 3.5 v2 (#1888) 2024-10-27 17:10:02 +08:00
longkeyy
629378691b feat: update groq model and price (#1864) 2024-10-27 17:07:24 +08:00
liangjs
3716e1b0e6 fix: use modelMap when testing a channel (#1855)
Co-authored-by: oliang <oliang@tencent.com>
2024-10-27 17:06:41 +08:00
Pan, Wen-Ming
a4d6e7a886 feat: add Vertex AI gemini-1.5-pro-002 and gemini-1.5-flash-002 (#1854) 2024-10-27 17:04:41 +08:00
千寻简
cb772e5d06 fix:unsuccessful lobechat redirection link (#1843) 2024-10-27 17:03:35 +08:00
lihangfu
e32cb0b844 feat: support SparkDesk-v3.5-32K (#1832)
Co-authored-by: lihangfu <hfli8@iflytek.com>
2024-10-27 17:02:54 +08:00
抒情熊
fdd7bf41c0 feat: support multipart/form-data format request (#1690)
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
* "add parser multipart/form-data"

* chore: fix impl

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-22 17:32:47 +08:00
徐瑞东
29389ed44f fix: modify the type of token models to be text (#1761)
* fix: modify the type of token models to be text

* chore: update receiver name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-22 16:51:16 +08:00
byte911
88acc5a614 fix: return the usage info if not null (#1792)
Usage is missing.
2024-09-22 16:41:10 +08:00
TimeTrapzz
a21681096a feat: add siliconflow usage (#1798) 2024-09-22 16:31:26 +08:00
lihangfu
32f90a79a8 feat: support SparkDesk-v3.1-128K (#1732)
* feat: 支持SparkDesk-v3.1-128K以及hunyuan-vision

* feat: 支持SparkDesk-v3.1-128K以及hunyuan-vision

---------

Co-authored-by: lihangfu <hfli8@iflytek.com>
2024-09-22 16:29:09 +08:00
OnEvent
99c8c77504 feat: add oidc support (#1725)
Some checks are pending
CI / Unit tests (push) Waiting to run
CI / commit_lint (push) Waiting to run
* feat: add the ui for configuring the third-party standard OAuth2.0/OIDC.

- update SystemSetting.js
- add setup ui
- add configuration

* feat: add the ui for "allow the OAuth 2.0 to login"

- update SystemSetting.js

* feat: add OAuth 2.0 web ui and its process functions

- update common.js
- update AuthLogin.js
- update config.js

* fix: missing "Userinfo" endpoint configuration entry, used by OAuth clients to request user information from the IdP.

- update config.js
- update SystemSetting.js

* feat: updated the icons for Lark and OIDC to match the style of the icons for WeChat, EMail, GitHub.

- update lark.svg
- new oidc.svg

* refactor: Changing OAuth 2.0 to OIDC

* feat: add OIDC login method

* feat: Add support for OIDC login to the backend

* fix: Change the AppId and AppSecret on the Web UI to the standard usage: ClientId, ClientSecret.

* feat: Support quick configuration of OIDC through Well-Known Discovery Endpoint

* feat: Standardize terminology, add well-known configuration

- Change the AppId and AppSecret on the Server End to the standard usage: ClientId, ClientSecret.
- add Well-Known configuration to store in database, no actual use in server end but store and display in web ui only
2024-09-21 23:03:20 +08:00
TAKO
649ecbf29c feat: support new openai models (4o 0806, chatgpt-4o-latest) (#1721)
* feat: support new model gpt-4o-2024-08-06

* feat: support new model chatgpt-4o-latest
2024-09-21 23:01:19 +08:00
qinguoyi
3a27c90910 fix: getTokenById return token nil, make panic (#1728)
* fix:getTokenById return token nil, make panic

* chore: remove useless err check

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 23:00:29 +08:00
千寻简
cba82404ae feat: add lobechat open link options (#1741)
Co-authored-by: Star <iii9777@163.com>
2024-09-21 22:49:31 +08:00
forrestlinfeng
c9ac670ba1 feat: update stepfun models (#1740)
Co-authored-by: chenlinfeng <chenlinfeng@step.ai>
2024-09-21 22:48:46 +08:00
leavegee
15f815c23c fix: fix ali embedding model always use v1 (#1747)
* fix:ali embedding model: v2 and v3

* chore: use ctxkey.RequestModel to eliminate hardcoding

---------

Co-authored-by: xuejia <gexuejia@djbx.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 22:40:06 +08:00
majian
89b63ca96f feat: ResponseFormat support json_schema (#1759)
* feat: responseFormat support json_schema

* chore: rename struct name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 22:35:24 +08:00
Ghostz
8cc54489b9 feat: update disabled channel (#1780)
* Update disabled channel

* Update manage.go

* Update manage.go

* chore: add missing space

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-09-21 22:31:53 +08:00
guogeer
58bf60805e fix: postgres use COALESCE replace null (#1793)
Co-authored-by: jinqi.guo <jinqi.guo@ubtrobot.com>
2024-09-21 22:13:31 +08:00
AJ's Life Journey
6714cf96d6 fix: Groq organization not auto-disabled when blocked (#1822) 2024-09-21 22:12:09 +08:00
longkeyy
f9774698e9 feat: synchronize with the official release of the groq model (#1677)
update groq add gemma2-9b-it llama3.1 family fixup price k/token -> m/token
2024-08-06 23:51:08 +08:00
TAKO
2af6f6a166 feat: add Cloudflare New Free Model Llama 3.1 8b (#1703) 2024-08-06 23:49:48 +08:00
MotorBottle
04bb3ef392 feat: add Max Tokens and Context Window Setting Options for Ollama Channel (#1694)
* Update main.go with max_tokens param

* Update model.go with max_tokens param

* Update model.go

* Update main.go

* Update main.go

* Adds num_ctx param for Ollama Channel

* Added num_ctx param for ollama adapter

* Added num_ctx param for ollama adapter

* Improved data process logic
2024-08-06 23:44:37 +08:00
longkeyy
b4bfa418a8 feat: update gemini model and price (#1705) 2024-08-06 23:43:33 +08:00
SLKun
e7e99e558a feat: update Ollama embedding API to latest version with multi-text embedding support (#1715) 2024-08-06 23:43:20 +08:00
Shenghang Tsai
402fcf7f79 feat: add SiliconFlow (#1717)
* Add SiliconFlow

* Update README.md

* Update README.md

* Update channel.constants.js

* Update ChannelConstants.js

* Update channel.constants.js

* Update ChannelConstants.js

* Update compatible.go

* Update README.md
2024-08-06 23:42:25 +08:00
Junyan Qin
36039e329e docs: update introduction for QChatGPT (#1707) 2024-08-06 23:33:43 +08:00
Laisky.Cai
c936198ac8 feat: add Proxy channel type and relay mode (#1678)
Add the Proxy channel type and relay mode to support proxying requests to custom upstream services.
2024-07-22 22:51:19 +08:00
TAKO
296ab013b8 feat: support gpt-4o mini (#1665)
* feat: support gpt-4o mini

* feat: fix gpt-4o mini image price
2024-07-22 22:44:08 +08:00
zijiren
5f03c856b4 feat: fast build linux/arm64 frontend (#1663)
* feat: fast build linux/arm64 frontend

* fix: dockerfile as replace to AS

* fix: trim space
2024-07-22 22:39:22 +08:00
igophper
39383e5532 fix: support embedding models for doubao (#1662)
Fixes #1594
2024-07-22 22:38:50 +08:00
122 changed files with 3024 additions and 748 deletions

View File

@@ -1,61 +0,0 @@
name: Publish Docker image (amd64)
on:
push:
tags:
- 'v*.*.*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,4 +1,4 @@
name: Publish Docker image (amd64, English)
name: Publish Docker image (English)
on:
push:
@@ -34,6 +34,13 @@ jobs:
- name: Translate
run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
@@ -51,6 +58,7 @@ jobs:
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,10 +1,9 @@
name: Publish Docker image (arm64)
name: Publish Docker image
on:
push:
tags:
- 'v*.*.*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:

1
.gitignore vendored
View File

@@ -10,3 +10,4 @@ data
/web/node_modules
cmd.md
.env
/one-api

View File

@@ -1,4 +1,4 @@
FROM node:16 as builder
FROM --platform=$BUILDPLATFORM node:16 AS builder
WORKDIR /web
COPY ./VERSION .

View File

@@ -89,6 +89,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
+ [x] [xAI](https://x.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -113,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ 支持使用飞书进行授权登录
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)[这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login)
+ 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
@@ -251,9 +253,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
#### QChatGPT - QQ机器人
项目主页https://github.com/RockChinQ/QChatGPT
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
运行期间可以通过`!model`命令查看、切换可用模型。
### 部署到第三方平台
<details>
@@ -398,6 +400,7 @@ graph LR
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage默认不开启可选值为 `true` 和 `false`。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -35,6 +35,7 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var OidcEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
@@ -70,6 +71,13 @@ var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""
var OidcClientId = ""
var OidcClientSecret = ""
var OidcWellKnown = ""
var OidcAuthorizationEndpoint = ""
var OidcTokenEndpoint = ""
var OidcUserinfoEndpoint = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -152,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)

View File

@@ -20,4 +20,5 @@ const (
BaseURL = "base_url"
AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
)

View File

@@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = c.ShouldBind(&v)
}
if err != nil {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}

View File

@@ -137,3 +137,23 @@ func String2Int(str string) int {
}
return num
}
func Float64PtrMax(p *float64, maxValue float64) *float64 {
if p == nil {
return nil
}
if *p > maxValue {
return &maxValue
}
return p
}
func Float64PtrMin(p *float64, minValue float64) *float64 {
if p == nil {
return nil
}
if *p < minValue {
return &minValue
}
return p
}

View File

@@ -3,9 +3,10 @@ package render
import (
"encoding/json"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"strings"
)
func StringData(c *gin.Context, str string) {

View File

@@ -40,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}

225
controller/auth/oidc.go Normal file
View File

@@ -0,0 +1,225 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.OidcClientId,
"client_secret": config.OidcClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId)
if err == nil {
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
}
} else {
userId := c.GetInt(ctxkey.Id)
remainQuota, err = model.GetUserQuota(userId)

View File

@@ -81,6 +81,26 @@ type APGC2DGPTUsageResponse struct {
TotalUsed float64 `json:"total_used"`
}
type SiliconFlowUsageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Status bool `json:"status"`
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image"`
Email string `json:"email"`
IsAdmin bool `json:"isAdmin"`
Balance string `json:"balance"`
Status string `json:"status"`
Introduction string `json:"introduction"`
Role string `json:"role"`
ChargeBalance string `json:"chargeBalance"`
TotalBalance string `json:"totalBalance"`
Category string `json:"category"`
} `json:"data"`
}
// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
@@ -203,6 +223,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
return response.TotalAvailable, nil
}
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
url := "https://api.siliconflow.cn/v1/user/info"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := SiliconFlowUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
}
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
@@ -227,6 +269,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelAPI2GPTBalance(channel)
case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel)
case channeltype.SiliconFlow:
return updateChannelSiliconFlowBalance(channel)
default:
return 0, errors.New("尚未实现")
}

View File

@@ -76,10 +76,10 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
if modelMap != nil && modelMap[modelName] != "" {
modelName = modelMap[modelName]
}
}
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)

View File

@@ -36,6 +36,12 @@ func GetStatus(c *gin.Context) {
"chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled,
"oidc": config.OidcEnabled,
"oidc_client_id": config.OidcClientId,
"oidc_well_known": config.OidcWellKnown,
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
"oidc_token_endpoint": config.OidcTokenEndpoint,
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
},
})
return

View File

@@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
fallthrough
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
case relaymode.Proxy:
err = controller.RelayProxyHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
}
@@ -58,7 +60,7 @@ func Relay(c *gin.Context) {
channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
@@ -85,12 +87,14 @@ func Relay(c *gin.Context) {
channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
// BUG: bizErr is in race condition
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error,
@@ -117,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
return true
}
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {

8
go.mod
View File

@@ -25,7 +25,7 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.24.0
golang.org/x/crypto v0.31.0
golang.org/x/image v0.18.0
google.golang.org/api v0.187.0
gorm.io/driver/mysql v1.5.6
@@ -99,9 +99,9 @@ require (
golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect

16
go.sum
View File

@@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
@@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -140,6 +140,12 @@ func TokenAuth() func(c *gin.Context) {
return
}
}
// set channel id for proxy relay
if channelId := c.Param("channelid"); channelId != "" {
c.Set(ctxkey.SpecificChannelId, channelId)
}
c.Next()
}
}

View File

@@ -12,7 +12,7 @@ import (
)
type ModelRequest struct {
Model string `json:"model"`
Model string `json:"model" form:"model"`
}
func Distribute() func(c *gin.Context) {
@@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name)
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
}
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

27
middleware/gzip.go Normal file
View File

@@ -0,0 +1,27 @@
package middleware
import (
"compress/gzip"
"github.com/gin-gonic/gin"
"io"
"net/http"
)
func GzipDecodeMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(c.Request.Body)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
defer gzipReader.Close()
// Replace the request body with the decompressed data
c.Request.Body = io.NopCloser(gzipReader)
}
// Continue processing the request
c.Next()
}
}

View File

@@ -37,6 +37,7 @@ type Channel struct {
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
}
type ChannelConfig struct {

View File

@@ -3,6 +3,7 @@ package model
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
@@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
if username != "" {
tx = tx.Where("username = ?", username)
}

View File

@@ -28,6 +28,7 @@ func InitOptionMap() {
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
@@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) {
config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
config.GitHubOAuthEnabled = boolValue
case "OidcEnabled":
config.OidcEnabled = boolValue
case "WeChatAuthEnabled":
config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
@@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) {
config.LarkClientId = value
case "LarkClientSecret":
config.LarkClientSecret = value
case "OidcClientId":
config.OidcClientId = value
case "OidcClientSecret":
config.OidcClientSecret = value
case "OidcWellKnown":
config.OidcWellKnown = value
case "OidcAuthorizationEndpoint":
config.OidcAuthorizationEndpoint = value
case "OidcTokenEndpoint":
config.OidcTokenEndpoint = value
case "OidcUserinfoEndpoint":
config.OidcUserinfoEndpoint = value
case "Footer":
config.Footer = value
case "SystemName":

View File

@@ -30,7 +30,7 @@ type Token struct {
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"default:''"` // allowed models
Models *string `json:"models" gorm:"type:text"` // allowed models
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
}
@@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
return &token, err
}
func (token *Token) Insert() error {
func (t *Token) Insert() error {
var err error
err = DB.Create(token).Error
err = DB.Create(t).Error
return err
}
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
func (t *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
return err
}
func (token *Token) SelectUpdate() error {
func (t *Token) SelectUpdate() error {
// This can update zero values
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
}
func (token *Token) Delete() error {
func (t *Token) Delete() error {
var err error
err = DB.Delete(token).Error
err = DB.Delete(t).Error
return err
}
func (t *Token) GetModels() string {
if t == nil {
return ""
}
if t.Models == nil {
return ""
}
return *t.Models
}
func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 {
@@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId)
if err != nil {
return err
}
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)

View File

@@ -39,6 +39,7 @@ type User struct {
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error {
return nil
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}

View File

@@ -1,10 +1,11 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
)
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
@@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
lowerMessage := strings.ToLower(err.Message)
if strings.Contains(lowerMessage, "your access was terminated") ||
strings.Contains(lowerMessage, "violation of our policies") ||
strings.Contains(lowerMessage, "your credit balance is too low") ||
strings.Contains(lowerMessage, "organization has been disabled") ||
strings.Contains(lowerMessage, "credit") ||
strings.Contains(lowerMessage, "balance") ||
strings.Contains(lowerMessage, "permission denied") ||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
strings.Contains(lowerMessage, "已欠费") {
return true
}
return false

View File

@@ -15,6 +15,8 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
@@ -58,6 +60,10 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &deepl.Adaptor{}
case apitype.VertexAI:
return &vertexai.Adaptor{}
case apitype.Proxy:
return &proxy.Adaptor{}
case apitype.Replicate:
return &replicate.Adaptor{}
}
return nil
}

View File

@@ -1,7 +1,23 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
"qwen-turbo", "qwen-turbo-latest",
"qwen-plus", "qwen-plus-latest",
"qwen-max", "qwen-max-latest",
"qwen-max-longcontext",
"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
"qwen-vl-ocr", "qwen-vl-ocr-latest",
"qwen-audio-turbo",
"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
"qwen2-audio-instruct", "qwen-audio-chat",
"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
}

View File

@@ -3,6 +3,7 @@ package ali
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
@@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
if request.TopP >= 1 {
request.TopP = 0.9999
}
request.TopP = helper.Float64PtrMax(request.TopP, 0.9999)
return &ChatRequest{
Model: aliModel,
Input: Input{
@@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Model: request.Model,
Input: struct {
Texts []string `json:"texts"`
}{
@@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
StatusCode: resp.StatusCode,
}, nil
}
requestModel := c.GetString(ctxkey.RequestModel)
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
fullTextResponse.Model = requestModel
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -16,13 +16,13 @@ type Input struct {
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}

View File

@@ -3,7 +3,11 @@ package anthropic
var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
"claude-3-5-haiku-20241022",
}

View File

@@ -48,8 +48,8 @@ type Request struct {
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`

View File

@@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
}
func awsModelID(requestModel string) (string, error) {

View File

@@ -11,8 +11,8 @@ type Request struct {
Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`

View File

@@ -6,8 +6,8 @@ package aws
type Request struct {
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
}
// Response is the response from AWS Llama3

View File

@@ -35,9 +35,9 @@ type Message struct {
type ChatRequest struct {
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
PenaltyScore *float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`

View File

@@ -1,6 +1,7 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",

View File

@@ -9,5 +9,5 @@ type Request struct {
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
}

View File

@@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
K: textRequest.TopK,
Stream: textRequest.Stream,
FrequencyPenalty: textRequest.FrequencyPenalty,
PresencePenalty: textRequest.FrequencyPenalty,
PresencePenalty: textRequest.PresencePenalty,
Seed: int(textRequest.Seed),
}
if cohereRequest.Model == "" {

View File

@@ -10,15 +10,15 @@ type Request struct {
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
Connectors []Connector `json:"connectors,omitempty"`
Documents []Document `json:"documents,omitempty"`
Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3
Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3
MaxTokens int `json:"max_tokens,omitempty"`
MaxInputTokens int `json:"max_input_tokens,omitempty"`
K int `json:"k,omitempty"` // 默认值为0
P float64 `json:"p,omitempty"` // 默认值为0.75
P *float64 `json:"p,omitempty"` // 默认值为0.75
Seed int `json:"seed,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
Tools []Tool `json:"tools,omitempty"`
ToolResults []ToolResult `json:"tool_results,omitempty"`
}

View File

@@ -7,8 +7,12 @@ import (
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:
}
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
}

View File

@@ -24,7 +24,12 @@ func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
defaultVersion := config.GeminiVersion
if meta.ActualModelName == "gemini-2.0-flash-exp" {
defaultVersion = "v1beta"
}
version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
action := ""
switch meta.Mode {
case relaymode.Embeddings:
@@ -36,6 +41,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if meta.IsStream {
action = "streamGenerateContent?alt=sse"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
}

View File

@@ -3,6 +3,9 @@ package gemini
// https://ai.google.dev/models/gemini
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
"gemini-pro", "gemini-1.0-pro",
"gemini-1.5-flash", "gemini-1.5-pro",
"text-embedding-004", "aqa",
"gemini-2.0-flash-exp",
"gemini-2.0-flash-thinking-exp",
}

View File

@@ -4,11 +4,12 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
@@ -28,6 +29,11 @@ const (
VisionMaxImageNum = 16
)
var mimeTypeMap = map[string]string{
"json_object": "application/json",
"text": "text/plain",
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
@@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
@@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.ResponseFormat != nil {
if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
geminiRequest.GenerationConfig.ResponseMimeType = mimeType
}
if textRequest.ResponseFormat.JsonSchema != nil {
geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
}
}
if textRequest.Tools != nil {
functions := make([]model.Function, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
@@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = getToolCalls(&candidate)
} else {
choice.Message.Content = candidate.Content.Parts[0].Text
var builder strings.Builder
for _, part := range candidate.Content.Parts {
if i > 0 {
builder.WriteString("\n")
}
builder.WriteString(part.Text)
}
choice.Message.Content = builder.String()
}
} else {
choice.Message.Content = ""

View File

@@ -65,8 +65,10 @@ type ChatTools struct {
}
type ChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`

View File

@@ -4,9 +4,24 @@ package groq
var ModelList = []string{
"gemma-7b-it",
"llama2-7b-2048",
"llama2-70b-4096",
"mixtral-8x7b-32768",
"llama3-8b-8192",
"gemma2-9b-it",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"llama-3.2-11b-text-preview",
"llama-3.2-11b-vision-preview",
"llama-3.2-1b-preview",
"llama-3.2-3b-preview",
"llama-3.2-11b-vision-preview",
"llama-3.2-90b-text-preview",
"llama-3.2-90b-vision-preview",
"llama-guard-3-8b",
"llama3-70b-8192",
"llama3-8b-8192",
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"llava-v1.5-7b-4096-preview",
"mixtral-8x7b-32768",
"distil-whisper-large-v3-en",
"whisper-large-v3",
"whisper-large-v3-turbo",
}

View File

@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL)
}
return fullRequestURL, nil
}

View File

@@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,
},
Stream: request.Stream,
}
@@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
common.SetEventStreamHeaders(c)
for scanner.Scan() {
data := strings.TrimPrefix(scanner.Text(), "}")
data = data + "}"
data := scanner.Text()
if strings.HasPrefix(data, "}") {
data = strings.TrimPrefix(data, "}") + "}"
}
var ollamaResponse ChatResponse
err := json.Unmarshal([]byte(data), &ollamaResponse)
@@ -158,7 +162,14 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "),
Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
}
}
@@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, 1),
Model: "text-embedding-v1",
Model: response.Model,
Usage: model.Usage{TotalTokens: 0},
}
for i, embedding := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: response.Embedding,
Index: i,
Embedding: embedding,
})
}
return &openAIEmbeddingResponse
}

View File

@@ -2,11 +2,13 @@ package ollama
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type Message struct {
@@ -38,10 +40,14 @@ type ChatResponse struct {
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Input []string `json:"input"`
// Truncate bool `json:"truncate,omitempty"`
Options *Options `json:"options,omitempty"`
// KeepAlive string `json:"keep_alive,omitempty"`
}
type EmbeddingResponse struct {
Error string `json:"error,omitempty"`
Embedding []float64 `json:"embedding,omitempty"`
Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings"`
}

View File

@@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
if request.Stream {
// always return usage in stream mode
if request.StreamOptions == nil {
request.StreamOptions = &model.StreamOptions{}
}
request.StreamOptions.IncludeUsage = true
}
return request, nil
}

View File

@@ -11,8 +11,10 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/adaptor/xai"
"github.com/songquanpeng/one-api/relay/channeltype"
)
@@ -30,6 +32,8 @@ var CompatibleChannels = []int{
channeltype.DeepSeek,
channeltype.TogetherAI,
channeltype.Novita,
channeltype.SiliconFlow,
channeltype.XAI,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -60,6 +64,10 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "doubao", doubao.ModelList
case channeltype.Novita:
return "novita", novita.ModelList
case channeltype.SiliconFlow:
return "siliconflow", siliconflow.ModelList
case channeltype.XAI:
return "xai", xai.ModelList
default:
return "openai", ModelList
}

View File

@@ -8,6 +8,9 @@ var ModelList = []string{
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"chatgpt-4o-latest",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
@@ -17,4 +20,7 @@ var ModelList = []string{
"dall-e-2", "dall-e-3",
"whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
"o1", "o1-2024-12-17",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
}

View File

@@ -2,15 +2,16 @@ package openai
import (
"fmt"
"strings"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
"strings"
)
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.CompletionTokens = CountTokenText(responseText, modelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}

View File

@@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
// but for empty choice and no usage, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
render.StringData(c, data)

View File

@@ -110,7 +110,7 @@ func CountTokenMessages(messages []model.Message, model string) int {
if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail)
imageTokens, err := countImageTokens(url, detail, model)
if err != nil {
logger.SysError("error counting image tokens: " + err.Error())
} else {
@@ -134,11 +134,15 @@ const (
lowDetailCost = 85
highDetailCostPerTile = 170
additionalCost = 85
// gpt-4o-mini cost higher than other model
gpt4oMiniLowDetailCost = 2833
gpt4oMiniHighDetailCost = 5667
gpt4oMiniAdditionalCost = 2833
)
// https://platform.openai.com/docs/guides/vision/calculating-costs
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
func countImageTokens(url string, detail string) (_ int, err error) {
func countImageTokens(url string, detail string, model string) (_ int, err error) {
var fetchSize = true
var width, height int
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
@@ -172,6 +176,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
}
switch detail {
case "low":
if strings.HasPrefix(model, "gpt-4o-mini") {
return gpt4oMiniLowDetailCost, nil
}
return lowDetailCost, nil
case "high":
if fetchSize {
@@ -191,6 +198,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
height = int(float64(height) * ratio)
}
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
if strings.HasPrefix(model, "gpt-4o-mini") {
return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil
}
result := numSquares*highDetailCostPerTile + additionalCost
return result, nil
default:

View File

@@ -1,8 +1,16 @@
package openai
import "github.com/songquanpeng/one-api/relay/model"
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
)
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
Error := model.Error{
Message: err.Error(),
Type: "one_api_error",

View File

@@ -20,9 +20,9 @@ type Prompt struct {
type ChatRequest struct {
Prompt Prompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}

View File

@@ -0,0 +1,89 @@
package proxy
import (
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
const channelName = "proxy"
type Adaptor struct{}
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("notimplement")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
for k, v := range resp.Header {
for _, vv := range v {
c.Writer.Header().Set(k, vv)
}
}
c.Writer.WriteHeader(resp.StatusCode)
if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil {
return nil, &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: gerr.Error(),
},
}
}
return nil, nil
}
func (a *Adaptor) GetModelList() (models []string) {
return nil
}
func (a *Adaptor) GetChannelName() string {
return channelName
}
// GetRequestURL remove static prefix, and return the real request url to the upstream service
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId)
return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
for k, v := range c.Request.Header {
req.Header.Set(k, v[0])
}
// remove unnecessary headers
req.Header.Del("Host")
req.Header.Del("Content-Length")
req.Header.Del("Accept-Encoding")
req.Header.Del("Connection")
// set authorization header
req.Header.Set("Authorization", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return nil, errors.Errorf("not implement")
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -0,0 +1,136 @@
package replicate
import (
"fmt"
"io"
"net/http"
"slices"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
meta *meta.Meta
}
// ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return DrawImageRequest{
Input: ImageInput{
Steps: 25,
Prompt: request.Prompt,
Guidance: 3,
Seed: int(time.Now().UnixNano()),
SafetyTolerance: 5,
NImages: 1, // replicate will always return 1 image
Width: 1440,
Height: 1440,
AspectRatio: "1:1",
},
}, nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if !request.Stream {
// TODO: support non-stream mode
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
}
// Build the prompt from OpenAI messages
var promptBuilder strings.Builder
for _, message := range request.Messages {
switch msgCnt := message.Content.(type) {
case string:
promptBuilder.WriteString(message.Role)
promptBuilder.WriteString(": ")
promptBuilder.WriteString(msgCnt)
promptBuilder.WriteString("\n")
default:
}
}
replicateRequest := ReplicateChatRequest{
Input: ChatInput{
Prompt: promptBuilder.String(),
MaxTokens: request.MaxTokens,
Temperature: 1.0,
TopP: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
},
}
// Map optional fields
if request.Temperature != nil {
replicateRequest.Input.Temperature = *request.Temperature
}
if request.TopP != nil {
replicateRequest.Input.TopP = *request.TopP
}
if request.PresencePenalty != nil {
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
}
if request.FrequencyPenalty != nil {
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
}
if request.MaxTokens > 0 {
replicateRequest.Input.MaxTokens = request.MaxTokens
} else if request.MaxTokens == 0 {
replicateRequest.Input.MaxTokens = 500
}
return replicateRequest, nil
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if !slices.Contains(ModelList, meta.OriginModelName) {
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
}
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
logger.Info(c, "send request to replicate")
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp)
case relaymode.ChatCompletions:
err, usage = ChatHandler(c, resp)
default:
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "replicate"
}

View File

@@ -0,0 +1,191 @@
package replicate
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
func ChatHandler(c *gin.Context, resp *http.Response) (
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
if resp.StatusCode != http.StatusCreated {
payload, _ := io.ReadAll(resp.Body)
return openai.ErrorWrapper(
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
"bad_status_code", http.StatusInternalServerError),
nil
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
respData := new(ChatResponse)
if err = json.Unmarshal(respBody, respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
for {
err = func() error {
// get task
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, respData.URLs.Get, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
taskResp, err := http.DefaultClient.Do(taskReq)
if err != nil {
return errors.Wrap(err, "get task")
}
defer taskResp.Body.Close()
if taskResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(taskResp.Body)
return errors.Errorf("bad status code [%d]%s",
taskResp.StatusCode, string(payload))
}
taskBody, err := io.ReadAll(taskResp.Body)
if err != nil {
return errors.Wrap(err, "read task response")
}
taskData := new(ChatResponse)
if err = json.Unmarshal(taskBody, taskData); err != nil {
return errors.Wrap(err, "decode task response")
}
switch taskData.Status {
case "succeeded":
case "failed", "canceled":
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
default:
time.Sleep(time.Second * 3)
return errNextLoop
}
if taskData.URLs.Stream == "" {
return errors.New("stream url is empty")
}
// request stream url
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
if err != nil {
return errors.Wrap(err, "chat stream handler")
}
ctxMeta := meta.GetByContext(c)
usage = openai.ResponseText2Usage(responseText,
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
return nil
}()
if err != nil {
if errors.Is(err, errNextLoop) {
continue
}
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
}
break
}
return nil, usage
}
const (
eventPrefix = "event: "
dataPrefix = "data: "
done = "[DONE]"
)
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
// request stream endpoint
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
if err != nil {
return "", errors.Wrap(err, "new request to stream")
}
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
streamReq.Header.Set("Accept", "text/event-stream")
streamReq.Header.Set("Cache-Control", "no-store")
resp, err := http.DefaultClient.Do(streamReq)
if err != nil {
return "", errors.Wrap(err, "do request to stream")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(resp.Body)
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
// Handle comments starting with ':'
if strings.HasPrefix(line, ":") {
continue
}
// Parse SSE fields
if strings.HasPrefix(line, eventPrefix) {
event := strings.TrimSpace(line[len(eventPrefix):])
var data string
// Read the following lines to get data and id
for scanner.Scan() {
nextLine := scanner.Text()
if nextLine == "" {
break
}
if strings.HasPrefix(nextLine, dataPrefix) {
data = nextLine[len(dataPrefix):]
} else if strings.HasPrefix(nextLine, "id:") {
// id = strings.TrimSpace(nextLine[len("id:"):])
}
}
if event == "output" {
render.StringData(c, data)
responseText += data
} else if event == "done" {
render.Done(c)
doneRendered = true
break
}
}
}
if err := scanner.Err(); err != nil {
return "", errors.Wrap(err, "scan stream")
}
if !doneRendered {
render.Done(c)
}
return responseText, nil
}

View File

@@ -0,0 +1,58 @@
package replicate
// ModelList is a list of models that can be used with Replicate.
//
// https://replicate.com/pricing
var ModelList = []string{
// -------------------------------------
// image model
// -------------------------------------
"black-forest-labs/flux-1.1-pro",
"black-forest-labs/flux-1.1-pro-ultra",
"black-forest-labs/flux-canny-dev",
"black-forest-labs/flux-canny-pro",
"black-forest-labs/flux-depth-dev",
"black-forest-labs/flux-depth-pro",
"black-forest-labs/flux-dev",
"black-forest-labs/flux-dev-lora",
"black-forest-labs/flux-fill-dev",
"black-forest-labs/flux-fill-pro",
"black-forest-labs/flux-pro",
"black-forest-labs/flux-redux-dev",
"black-forest-labs/flux-redux-schnell",
"black-forest-labs/flux-schnell",
"black-forest-labs/flux-schnell-lora",
"ideogram-ai/ideogram-v2",
"ideogram-ai/ideogram-v2-turbo",
"recraft-ai/recraft-v3",
"recraft-ai/recraft-v3-svg",
"stability-ai/stable-diffusion-3",
"stability-ai/stable-diffusion-3.5-large",
"stability-ai/stable-diffusion-3.5-large-turbo",
"stability-ai/stable-diffusion-3.5-medium",
// -------------------------------------
// language model
// -------------------------------------
"ibm-granite/granite-20b-code-instruct-8k",
"ibm-granite/granite-3.0-2b-instruct",
"ibm-granite/granite-3.0-8b-instruct",
"ibm-granite/granite-8b-code-instruct-128k",
"meta/llama-2-13b",
"meta/llama-2-13b-chat",
"meta/llama-2-70b",
"meta/llama-2-70b-chat",
"meta/llama-2-7b",
"meta/llama-2-7b-chat",
"meta/meta-llama-3.1-405b-instruct",
"meta/meta-llama-3-70b",
"meta/meta-llama-3-70b-instruct",
"meta/meta-llama-3-8b",
"meta/meta-llama-3-8b-instruct",
"mistralai/mistral-7b-instruct-v0.2",
"mistralai/mistral-7b-v0.1",
"mistralai/mixtral-8x7b-instruct-v0.1",
// -------------------------------------
// video model
// -------------------------------------
// "minimax/video-01", // TODO: implement the adaptor
}

View File

@@ -0,0 +1,222 @@
package replicate
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"image"
"image/png"
"io"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"golang.org/x/image/webp"
"golang.org/x/sync/errgroup"
)
// ImagesEditsHandler just copy response body to client
//
// https://replicate.com/black-forest-labs/flux-fill-pro
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
// c.Writer.WriteHeader(resp.StatusCode)
// for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
// }
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
// }
// defer resp.Body.Close()
// return nil, nil
// }
var errNextLoop = errors.New("next_loop")
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
if resp.StatusCode != http.StatusCreated {
payload, _ := io.ReadAll(resp.Body)
return openai.ErrorWrapper(
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
"bad_status_code", http.StatusInternalServerError),
nil
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
respData := new(ImageResponse)
if err = json.Unmarshal(respBody, respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
for {
err = func() error {
// get task
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, respData.URLs.Get, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
taskResp, err := http.DefaultClient.Do(taskReq)
if err != nil {
return errors.Wrap(err, "get task")
}
defer taskResp.Body.Close()
if taskResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(taskResp.Body)
return errors.Errorf("bad status code [%d]%s",
taskResp.StatusCode, string(payload))
}
taskBody, err := io.ReadAll(taskResp.Body)
if err != nil {
return errors.Wrap(err, "read task response")
}
taskData := new(ImageResponse)
if err = json.Unmarshal(taskBody, taskData); err != nil {
return errors.Wrap(err, "decode task response")
}
switch taskData.Status {
case "succeeded":
case "failed", "canceled":
return errors.Errorf("task failed: %s", taskData.Status)
default:
time.Sleep(time.Second * 3)
return errNextLoop
}
output, err := taskData.GetOutput()
if err != nil {
return errors.Wrap(err, "get output")
}
if len(output) == 0 {
return errors.New("response output is empty")
}
var mu sync.Mutex
var pool errgroup.Group
respBody := &openai.ImageResponse{
Created: taskData.CompletedAt.Unix(),
Data: []openai.ImageData{},
}
for _, imgOut := range output {
imgOut := imgOut
pool.Go(func() error {
// download image
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, imgOut, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
imgResp, err := http.DefaultClient.Do(downloadReq)
if err != nil {
return errors.Wrap(err, "download image")
}
defer imgResp.Body.Close()
if imgResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(imgResp.Body)
return errors.Errorf("bad status code [%d]%s",
imgResp.StatusCode, string(payload))
}
imgData, err := io.ReadAll(imgResp.Body)
if err != nil {
return errors.Wrap(err, "read image")
}
imgData, err = ConvertImageToPNG(imgData)
if err != nil {
return errors.Wrap(err, "convert image")
}
mu.Lock()
respBody.Data = append(respBody.Data, openai.ImageData{
B64Json: fmt.Sprintf("data:image/png;base64,%s",
base64.StdEncoding.EncodeToString(imgData)),
})
mu.Unlock()
return nil
})
}
if err := pool.Wait(); err != nil {
if len(respBody.Data) == 0 {
return errors.WithStack(err)
}
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
}
c.JSON(http.StatusOK, respBody)
return nil
}()
if err != nil {
if errors.Is(err, errNextLoop) {
continue
}
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
}
break
}
return nil, nil
}
// ConvertImageToPNG converts a WebP image to PNG format
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
// bypass if it's already a PNG image
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
return webpData, nil
}
// check if is jpeg, convert to png
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
img, _, err := image.Decode(bytes.NewReader(webpData))
if err != nil {
return nil, errors.Wrap(err, "decode jpeg")
}
var pngBuffer bytes.Buffer
if err := png.Encode(&pngBuffer, img); err != nil {
return nil, errors.Wrap(err, "encode png")
}
return pngBuffer.Bytes(), nil
}
// Decode the WebP image
img, err := webp.Decode(bytes.NewReader(webpData))
if err != nil {
return nil, errors.Wrap(err, "decode webp")
}
// Encode the image as PNG
var pngBuffer bytes.Buffer
if err := png.Encode(&pngBuffer, img); err != nil {
return nil, errors.Wrap(err, "encode png")
}
return pngBuffer.Bytes(), nil
}

View File

@@ -0,0 +1,159 @@
package replicate
import (
"time"
"github.com/pkg/errors"
)
// DrawImageRequest draw image by fluxpro
//
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
type DrawImageRequest struct {
Input ImageInput `json:"input"`
}
// ImageInput is input of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
type ImageInput struct {
Steps int `json:"steps" binding:"required,min=1"`
Prompt string `json:"prompt" binding:"required,min=5"`
ImagePrompt string `json:"image_prompt"`
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
Interval int `json:"interval" binding:"required,min=1,max=4"`
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
Seed int `json:"seed"`
NImages int `json:"n_images" binding:"required,min=1,max=8"`
Width int `json:"width" binding:"required,min=256,max=1440"`
Height int `json:"height" binding:"required,min=256,max=1440"`
}
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
//
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
type InpaintingImageByFlusReplicateRequest struct {
Input FluxInpaintingInput `json:"input"`
}
// FluxInpaintingInput is input of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
type FluxInpaintingInput struct {
Mask string `json:"mask" binding:"required"`
Image string `json:"image" binding:"required"`
Seed int `json:"seed"`
Steps int `json:"steps" binding:"required,min=1"`
Prompt string `json:"prompt" binding:"required,min=5"`
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
OutputFormat string `json:"output_format"`
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
PromptUnsampling bool `json:"prompt_unsampling"`
}
// ImageResponse is response of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
type ImageResponse struct {
CompletedAt time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
DataRemoved bool `json:"data_removed"`
Error string `json:"error"`
ID string `json:"id"`
Input DrawImageRequest `json:"input"`
Logs string `json:"logs"`
Metrics FluxMetrics `json:"metrics"`
// Output could be `string` or `[]string`
Output any `json:"output"`
StartedAt time.Time `json:"started_at"`
Status string `json:"status"`
URLs FluxURLs `json:"urls"`
Version string `json:"version"`
}
func (r *ImageResponse) GetOutput() ([]string, error) {
switch v := r.Output.(type) {
case string:
return []string{v}, nil
case []string:
return v, nil
case nil:
return nil, nil
case []interface{}:
// convert []interface{} to []string
ret := make([]string, len(v))
for idx, vv := range v {
if vvv, ok := vv.(string); ok {
ret[idx] = vvv
} else {
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
}
}
return ret, nil
default:
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
}
}
// FluxMetrics is metrics of ImageResponse
type FluxMetrics struct {
ImageCount int `json:"image_count"`
PredictTime float64 `json:"predict_time"`
TotalTime float64 `json:"total_time"`
}
// FluxURLs is urls of ImageResponse
type FluxURLs struct {
Get string `json:"get"`
Cancel string `json:"cancel"`
}
type ReplicateChatRequest struct {
Input ChatInput `json:"input" form:"input" binding:"required"`
}
// ChatInput is input of ChatByReplicateRequest
//
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
type ChatInput struct {
TopK int `json:"top_k"`
TopP float64 `json:"top_p"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
MinTokens int `json:"min_tokens"`
Temperature float64 `json:"temperature"`
SystemPrompt string `json:"system_prompt"`
StopSequences string `json:"stop_sequences"`
PromptTemplate string `json:"prompt_template"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
}
// ChatResponse is response of ChatByReplicateRequest
//
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
type ChatResponse struct {
CompletedAt time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
DataRemoved bool `json:"data_removed"`
Error string `json:"error"`
ID string `json:"id"`
Input ChatInput `json:"input"`
Logs string `json:"logs"`
Metrics FluxMetrics `json:"metrics"`
// Output could be `string` or `[]string`
Output []string `json:"output"`
StartedAt time.Time `json:"started_at"`
Status string `json:"status"`
URLs ChatResponseUrl `json:"urls"`
Version string `json:"version"`
}
// ChatResponseUrl is task urls of ChatResponse
type ChatResponseUrl struct {
Stream string `json:"stream"`
Get string `json:"get"`
Cancel string `json:"cancel"`
}

View File

@@ -0,0 +1,36 @@
package siliconflow
// https://docs.siliconflow.cn/docs/getting-started
var ModelList = []string{
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"THUDM/glm-4-9b-chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
}

View File

@@ -1,7 +1,13 @@
package stepfun
var ModelList = []string{
"step-1-8k",
"step-1-32k",
"step-1-128k",
"step-1-256k",
"step-1-flash",
"step-2-16k",
"step-1v-8k",
"step-1v-32k",
"step-1-200k",
"step-1x-medium",
}

View File

@@ -5,4 +5,5 @@ var ModelList = []string{
"hunyuan-standard",
"hunyuan-standard-256K",
"hunyuan-pro",
"hunyuan-vision",
}

View File

@@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
Model: &request.Model,
Stream: &request.Stream,
Messages: messages,
TopP: &request.TopP,
Temperature: &request.Temperature,
TopP: request.TopP,
Temperature: request.Temperature,
}
}

View File

@@ -13,7 +13,12 @@ import (
)
var ModelList = []string{
"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229",
"claude-3-haiku@20240307",
"claude-3-sonnet@20240229",
"claude-3-opus@20240229",
"claude-3-5-sonnet@20240620",
"claude-3-5-sonnet-v2@20241022",
"claude-3-5-haiku@20241022",
}
const anthropicVersion = "vertex-2023-10-16"

View File

@@ -11,8 +11,8 @@ type Request struct {
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`

View File

@@ -15,7 +15,10 @@ import (
)
var ModelList = []string{
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
"gemini-pro", "gemini-pro-vision",
"gemini-1.5-pro-001", "gemini-1.5-flash-001",
"gemini-1.5-pro-002", "gemini-1.5-flash-002",
"gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp",
}
type Adaptor struct {

View File

@@ -0,0 +1,5 @@
package xai
var ModelList = []string{
"grok-beta",
}

View File

@@ -5,6 +5,8 @@ var ModelList = []string{
"SparkDesk-v1.1",
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.1-128K",
"SparkDesk-v3.5",
"SparkDesk-v3.5-32K",
"SparkDesk-v4.0",
}

View File

@@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
}
func parseAPIVersionByModelName(modelName string) string {
parts := strings.Split(modelName, "-")
if len(parts) == 2 {
return parts[1]
index := strings.IndexAny(modelName, "-")
if index != -1 {
return modelName[index+1:]
}
return ""
}
@@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string {
func apiVersion2domain(apiVersion string) string {
switch apiVersion {
case "v1.1":
return "general"
return "lite"
case "v2.1":
return "generalv2"
case "v3.1":
return "generalv3"
case "v3.1-128K":
return "pro-128k"
case "v3.5":
return "generalv3.5"
case "v3.5-32K":
return "max-32k"
case "v4.0":
return "4.0Ultra"
}
@@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string {
}
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
var authUrl string
domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
switch apiVersion {
case "v3.1-128K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret)
break
case "v3.5-32K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
break
default:
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
}
return domain, authUrl
}

View File

@@ -20,7 +20,7 @@ type ChatRequest struct {
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`

View File

@@ -4,13 +4,13 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"math"
"net/http"
"strings"
)
@@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, err
default:
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
// TopP [0.0, 1.0]
request.TopP = helper.Float64PtrMax(request.TopP, 1)
request.TopP = helper.Float64PtrMin(request.TopP, 0)
// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
// Temperature [0.0, 1.0]
request.Temperature = helper.Float64PtrMax(request.Temperature, 1)
request.Temperature = helper.Float64PtrMin(request.Temperature, 0)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil

View File

@@ -12,8 +12,8 @@ type Message struct {
type Request struct {
Prompt []Message `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}

View File

@@ -18,6 +18,8 @@ const (
Cloudflare
DeepL
VertexAI
Proxy
Replicate
Dummy // this one is only for count, do not add any channel after this
)

View File

@@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{
"720x1280": 1,
"1280x720": 1,
},
"step-1x-medium": {
"256x256": 1,
"512x512": 1,
"768x768": 1,
"1024x1024": 1,
"1280x800": 1,
"800x1280": 1,
},
}
var ImageGenerationAmounts = map[string][2]int{
@@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1},
"step-1x-medium": {1, 1},
}
var ImagePromptLengthLimitations = map[string]int{
@@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{
"ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
"step-1x-medium": 4000,
}
var ImageOriginModelName = map[string]string{

View File

@@ -34,7 +34,11 @@ var ModelRatio = map[string]float64{
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
@@ -44,6 +48,12 @@ var ModelRatio = map[string]float64{
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"o1": 7.5, // $15.00 / 1M input tokens
"o1-2024-12-17": 7.5,
"o1-preview": 7.5, // $15.00 / 1M input tokens
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5, // $3.00 / 1M input tokens
"o1-mini-2024-09-12": 1.5,
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2,
@@ -75,8 +85,10 @@ var ModelRatio = map[string]float64{
"claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-5-haiku-20241022": 1.0 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-4.0-8K": 0.120 * RMB,
@@ -96,12 +108,15 @@ var ModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.0-pro": 1,
"gemini-1.5-pro": 1,
"gemini-1.5-pro-001": 1,
"gemini-1.5-flash": 1,
"gemini-1.5-flash-001": 1,
"gemini-2.0-flash-exp": 1,
"gemini-2.0-flash-thinking-exp": 1,
"aqa": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
@@ -113,19 +128,86 @@ var ModelRatio = map[string]float64{
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"cogview-3": 0.25 * RMB,
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"qwen-turbo": 1.4286, // ¥0.02 / 1k tokens
"qwen-turbo-latest": 1.4286,
"qwen-plus": 1.4286,
"qwen-plus-latest": 1.4286,
"qwen-max": 1.4286,
"qwen-max-latest": 1.4286,
"qwen-max-longcontext": 1.4286,
"qwen-vl-max": 1.4286,
"qwen-vl-max-latest": 1.4286,
"qwen-vl-plus": 1.4286,
"qwen-vl-plus-latest": 1.4286,
"qwen-vl-ocr": 1.4286,
"qwen-vl-ocr-latest": 1.4286,
"qwen-audio-turbo": 1.4286,
"qwen-math-plus": 1.4286,
"qwen-math-plus-latest": 1.4286,
"qwen-math-turbo": 1.4286,
"qwen-math-turbo-latest": 1.4286,
"qwen-coder-plus": 1.4286,
"qwen-coder-plus-latest": 1.4286,
"qwen-coder-turbo": 1.4286,
"qwen-coder-turbo-latest": 1.4286,
"qwq-32b-preview": 1.4286,
"qwen2.5-72b-instruct": 1.4286,
"qwen2.5-32b-instruct": 1.4286,
"qwen2.5-14b-instruct": 1.4286,
"qwen2.5-7b-instruct": 1.4286,
"qwen2.5-3b-instruct": 1.4286,
"qwen2.5-1.5b-instruct": 1.4286,
"qwen2.5-0.5b-instruct": 1.4286,
"qwen2-72b-instruct": 1.4286,
"qwen2-57b-a14b-instruct": 1.4286,
"qwen2-7b-instruct": 1.4286,
"qwen2-1.5b-instruct": 1.4286,
"qwen2-0.5b-instruct": 1.4286,
"qwen1.5-110b-chat": 1.4286,
"qwen1.5-72b-chat": 1.4286,
"qwen1.5-32b-chat": 1.4286,
"qwen1.5-14b-chat": 1.4286,
"qwen1.5-7b-chat": 1.4286,
"qwen1.5-1.8b-chat": 1.4286,
"qwen1.5-0.5b-chat": 1.4286,
"qwen-72b-chat": 1.4286,
"qwen-14b-chat": 1.4286,
"qwen-7b-chat": 1.4286,
"qwen-1.8b-chat": 1.4286,
"qwen-1.8b-longcontext-chat": 1.4286,
"qwen2-vl-7b-instruct": 1.4286,
"qwen2-vl-2b-instruct": 1.4286,
"qwen-vl-v1": 1.4286,
"qwen-vl-chat-v1": 1.4286,
"qwen2-audio-instruct": 1.4286,
"qwen-audio-chat": 1.4286,
"qwen2.5-math-72b-instruct": 1.4286,
"qwen2.5-math-7b-instruct": 1.4286,
"qwen2.5-math-1.5b-instruct": 1.4286,
"qwen2-math-72b-instruct": 1.4286,
"qwen2-math-7b-instruct": 1.4286,
"qwen2-math-1.5b-instruct": 1.4286,
"qwen2.5-coder-32b-instruct": 1.4286,
"qwen2.5-coder-14b-instruct": 1.4286,
"qwen2.5-coder-7b-instruct": 1.4286,
"qwen2.5-coder-3b-instruct": 1.4286,
"qwen2.5-coder-1.5b-instruct": 1.4286,
"qwen2.5-coder-0.5b-instruct": 1.4286,
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"ali-stable-diffusion-xl": 8,
"ali-stable-diffusion-v1.5": 8,
"wanx-v1": 8,
"text-embedding-v3": 0.05,
"text-embedding-v2": 0.05,
"text-embedding-async-v2": 0.05,
"text-embedding-async-v1": 0.05,
"ali-stable-diffusion-xl": 8.00,
"ali-stable-diffusion-v1.5": 8.00,
"wanx-v1": 8.00,
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
@@ -156,20 +238,35 @@ var ModelRatio = map[string]float64{
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
"llama3-70b-8192": 0.59 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"llama3-8b-8192": 0.05 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
"llama2-70b-4096": 0.64 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"gemma-7b-it": 0.07 / 1000000 * USD,
"gemma2-9b-it": 0.20 / 1000000 * USD,
"llama-3.1-70b-versatile": 0.59 / 1000000 * USD,
"llama-3.1-8b-instant": 0.05 / 1000000 * USD,
"llama-3.2-11b-text-preview": 0.05 / 1000000 * USD,
"llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD,
"llama-3.2-1b-preview": 0.05 / 1000000 * USD,
"llama-3.2-3b-preview": 0.05 / 1000000 * USD,
"llama-3.2-90b-text-preview": 0.59 / 1000000 * USD,
"llama-guard-3-8b": 0.05 / 1000000 * USD,
"llama3-70b-8192": 0.59 / 1000000 * USD,
"llama3-8b-8192": 0.05 / 1000000 * USD,
"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
"llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD,
"mixtral-8x7b-32768": 0.24 / 1000000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB,
// stepfun todo
"step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB,
// https://platform.stepfun.com/docs/pricing/details
"step-1-8k": 0.005 / 1000 * RMB,
"step-1-32k": 0.015 / 1000 * RMB,
"step-1-128k": 0.040 / 1000 * RMB,
"step-1-256k": 0.095 / 1000 * RMB,
"step-1-flash": 0.001 / 1000 * RMB,
"step-2-16k": 0.038 / 1000 * RMB,
"step-1v-8k": 0.005 / 1000 * RMB,
"step-1v-32k": 0.015 / 1000 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
@@ -187,6 +284,52 @@ var ModelRatio = map[string]float64{
"deepl-zh": 25.0 / 1000 * USD,
"deepl-en": 25.0 / 1000 * USD,
"deepl-ja": 25.0 / 1000 * USD,
// https://console.x.ai/
"grok-beta": 5.0 / 1000 * USD,
// replicate charges based on the number of generated images
// https://replicate.com/pricing
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
"black-forest-labs/flux-canny-dev": 0.025 * USD,
"black-forest-labs/flux-canny-pro": 0.05 * USD,
"black-forest-labs/flux-depth-dev": 0.025 * USD,
"black-forest-labs/flux-depth-pro": 0.05 * USD,
"black-forest-labs/flux-dev": 0.025 * USD,
"black-forest-labs/flux-dev-lora": 0.032 * USD,
"black-forest-labs/flux-fill-dev": 0.04 * USD,
"black-forest-labs/flux-fill-pro": 0.05 * USD,
"black-forest-labs/flux-pro": 0.055 * USD,
"black-forest-labs/flux-redux-dev": 0.025 * USD,
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
"black-forest-labs/flux-schnell": 0.003 * USD,
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
"ideogram-ai/ideogram-v2": 0.08 * USD,
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
"recraft-ai/recraft-v3": 0.04 * USD,
"recraft-ai/recraft-v3-svg": 0.08 * USD,
"stability-ai/stable-diffusion-3": 0.035 * USD,
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
// replicate chat models
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
"meta/llama-2-13b": 0.100 * USD,
"meta/llama-2-13b-chat": 0.100 * USD,
"meta/llama-2-70b": 0.650 * USD,
"meta/llama-2-70b-chat": 0.650 * USD,
"meta/llama-2-7b": 0.050 * USD,
"meta/llama-2-7b-chat": 0.050 * USD,
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
"meta/meta-llama-3-70b": 0.650 * USD,
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
"meta/meta-llama-3-8b": 0.050 * USD,
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
"mistralai/mistral-7b-v0.1": 0.050 * USD,
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
}
var CompletionRatio = map[string]float64{
@@ -195,8 +338,10 @@ var CompletionRatio = map[string]float64{
"llama3-70b-8192(33)": 0.0035 / 0.00265,
}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
var (
DefaultModelRatio map[string]float64
DefaultCompletionRatio map[string]float64
)
func init() {
DefaultModelRatio = make(map[string]float64)
@@ -308,6 +453,9 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
return 4
}
if strings.HasPrefix(name, "gpt-4-turbo") ||
strings.HasPrefix(name, "gpt-4o") ||
strings.HasSuffix(name, "preview") {
@@ -315,6 +463,13 @@ func GetCompletionRatio(name string, channelType int) float64 {
}
return 2
}
// including o1, o1-preview, o1-mini
if strings.HasPrefix(name, "o1") {
return 4
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-3") {
return 5
}
@@ -330,6 +485,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "deepseek-") {
return 2
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.64
@@ -343,6 +499,37 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 3
case "command-r-plus":
return 5
case "grok-beta":
return 3
// Replicate Models
// https://replicate.com/pricing
case "ibm-granite/granite-20b-code-instruct-8k":
return 5
case "ibm-granite/granite-3.0-2b-instruct":
return 8.333333333333334
case "ibm-granite/granite-3.0-8b-instruct",
"ibm-granite/granite-8b-code-instruct-128k":
return 5
case "meta/llama-2-13b",
"meta/llama-2-13b-chat",
"meta/llama-2-7b",
"meta/llama-2-7b-chat",
"meta/meta-llama-3-8b",
"meta/meta-llama-3-8b-instruct":
return 5
case "meta/llama-2-70b",
"meta/llama-2-70b-chat",
"meta/meta-llama-3-70b",
"meta/meta-llama-3-70b-instruct":
return 2.750 / 0.650 // ≈4.230769
case "meta/meta-llama-3.1-405b-instruct":
return 1
case "mistralai/mistral-7b-instruct-v0.2",
"mistralai/mistral-7b-v0.1":
return 5
case "mistralai/mixtral-8x7b-instruct-v0.1":
return 1.000 / 0.300 // ≈3.333333
}
return 1
}

View File

@@ -44,5 +44,9 @@ const (
Doubao
Novita
VertextAI
Proxy
SiliconFlow
XAI
Replicate
Dummy
)

View File

@@ -37,6 +37,10 @@ func ToAPIType(channelType int) int {
apiType = apitype.DeepL
case VertextAI:
apiType = apitype.VertexAI
case Replicate:
apiType = apitype.Replicate
case Proxy:
apiType = apitype.Proxy
}
return apiType

View File

@@ -44,6 +44,10 @@ var ChannelBaseURLs = []string{
"https://ark.cn-beijing.volces.com", // 40
"https://api.novita.ai/v3/openai", // 41
"", // 42
"", // 43
"https://api.siliconflow.cn", // 44
"https://api.x.ai", // 45
"https://api.replicate.com/v1/models/", // 46
}
func init() {

View File

@@ -1,5 +1,6 @@
package role
const (
System = "system"
Assistant = "assistant"
)

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/songquanpeng/one-api/relay/constant/role"
"math"
"net/http"
"strings"
@@ -90,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
@@ -118,7 +119,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error())
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
var extraLog string
if systemPromptReset {
extraLog = " (注意系统提示词已被重置)"
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
@@ -142,15 +147,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
}
return true
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode != http.StatusOK &&
// replicate return 201 to create a task
resp.StatusCode != http.StatusCreated {
return true
}
if meta.ChannelType == channeltype.DeepL {
// skip stream check for deepl
return false
}
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
// requiring the client to request the stream endpoint in the task info
meta.ChannelType != channeltype.Replicate {
return true
}
return false
}
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
if prompt == "" {
return false
}
if len(request.Messages) == 0 {
return false
}
if request.Messages[0].Role == role.System {
request.Messages[0].Content = prompt
logger.Infof(ctx, "rewrite system prompt")
return true
}
request.Messages = append([]relaymodel.Message{{
Role: role.System,
Content: prompt,
}}, request.Messages...)
logger.Infof(ctx, "add system prompt")
return true
}

View File

@@ -22,7 +22,7 @@ import (
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
@@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
return 1
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
@@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
adaptor.Init(meta)
// these adaptors need to convert the request
switch meta.ChannelType {
case channeltype.Ali:
fallthrough
case channeltype.Baidu:
fallthrough
case channeltype.Zhipu:
case channeltype.Zhipu,
channeltype.Ali,
channeltype.Replicate,
channeltype.Baidu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
@@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
var quota int64
switch meta.ChannelType {
case channeltype.Replicate:
// replicate always return 1 image
quota = int64(ratio * imageCostRatio * 1000)
default:
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
}
if userQuota-quota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
defer func(ctx context.Context) {
if resp != nil && resp.StatusCode != http.StatusOK {
if resp != nil &&
resp.StatusCode != http.StatusCreated && // replicate returns 201
resp.StatusCode != http.StatusOK {
return
}

41
relay/controller/proxy.go Normal file
View File

@@ -0,0 +1,41 @@
// Package controller is a package for handling the relay controller
package controller
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// RelayProxyHelper is a helper function to proxy the request to the upstream service
func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(meta)
resp, err := adaptor.DoRequest(c, meta, c.Request.Body)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
return respErr
}
return nil
}

View File

@@ -4,12 +4,14 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
@@ -31,10 +33,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// set system prompt if not empty
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group)
@@ -55,31 +58,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
adaptor.Init(meta)
// get request body
var requestBody io.Reader
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
}
// do request
resp, err := adaptor.DoRequest(c, meta, requestBody)
@@ -100,6 +82,29 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return respErr
}
// post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
return nil
}
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
return c.Request.Body, nil
}
// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}

View File

@@ -18,6 +18,7 @@ type Meta struct {
UserId int
Group string
ModelMapping map[string]string
// BaseURL is the proxy url set in the channel config
BaseURL string
APIKey string
APIType int
@@ -29,6 +30,7 @@ type Meta struct {
ActualModelName string
RequestURLPath string
PromptTokens int // only for DoResponse
SystemPrompt string
}
func GetByContext(c *gin.Context) *Meta {
@@ -45,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta {
BaseURL: c.GetString(ctxkey.BaseURL),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(),
SystemPrompt: c.GetString(ctxkey.SystemPrompt),
}
cfg, ok := c.Get(ctxkey.Config)
if ok {

View File

@@ -3,4 +3,5 @@ package model
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
ContentTypeInputAudio = "input_audio"
)

View File

@@ -2,33 +2,69 @@ package model
type ResponseFormat struct {
Type string `json:"type,omitempty"`
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
}
type JSONSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema map[string]interface{} `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
type Audio struct {
Voice string `json:"voice,omitempty"`
Format string `json:"format,omitempty"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
type GeneralOpenAIRequest struct {
// https://platform.openai.com/docs/api-reference/chat/create
Messages []Message `json:"messages,omitempty"`
Model string `json:"model,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
Store *bool `json:"store,omitempty"`
Metadata any `json:"metadata,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Prediction any `json:"prediction,omitempty"`
Audio *Audio `json:"audio,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
ServiceTier *string `json:"service_tier,omitempty"`
Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
User string `json:"user,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
Functions any `json:"functions,omitempty"`
User string `json:"user,omitempty"`
Prompt any `json:"prompt,omitempty"`
// https://platform.openai.com/docs/api-reference/embeddings/create
Input any `json:"input,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Instruction string `json:"instruction,omitempty"`
// https://platform.openai.com/docs/api-reference/images/create
Prompt any `json:"prompt,omitempty"`
Quality *string `json:"quality,omitempty"`
Size string `json:"size,omitempty"`
Style *string `json:"style,omitempty"`
// Others
Instruction string `json:"instruction,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@@ -11,4 +11,6 @@ const (
AudioSpeech
AudioTranscription
AudioTranslation
// Proxy is a special relay mode for proxying requests to custom upstream
Proxy
)

View File

@@ -24,6 +24,8 @@ func GetByPath(path string) int {
relayMode = AudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = AudioTranslation
} else if strings.HasPrefix(path, "/v1/oneapi/proxy") {
relayMode = Proxy
}
return relayMode
}

View File

@@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth)
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)

View File

@@ -9,6 +9,7 @@ import (
func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
router.Use(middleware.GzipDecodeMiddleware())
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())
@@ -19,6 +20,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.Any("/oneapi/proxy/:channelid/*target", controller.Relay)
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.Relay)

View File

@@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken';
const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
];
const OPEN_LINK_OPTIONS = [
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' }
];
function renderTimestamp(timestamp) {
@@ -60,7 +62,12 @@ const TokensTable = () => {
onOpenLink('next-mj');
}
},
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' },
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
}
];
const columns = [
@@ -177,6 +184,11 @@ const TokensTable = () => {
node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => {
onOpenLink('opencat', record.key);
}
},
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
}
]
}
@@ -382,6 +394,9 @@ const TokensTable = () => {
case 'next-mj':
url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
break;
case 'lobechat':
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
break;
default:
if (!chatLink) {
showError('管理员未设置聊天链接');

View File

@@ -1,10 +1,13 @@
export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
{ key: 33, text: 'AWS', value: 33, color: 'black' },
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
{ key: 28, text: 'Mistral AI', value: 28, color: 'orange' },
{ key: 41, text: 'Novita', value: 41, color: 'purple' },
{ key: 40, text: '字节跳动豆包', value: 40, color: 'blue' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
@@ -17,6 +20,18 @@ export const CHANNEL_OPTIONS = [
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 30, text: 'Ollama', value: 30, color: 'black' },
{ key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 32, text: '阶跃星辰', value: 32, color: 'blue' },
{ key: 34, text: 'Coze', value: 34, color: 'blue' },
{ key: 35, text: 'Cohere', value: 35, color: 'blue' },
{ key: 36, text: 'DeepSeek', value: 36, color: 'black' },
{ key: 37, text: 'Cloudflare', value: 37, color: 'orange' },
{ key: 38, text: 'DeepL', value: 38, color: 'black' },
{ key: 39, text: 'together.ai', value: 39, color: 'blue' },
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

View File

@@ -43,6 +43,7 @@ const EditChannel = (props) => {
base_url: '',
other: '',
model_mapping: '',
system_prompt: '',
models: [],
auto_ban: 1,
groups: ['default']
@@ -63,7 +64,7 @@ const EditChannel = (props) => {
let localModels = [];
switch (value) {
case 14:
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"];
break;
case 11:
localModels = ['PaLM-2'];
@@ -78,7 +79,7 @@ const EditChannel = (props) => {
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0'];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
@@ -304,7 +305,7 @@ const EditChannel = (props) => {
width={isMobile() ? '100%' : 600}
>
<Spin spinning={loading}>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>类型</Typography.Text>
</div>
<Select
@@ -313,12 +314,12 @@ const EditChannel = (props) => {
optionList={CHANNEL_OPTIONS}
value={inputs.type}
onChange={value => handleInputChange('type', value)}
style={{width: '50%'}}
style={{ width: '50%' }}
/>
{
inputs.type === 3 && (
<>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Banner type={"warning"} description={
<>
注意<strong>模型部署名称必须和模型名称保持一致</strong> One API
@@ -329,7 +330,7 @@ const EditChannel = (props) => {
}>
</Banner>
</div>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>AZURE_OPENAI_ENDPOINT</Typography.Text>
</div>
<Input
@@ -342,7 +343,7 @@ const EditChannel = (props) => {
value={inputs.base_url}
autoComplete='new-password'
/>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>默认 API 版本</Typography.Text>
</div>
<Input
@@ -361,7 +362,7 @@ const EditChannel = (props) => {
{
inputs.type === 8 && (
<>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>Base URL</Typography.Text>
</div>
<Input
@@ -376,7 +377,7 @@ const EditChannel = (props) => {
</>
)
}
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>名称</Typography.Text>
</div>
<Input
@@ -389,7 +390,7 @@ const EditChannel = (props) => {
value={inputs.name}
autoComplete='new-password'
/>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>分组</Typography.Text>
</div>
<Select
@@ -410,7 +411,7 @@ const EditChannel = (props) => {
{
inputs.type === 18 && (
<>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>模型版本</Typography.Text>
</div>
<Input
@@ -428,7 +429,7 @@ const EditChannel = (props) => {
{
inputs.type === 21 && (
<>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>知识库 ID</Typography.Text>
</div>
<Input
@@ -444,7 +445,7 @@ const EditChannel = (props) => {
</>
)
}
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>模型</Typography.Text>
</div>
<Select
@@ -460,7 +461,7 @@ const EditChannel = (props) => {
autoComplete='new-password'
optionList={modelOptions}
/>
<div style={{lineHeight: '40px', marginBottom: '12px'}}>
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Space>
<Button type='primary' onClick={() => {
handleInputChange('models', basicModels);
@@ -483,7 +484,7 @@ const EditChannel = (props) => {
}}
/>
</div>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>模型重定向</Typography.Text>
</div>
<TextArea
@@ -496,6 +497,19 @@ const EditChannel = (props) => {
value={inputs.model_mapping}
autoComplete='new-password'
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>系统提示词</Typography.Text>
</div>
<TextArea
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
name='system_prompt'
onChange={value => {
handleInputChange('system_prompt', value)
}}
autosize
value={inputs.system_prompt}
autoComplete='new-password'
/>
<Typography.Text style={{
color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none',
@@ -507,7 +521,7 @@ const EditChannel = (props) => {
}>
填入模板
</Typography.Text>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>密钥</Typography.Text>
</div>
{
@@ -521,7 +535,7 @@ const EditChannel = (props) => {
handleInputChange('key', value)
}}
value={inputs.key}
style={{minHeight: 150, fontFamily: 'JetBrains Mono, Consolas'}}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
:
@@ -537,7 +551,7 @@ const EditChannel = (props) => {
autoComplete='new-password'
/>
}
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>组织</Typography.Text>
</div>
<Input
@@ -549,7 +563,7 @@ const EditChannel = (props) => {
}}
value={inputs.openai_organization}
/>
<div style={{marginTop: 10, display: 'flex'}}>
<div style={{ marginTop: 10, display: 'flex' }}>
<Space>
<Checkbox
name='auto_ban'
@@ -568,7 +582,7 @@ const EditChannel = (props) => {
{
!isEdit && (
<div style={{marginTop: 10, display: 'flex'}}>
<div style={{ marginTop: 10, display: 'flex' }}>
<Space>
<Checkbox
checked={batch}
@@ -584,7 +598,7 @@ const EditChannel = (props) => {
{
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
<>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>代理</Typography.Text>
</div>
<Input
@@ -603,7 +617,7 @@ const EditChannel = (props) => {
{
inputs.type === 22 && (
<>
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>私有部署地址</Typography.Text>
</div>
<Input

Some files were not shown because too many files have changed in this diff Show More