Compare commits

...

18 Commits

Author SHA1 Message Date
Qiying Wang
f636c50c84 fix: duplicate [DONE] (#1629) 2024-07-09 22:43:59 +08:00
Qiying Wang
720fe2dfeb feat: refactor AwsClaude to Aws to support both llama3 and claude (#1601)
* feat: refactor AwsClaude to Aws to support both llama3 and claude

* fix: aws llama3 ratio
2024-07-06 13:19:41 +08:00
Jason
e090e76c86 feat: add Novita AI as model provider (#1609) 2024-07-06 13:16:46 +08:00
open source
6a941748f8 feat: add initial root access token (#1598)
Signed-off-by: xiaobo <peterwillcn@gmail.com>
2024-07-06 13:15:17 +08:00
open source
46a0773580 fix: update readme docs (#1599)
Signed-off-by: xiaobo <peterwillcn@gmail.com>
2024-07-06 13:14:32 +08:00
zijiren
ffdb0b0c81 fix: use musl libc (#1597) 2024-07-06 13:14:07 +08:00
zijiren
efd30a40b3 feat: cloudflare support native openai api (#1596) 2024-07-06 13:12:30 +08:00
Qiying Wang
d7a78f3397 feat: support test specific model (#1600) 2024-07-05 18:05:16 +08:00
Leo Q
273be55797 feat(ui): show available models for air theme (#1595)
* feat(ui): air 主题显示可用模型

* chore: 改为全角括号
2024-07-04 08:35:41 +08:00
Leo Q
ec6ad24810 feat: support smtp without auth (#1101) 2024-07-03 22:23:49 +08:00
LinZeliang
c4fe57c165 feat: support one or more log file (#1400)
Co-authored-by: Laisky.Cai <github@laisky.com>
2024-07-03 20:53:29 +08:00
igophper
274fcf3d76 refactor: init db (#1590)
Co-authored-by: 江杭辉 <jianghanghui@k.app>
2024-07-03 20:50:40 +08:00
Mikey
0fc07ea558 feat: add support for Claude 3 tool use (function calling) (#1587)
* feat: add tool support for AWS & Claude

* fix: add {} for openai compatibility in streaming tool_use
2024-07-02 00:12:01 +08:00
Leo Q
1ce1e529ee ci: skip archive, upload directly (#1586) 2024-07-02 00:05:47 +08:00
Darkside
d936817de9 docs: add related projects (#1562)
Co-authored-by: 成达 <chengda.615@bytedance.com>
2024-06-30 19:57:30 +08:00
igophper
fecaece71b fix: fix size not support during image generation (#1564)
Fixes #1224, #1068
2024-06-30 19:52:33 +08:00
Shi Jilin
c135d74f13 feat: support Spark4.0 Ultra (#1575)
* fix: fix SparkDesk Function Call (修复 Spark Pro/Max函数调用只会返回普通对话回答而不是Function Call回答的问题

* feat: support Spark4.0 Ultra
2024-06-30 19:38:02 +08:00
lihangfu
d0369b114f feat: support spark4.0 ultra (#1569)
* feat: 支持v3最新协议的腾讯混元(#1452)

* feat: 支持Spark4.0 Ultra

---------

Co-authored-by: lihangfu <hfli8@iflytek.com>
2024-06-30 19:37:07 +08:00
50 changed files with 1318 additions and 466 deletions

View File

@@ -36,21 +36,9 @@ jobs:
# in the next step as well as the next job.
- name: Test
run: go test -cover -coverprofile=coverage.txt ./...
- name: Archive code coverage results
uses: actions/upload-artifact@v4
- uses: codecov/codecov-action@v4
with:
name: code-coverage
path: coverage.txt # Make sure to use the same file name you chose for the "-coverprofile" in the "Test" step
code_coverage:
name: "Code coverage report"
runs-on: ubuntu-latest
needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job
steps:
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
token: ${{ secrets.CODECOV_TOKEN }}
commit_lint:
runs-on: ubuntu-latest

View File

@@ -16,7 +16,9 @@ WORKDIR /web/air
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2
FROM golang:alpine AS builder2
RUN apk add --no-cache g++
ENV GO111MODULE=on \
CGO_ENABLED=1 \
@@ -27,7 +29,7 @@ ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine

View File

@@ -101,7 +101,7 @@ Nginx reference configuration:
```
server{
server_name openai.justsong.cn; # Modify your domain name accordingly
location / {
client_max_body_size 64m;
proxy_http_version 1.1;
@@ -132,12 +132,12 @@ The initial account username is `root` and password is `123456`.
1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source:
```shell
git clone https://github.com/songquanpeng/one-api.git
# Build the frontend
cd one-api/web/default
npm install
npm run build
# Build the backend
cd ../..
go mod download
@@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'.
7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60`
7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'.
+Example: ` BATCH_UPDATE_ENABLED=true`
+If you encounter an issue with too many database connections, you can try enabling this option.
13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'.
+Example: ` BATCH_UPDATE_INTERVAL=5`
14. Request frequency limit:
+ `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180.
+ `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60.
15. Encoder cache settings:
+`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment.
+`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it.
16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set.
17. `RELAY_PROXY`: After setting up, use this proxy to request APIs.
18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds.
19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images.
20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'.
21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default.
22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'.
23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md).
24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'.
25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'.
26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'.
27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time.
28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time.
### Command Line Parameters
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.
@@ -287,7 +312,9 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Double-check that your interface address and API Key are correct.
## Related Projects
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller.
* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization.
## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.

View File

@@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
> [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
>
>
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!WARNING]
@@ -88,6 +88,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -144,7 +145,7 @@ Nginx 的参考配置:
```
server{
server_name openai.justsong.cn; # 请根据实际情况修改你的域名
location / {
client_max_body_size 64m;
proxy_http_version 1.1;
@@ -189,12 +190,12 @@ docker-compose ps
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell
git clone https://github.com/songquanpeng/one-api.git
# 构建前端
cd one-api/web/default
npm install
npm run build
# 构建后端
cd ../..
go mod download
@@ -321,7 +322,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashbo
例如对于 OpenAI 的官方库:
```bash
OPENAI_API_KEY="sk-xxxxxx"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
```
```mermaid
@@ -369,33 +370,34 @@ graph LR
+ 例子:`NODE_TYPE=slave`
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
11. 例子:`CHANNEL_TEST_FREQUENCY=1440`
12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+例子:`CHANNEL_TEST_FREQUENCY=1440`
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
15. 请求频率限制:
14. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
16. 编码器缓存设置:
15. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
18. `RELAY_PROXY`:设置后使用该代理来请求 API。
19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
21. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
22. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
23. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`。
24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
17. `RELAY_PROXY`:设置后使用该代理来请求 API。
18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
20. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
21. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
22. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`。
23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -448,6 +450,8 @@ https://openai.justsong.cn
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。
* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。
## 注意

View File

@@ -143,8 +143,13 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

View File

@@ -27,7 +27,12 @@ var setupLogOnce sync.Once
func SetupLogger() {
setupLogOnce.Do(func() {
if LogDir != "" {
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
var logPath string
if config.OnlyOneLogFile {
logPath = filepath.Join(LogDir, "oneapi.log")
} else {
logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
}
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")

View File

@@ -6,11 +6,16 @@ import (
"encoding/base64"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"net"
"net/smtp"
"strings"
"time"
)
func shouldAuth() bool {
return config.SMTPAccount != "" || config.SMTPToken != ""
}
func SendEmail(subject string, receiver string, content string) error {
if receiver == "" {
return fmt.Errorf("receiver is empty")
@@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error {
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: config.SMTPServer,
if config.SMTPPort == 465 || !shouldAuth() {
// need advanced client
var conn net.Conn
var err error
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: config.SMTPServer,
}
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
} else {
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil {
return err
}
@@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error {
return err
}
defer client.Close()
if err = client.Auth(auth); err != nil {
return err
if shouldAuth() {
if err = client.Auth(auth); err != nil {
return err
}
}
if err = client.Mail(config.SMTPFrom); err != nil {
return err

View File

@@ -14,6 +14,7 @@ import (
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
@@ -27,15 +28,15 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/gin-gonic/gin"
)
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
if model == "" {
model = "gpt-3.5-turbo"
}
testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2,
Stream: false,
Model: "gpt-3.5-turbo",
Model: model,
}
testMessage := relaymodel.Message{
Role: "user",
@@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
return testRequest
}
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
@@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
var modelName string
modelList := adaptor.GetModelList()
modelName := request.Model
modelMap := channel.GetModelMapping()
if len(modelList) != 0 {
modelName = modelList[0]
}
if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
@@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
modelName = modelMap[modelName]
}
}
request := buildTestRequest()
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil {
return err, nil
@@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) {
})
return
}
model := c.Query("model")
testRequest := buildTestRequest(model)
tik := time.Now()
err, _ = testChannel(channel)
err, _ = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
milliseconds = 0
}
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
@@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) {
"success": false,
"message": err.Error(),
"time": consumedTime,
"model": model,
})
return
}
@@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) {
"success": true,
"message": "",
"time": consumedTime,
"model": model,
})
return
}
@@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel)
testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else {

2
go.mod
View File

@@ -68,7 +68,7 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect

4
go.sum
View File

@@ -110,8 +110,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

22
main.go
View File

@@ -27,27 +27,19 @@ func main() {
common.Init()
logger.SetupLogger()
logger.SysLogf("One API %s started", common.Version)
if os.Getenv("GIN_MODE") != "debug" {
if os.Getenv("GIN_MODE") != gin.DebugMode {
gin.SetMode(gin.ReleaseMode)
}
if config.DebugEnabled {
logger.SysLog("running in debug mode")
}
var err error
// Initialize SQL Database
model.DB, err = model.InitDB("SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize database: " + err.Error())
}
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
model.InitDB()
model.InitLogDB()
var err error
err = model.CreateRootAccountIfNeed()
if err != nil {
logger.FatalLog("database init error: " + err.Error())

View File

@@ -1,6 +1,7 @@
package model
import (
"database/sql"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
@@ -29,13 +30,17 @@ func CreateRootAccountIfNeed() error {
if err != nil {
return err
}
accessToken := random.GetUUID()
if config.InitialRootAccessToken != "" {
accessToken = config.InitialRootAccessToken
}
rootUser := User{
Username: "root",
Password: hashedPassword,
Role: RoleRootUser,
Status: UserStatusEnabled,
DisplayName: "Root User",
AccessToken: random.GetUUID(),
AccessToken: accessToken,
Quota: 500000000000000,
}
DB.Create(&rootUser)
@@ -60,90 +65,156 @@ func CreateRootAccountIfNeed() error {
}
func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" {
dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
dsn := os.Getenv(envName)
switch {
case strings.HasPrefix(dsn, "postgres://"):
// Use PostgreSQL
return openPostgreSQL(dsn)
case dsn != "":
// Use MySQL
logger.SysLog("using MySQL as database")
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
return openMySQL(dsn)
default:
// Use SQLite
return openSQLite()
}
// Use SQLite
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
}
func openPostgreSQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
func InitDB(envName string) (db *gorm.DB, err error) {
db, err = chooseDB(envName)
if err == nil {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
func openMySQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using MySQL as database")
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
if !config.IsMasterNode {
return db, err
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&User{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return nil, err
}
logger.SysLog("database migrated")
return db, err
} else {
logger.FatalLog(err)
func openSQLite() (*gorm.DB, error) {
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
func InitDB() {
var err error
DB, err = chooseDB("SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize database: " + err.Error())
return
}
return db, err
sqlDB := setDBConns(DB)
if !config.IsMasterNode {
return
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
if err = migrateDB(); err != nil {
logger.FatalLog("failed to migrate database: " + err.Error())
return
}
logger.SysLog("database migrated")
}
func migrateDB() error {
var err error
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Token{}); err != nil {
return err
}
if err = DB.AutoMigrate(&User{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Option{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Redemption{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Ability{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Log{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
return nil
}
func InitLogDB() {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
logger.SysLog("using secondary database for table logs")
var err error
LOG_DB, err = chooseDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
return
}
setDBConns(LOG_DB)
if !config.IsMasterNode {
return
}
logger.SysLog("secondary database migration started")
err = migrateLOGDB()
if err != nil {
logger.FatalLog("failed to migrate secondary database: " + err.Error())
return
}
logger.SysLog("secondary database migrated")
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func setDBConns(db *gorm.DB) *sql.DB {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
logger.FatalLog("failed to connect database: " + err.Error())
return nil
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
return sqlDB
}
func closeDB(db *gorm.DB) error {

View File

@@ -29,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string {
return "stop"
case "max_tokens":
return "length"
case "tool_use":
return "tool_calls"
default:
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
}
claudeRequest := Request{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
@@ -42,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
Tools: claudeTools,
}
if len(claudeTools) > 0 {
claudeToolChoice := struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
if function, ok := choice["function"].(map[string]any); ok {
claudeToolChoice.Type = "tool"
claudeToolChoice.Name = function["name"].(string)
}
} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
if toolChoiceType == "any" {
claudeToolChoice.Type = toolChoiceType
}
}
claudeRequest.ToolChoice = claudeToolChoice
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
@@ -64,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
if message.Role == "tool" {
claudeMessage.Role = "user"
content.Type = "tool_result"
content.Content = content.Text
content.Text = ""
content.ToolUseId = message.ToolCallId
}
claudeMessage.Content = append(claudeMessage.Content, content)
for i := range message.ToolCalls {
inputParam := make(map[string]any)
_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
claudeMessage.Content = append(claudeMessage.Content, Content{
Type: "tool_use",
Id: message.ToolCalls[i].Id,
Name: message.ToolCalls[i].Function.Name,
Input: inputParam,
})
}
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue
}
@@ -97,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
var response *Response
var responseText string
var stopReason string
tools := make([]model.Tool, 0)
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, model.Tool{
Id: claudeResponse.ContentBlock.Id,
Type: "function",
Function: model.Function{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, model.Tool{
Function: model.Function{
Arguments: claudeResponse.Delta.PartialJson,
},
})
}
}
case "message_delta":
if claudeResponse.Usage != nil {
@@ -120,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" {
@@ -136,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]model.Tool, 0)
for _, v := range claudeResponse.Content {
if v.Type == "tool_use" {
args, _ := json.Marshal(v.Input)
tools = append(tools, model.Tool{
Id: v.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: model.Function{
Name: v.Name,
Arguments: string(args),
},
})
}
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: responseText,
Name: nil,
Role: "assistant",
Content: responseText,
Name: nil,
ToolCalls: tools,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
@@ -176,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
var usage model.Usage
var modelName string
var id string
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
for scanner.Scan() {
data := scanner.Text()
@@ -196,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
continue
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
continue
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
}
}
if response == nil {
continue
@@ -207,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Id = id
response.Model = modelName
response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())

View File

@@ -16,6 +16,12 @@ type Content struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
}
type Message struct {
@@ -23,6 +29,18 @@ type Message struct {
Content []Content `json:"content"`
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
@@ -33,6 +51,8 @@ type Request struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
//Metadata `json:"metadata,omitempty"`
}
@@ -61,6 +81,7 @@ type Response struct {
type Delta struct {
Type string `json:"type"`
Text string `json:"text"`
PartialJson string `json:"partial_json,omitempty"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
}

View File

@@ -1,17 +1,16 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -19,18 +18,52 @@ import (
var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct {
meta *meta.Meta
awsClient *bedrockruntime.Client
awsAdapter utils.AwsAdapter
Meta *meta.Meta
AwsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
a.awsClient = bedrockruntime.New(bedrockruntime.Options{
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
adaptor := GetAdaptor(request.Model)
if adaptor == nil {
return nil, errors.New("adaptor not found")
}
a.awsAdapter = adaptor
return adaptor.ConvertRequest(c, relayMode, request)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if a.awsAdapter == nil {
return nil, utils.WrapErr(errors.New("awsAdapter is nil"))
}
return a.awsAdapter.DoResponse(c, a.AwsClient, meta)
}
func (a *Adaptor) GetModelList() (models []string) {
for model := range adaptors {
models = append(models, model)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
@@ -39,17 +72,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
@@ -60,23 +82,3 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, a.awsClient)
} else {
err, usage = Handler(c, a.awsClient, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}

View File

@@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"io"
"net/http"
@@ -16,23 +15,17 @@ import (
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{
var AwsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
@@ -43,7 +36,7 @@ var awsModelIDMap = map[string]string{
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
@@ -53,7 +46,7 @@ func awsModelID(requestModel string) (string, error) {
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
@@ -64,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
return utils.WrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
@@ -107,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
@@ -118,7 +111,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
return utils.WrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)
@@ -126,16 +119,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
@@ -143,6 +136,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
var id string
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
@@ -163,8 +158,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
}
}
if response == nil {
return true
@@ -172,6 +178,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
response.Id = id
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())

View File

@@ -9,9 +9,12 @@ type Request struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}

View File

@@ -0,0 +1,37 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
llamaReq := ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, llamaReq)
return llamaReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@@ -0,0 +1,231 @@
// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"text/template"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/random"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// Only support llama-3-8b and llama-3-70b instruction models
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var AwsModelIDMap = map[string]string{
"llama3-8b-8192": "meta.llama3-8b-instruct-v1:0",
"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0",
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
// promptTemplate with range
const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|>
`
var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate))
func RenderPrompt(messages []relaymodel.Message) string {
var buf bytes.Buffer
err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages})
if err != nil {
logger.SysError("error rendering prompt messages: " + err.Error())
}
return buf.String()
}
func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
llamaRequest := Request{
MaxGenLen: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
}
if llamaRequest.MaxGenLen == 0 {
llamaRequest.MaxGenLen = 2048
}
prompt := RenderPrompt(textRequest.Messages)
llamaRequest.Prompt = prompt
return &llamaRequest
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
}
var llamaResponse Response
err = json.Unmarshal(awsResp.Body, &llamaResponse)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := ResponseLlama2OpenAI(&llamaResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: llamaResponse.PromptTokenCount,
CompletionTokens: llamaResponse.GenerationTokenCount,
TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
var responseText string
if len(llamaResponse.Generation) > 0 {
responseText = llamaResponse.Generation
}
choice := openai.TextResponseChoice{
Index: 0,
Message: relaymodel.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: llamaResponse.StopReason,
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
var llamaResp StreamResponse
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
if llamaResp.PromptTokenCount > 0 {
usage.PromptTokens = llamaResp.PromptTokenCount
}
if llamaResp.StopReason == "stop" {
usage.CompletionTokens = llamaResp.GenerationTokenCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
response := StreamResponseLlama2OpenAI(&llamaResp)
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}
func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = llamaResponse.Generation
choice.Delta.Role = "assistant"
finishReason := llamaResponse.StopReason
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse
}

View File

@@ -0,0 +1,45 @@
package aws_test
import (
"testing"
aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/stretchr/testify/assert"
)
func TestRenderPrompt(t *testing.T) {
messages := []relaymodel.Message{
{
Role: "user",
Content: "What's your name?",
},
}
prompt := aws.RenderPrompt(messages)
expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
messages = []relaymodel.Message{
{
Role: "system",
Content: "Your name is Kat. You are a detective.",
},
{
Role: "user",
Content: "What's your name?",
},
{
Role: "assistant",
Content: "Kat",
},
{
Role: "user",
Content: "What's your job?",
},
}
prompt = aws.RenderPrompt(messages)
expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
}

View File

@@ -0,0 +1,29 @@
package aws
// Request is the request to AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Request struct {
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
}
// Response is the response from AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Response struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}
// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None}
type StreamResponse struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}

View File

@@ -0,0 +1,39 @@
package aws
import (
claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude"
llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
)
type AwsModelType int
const (
AwsClaude AwsModelType = iota + 1
AwsLlama3
)
var (
adaptors = map[string]AwsModelType{}
)
func init() {
for model := range claude.AwsModelIDMap {
adaptors[model] = AwsClaude
}
for model := range llama3.AwsModelIDMap {
adaptors[model] = AwsLlama3
}
}
func GetAdaptor(model string) utils.AwsAdapter {
adaptorType := adaptors[model]
switch adaptorType {
case AwsClaude:
return &claude.Adaptor{}
case AwsLlama3:
return &llama3.Adaptor{}
default:
return nil
}
}

View File

@@ -0,0 +1,51 @@
package utils
import (
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type AwsAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
}
type Adaptor struct {
Meta *meta.Meta
AwsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}

View File

@@ -0,0 +1,16 @@
package utils
import (
"net/http"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func WrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: err.Error(),
},
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
@@ -28,7 +29,14 @@ func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
@@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
switch relayMode {
case relaymode.Completions:
return ConvertCompletionsRequest(*request), nil
case relaymode.ChatCompletions, relaymode.Embeddings:
return request, nil
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {

View File

@@ -3,11 +3,13 @@ package cloudflare
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
@@ -16,57 +18,23 @@ import (
"github.com/songquanpeng/one-api/relay/model"
)
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
var promptBuilder strings.Builder
for _, message := range textRequest.Messages {
promptBuilder.WriteString(message.StringContent())
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
}
func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
p, _ := textRequest.Prompt.(string)
return &Request{
Prompt: p,
MaxTokens: textRequest.MaxTokens,
Prompt: promptBuilder.String(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cloudflareResponse.Result.Response,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
Created: helper.GetTimestamp(),
}
return &openaiResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
responseModel := c.GetString("original_model")
responseModel := c.GetString(ctxkey.OriginalModel)
var responseText string
for scanner.Scan() {
@@ -77,22 +45,22 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
var cloudflareResponse StreamResponse
err := json.Unmarshal([]byte(data), &cloudflareResponse)
if data == "[DONE]" {
break
}
var response openai.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
if response == nil {
continue
for _, v := range response.Choices {
v.Delta.Role = "assistant"
responseText += v.Delta.StringContent()
}
responseText += cloudflareResponse.Response
response.Id = id
response.Model = responseModel
response.Model = modelName
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
@@ -123,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cloudflareResponse Response
err = json.Unmarshal(responseBody, &cloudflareResponse)
var response openai.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse)
fullTextResponse.Model = modelName
usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens)
fullTextResponse.Usage = *usage
fullTextResponse.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(fullTextResponse)
response.Model = modelName
var responseText string
for _, v := range response.Choices {
responseText += v.Message.Content.(string)
}
usage := openai.ResponseText2Usage(responseText, modelName, promptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -1,25 +1,13 @@
package cloudflare
import "github.com/songquanpeng/one-api/relay/model"
type Request struct {
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
type Result struct {
Response string `json:"response"`
}
type Response struct {
Result Result `json:"result"`
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
}
type StreamResponse struct {
Response string `json:"response"`
Messages []model.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}

View File

@@ -0,0 +1,19 @@
package novita
// https://novita.ai/llm-api
var ModelList = []string{
"meta-llama/llama-3-8b-instruct",
"meta-llama/llama-3-70b-instruct",
"nousresearch/hermes-2-pro-llama-3-8b",
"nousresearch/nous-hermes-llama2-13b",
"mistralai/mistral-7b-instruct",
"cognitivecomputations/dolphin-mixtral-8x22b",
"sao10k/l3-70b-euryale-v2.1",
"sophosympatheia/midnight-rose-70b",
"gryphe/mythomax-l2-13b",
"Nous-Hermes-2-Mixtral-8x7B-DPO",
"lzlv_70b",
"teknium/openhermes-2.5-mistral-7b",
"microsoft/wizardlm-2-8x22b",
}

View File

@@ -0,0 +1,15 @@
package novita
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode)
}

View File

@@ -3,17 +3,19 @@ package openai
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
)
type Adaptor struct {
@@ -48,6 +50,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return minimax.GetRequestURL(meta)
case channeltype.Doubao:
return doubao.GetRequestURL(meta)
case channeltype.Novita:
return novita.GetRequestURL(meta)
default:
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/channeltype"
@@ -28,6 +29,7 @@ var CompatibleChannels = []int{
channeltype.StepFun,
channeltype.DeepSeek,
channeltype.TogetherAI,
channeltype.Novita,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -56,6 +58,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "together.ai", togetherai.ModelList
case channeltype.Doubao:
return "doubao", doubao.ModelList
case channeltype.Novita:
return "novita", novita.ModelList
default:
return "openai", ModelList
}

View File

@@ -4,11 +4,12 @@ import (
"bufio"
"bytes"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
@@ -31,6 +32,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() {
data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format
@@ -41,6 +43,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data)
doneRendered = true
continue
}
switch relayMode {
@@ -81,7 +84,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
if !doneRendered {
render.Done(c)
}
err := resp.Body.Close()
if err != nil {

View File

@@ -6,4 +6,5 @@ var ModelList = []string{
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.5",
"SparkDesk-v4.0",
}

View File

@@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
if strings.HasPrefix(domain, "generalv3") {
if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
functions := make([]model.Function, len(request.Tools))
for i, tool := range request.Tools {
functions[i] = tool.Function
@@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3"
case "v3.5":
return "generalv3.5"
case "v4.0":
return "4.0Ultra"
}
return "general" + apiVersion
}

View File

@@ -2,6 +2,7 @@ package ratio
import (
"encoding/json"
"fmt"
"strings"
"github.com/songquanpeng/one-api/common/logger"
@@ -125,6 +126,7 @@ var ModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
@@ -168,6 +170,9 @@ var ModelRatio = map[string]float64{
"step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
// https://cohere.com/pricing
"command": 0.5,
"command-nightly": 0.5,
@@ -184,7 +189,11 @@ var ModelRatio = map[string]float64{
"deepl-ja": 25.0 / 1000 * USD,
}
var CompletionRatio = map[string]float64{}
var CompletionRatio = map[string]float64{
// aws llama3
"llama3-8b-8192(33)": 0.0006 / 0.0003,
"llama3-70b-8192(33)": 0.0035 / 0.00265,
}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
@@ -233,22 +242,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
}
func GetModelRatio(name string) float64 {
func GetModelRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name]
if !ok {
ratio, ok = DefaultModelRatio[name]
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := ModelRatio[model]; ok {
return ratio
}
if !ok {
logger.SysError("model ratio not found: " + name)
return 30
if ratio, ok := DefaultModelRatio[model]; ok {
return ratio
}
return ratio
if ratio, ok := ModelRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[name]; ok {
return ratio
}
logger.SysError("model ratio not found: " + name)
return 30
}
func CompletionRatio2JSONString() string {
@@ -264,7 +279,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
func GetCompletionRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := CompletionRatio[model]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[model]; ok {
return ratio
}
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}

View File

@@ -42,5 +42,6 @@ const (
DeepL
TogetherAI
Doubao
Novita
Dummy
)

View File

@@ -42,6 +42,7 @@ var ChannelBaseURLs = []string{
"https://api-free.deepl.com", // 38
"https://api.together.xyz", // 39
"https://ark.cn-beijing.volces.com", // 40
"https://api.novita.ai/v3/openai", // 41
}
func init() {

View File

@@ -7,6 +7,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
@@ -21,9 +25,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
)
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
modelRatio := billingratio.GetModelRatio(audioModel)
modelRatio := billingratio.GetModelRatio(audioModel, channelType)
groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int64

View File

@@ -4,6 +4,10 @@ import (
"context"
"errors"
"fmt"
"math"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
@@ -16,9 +20,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
"strings"
)
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
@@ -40,78 +41,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
return textRequest, nil
}
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := billingratio.ImageSizeRatios[model][size]
if !ok {
return 1
}
return ratio
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// model validation
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
return nil
}
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
}
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case relaymode.ChatCompletions:
@@ -167,7 +96,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
return
}
var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))

View File

@@ -6,7 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
@@ -16,17 +20,86 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func isWithinRange(element string, value int) bool {
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
return false
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
min := billingratio.ImageGenerationAmounts[element][0]
max := billingratio.ImageGenerationAmounts[element][1]
return value >= min && value <= max
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func isValidImagePromptLength(model string, promptLength int) bool {
maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
return !ok || promptLength <= maxPromptLength
}
func isWithinRange(element string, value int) bool {
amounts, ok := billingratio.ImageGenerationAmounts[element]
return !ok || (value >= amounts[0] && value <= amounts[1])
}
func getImageSizeRatio(model string, size string) float64 {
if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
return ratio
}
return 1
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// model validation
if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
return nil
}
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
}
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@@ -94,7 +167,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billingratio.GetModelRatio(imageModel)
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)

View File

@@ -4,6 +4,9 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
@@ -14,8 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
@@ -35,7 +36,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model)
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
// pre-consume quota

View File

@@ -1,10 +1,11 @@
package model
type Message struct {
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"`
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
func (m Message) IsStringContent() bool {

View File

@@ -2,13 +2,13 @@ package model
type Tool struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
Function Function `json:"function"`
}
type Function struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty
Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response
}

View File

@@ -47,7 +47,7 @@ const PersonalSetting = () => {
const [countdown, setCountdown] = useState(30);
const [affLink, setAffLink] = useState('');
const [systemToken, setSystemToken] = useState('');
// const [models, setModels] = useState([]);
const [models, setModels] = useState([]);
const [openTransfer, setOpenTransfer] = useState(false);
const [transferAmount, setTransferAmount] = useState(0);
@@ -72,7 +72,7 @@ const PersonalSetting = () => {
console.log(userState);
}
);
// loadModels().then();
loadModels().then();
getAffLink().then();
setTransferAmount(getQuotaPerUnit());
}, []);
@@ -127,16 +127,16 @@ const PersonalSetting = () => {
}
};
// const loadModels = async () => {
// let res = await API.get(`/api/user/models`);
// const { success, message, data } = res.data;
// if (success) {
// setModels(data);
// console.log(data);
// } else {
// showError(message);
// }
// };
const loadModels = async () => {
let res = await API.get(`/api/user/available_models`);
const { success, message, data } = res.data;
if (success) {
setModels(data);
console.log(data);
} else {
showError(message);
}
};
const handleAffLinkClick = async (e) => {
e.target.select();
@@ -344,7 +344,7 @@ const PersonalSetting = () => {
}
>
<Typography.Title heading={6}>调用信息</Typography.Title>
{/* <Typography.Title heading={6}>可用模型</Typography.Title>
<p>可用模型可点击复制</p>
<div style={{ marginTop: 10 }}>
<Space wrap>
{models.map((model) => (
@@ -355,7 +355,7 @@ const PersonalSetting = () => {
</Tag>
))}
</Space>
</div> */}
</div>
</Card>
{/* <Card
footer={

View File

@@ -78,7 +78,7 @@ const EditChannel = (props) => {
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'];
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];

View File

@@ -13,7 +13,7 @@ export const CHANNEL_OPTIONS = {
},
33: {
key: 33,
text: 'AWS Claude',
text: 'AWS',
value: 33,
color: 'primary'
},
@@ -161,6 +161,12 @@ export const CHANNEL_OPTIONS = {
value: 39,
color: 'primary'
},
41: {
key: 41,
text: 'Novita',
value: 41,
color: 'purple'
},
8: {
key: 8,
text: '自定义渠道',

View File

@@ -91,7 +91,7 @@ const typeConfig = {
other: '版本号'
},
input: {
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
},
prompt: {
key: '按照如下格式输入APPID|APISecret|APIKey',

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import {
API,
@@ -70,13 +70,33 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/?p=${startIdx}`);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setChannels(data);
} else {
let newChannels = [...channels];
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
setChannels(newChannels);
}
let localChannels = data.map((channel) => {
if (channel.models === '') {
channel.models = [];
channel.test_model = "";
} else {
channel.models = channel.models.split(',');
if (channel.models.length > 0) {
channel.test_model = channel.models[0];
}
channel.model_options = channel.models.map((model) => {
return {
key: model,
text: model,
value: model,
}
})
console.log('channel', channel)
}
return channel;
});
if (startIdx === 0) {
setChannels(localChannels);
} else {
let newChannels = [...channels];
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels);
setChannels(newChannels);
}
} else {
showError(message);
}
@@ -225,19 +245,31 @@ const ChannelsTable = () => {
setSearching(false);
};
const testChannel = async (id, name, idx) => {
const res = await API.get(`/api/channel/test/${id}/`);
const { success, message, time } = res.data;
const switchTestModel = async (idx, model) => {
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].test_model = model;
setChannels(newChannels);
};
const testChannel = async (id, name, idx, m) => {
const res = await API.get(`/api/channel/test/${id}?model=${m}`);
const { success, message, time, model } = res.data;
if (success) {
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
showInfo(`渠道 ${name} 测试成功,模型 ${model}耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
}
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
};
const testChannels = async (scope) => {
@@ -405,6 +437,7 @@ const ChannelsTable = () => {
>
优先级
</Table.HeaderCell>
<Table.HeaderCell>测试模型</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row>
</Table.Header>
@@ -459,13 +492,24 @@ const ChannelsTable = () => {
basic
/>
</Table.Cell>
<Table.Cell>
<Dropdown
placeholder='请选择测试模型'
selection
options={channel.model_options}
defaultValue={channel.test_model}
onChange={(event, data) => {
switchTestModel(idx, data.value);
}}
/>
</Table.Cell>
<Table.Cell>
<div>
<Button
size={'small'}
positive
onClick={() => {
testChannel(channel.id, channel.name, idx);
testChannel(channel.id, channel.name, idx, channel.test_model);
}}
>
测试

View File

@@ -1,11 +1,12 @@
export const CHANNEL_OPTIONS = [
{key: 1, text: 'OpenAI', value: 1, color: 'green'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'black'},
{key: 33, text: 'AWS Claude', value: 33, color: 'black'},
{key: 33, text: 'AWS', value: 33, color: 'black'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'},
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange'},
{key: 24, text: 'Google Gemini', value: 24, color: 'orange'},
{key: 28, text: 'Mistral AI', value: 28, color: 'orange'},
{key: 41, text: 'Novita', value: 41, color: 'purple'},
{key: 40, text: '字节跳动豆包', value: 40, color: 'blue'},
{key: 15, text: '百度文心千帆', value: 15, color: 'blue'},
{key: 17, text: '阿里通义千问', value: 17, color: 'orange'},