Compare commits

..

2 Commits

Author SHA1 Message Date
Ian Li
43368e68c8 Merge 3462bd538c into c4fe57c165 2024-07-03 21:04:26 +08:00
Ian Li
3462bd538c feat: support LOGIN as SMTP authentication method. 2024-03-16 16:29:10 +08:00
42 changed files with 302 additions and 892 deletions

View File

@@ -16,9 +16,7 @@ WORKDIR /web/air
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang:alpine AS builder2
RUN apk add --no-cache g++
FROM golang AS builder2
ENV GO111MODULE=on \
CGO_ENABLED=1 \
@@ -29,7 +27,7 @@ ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /web/build ./web/build
RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine

View File

@@ -245,41 +245,16 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'.
7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60`
8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave`
9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440`
11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'.
+Example: ` BATCH_UPDATE_ENABLED=true`
+If you encounter an issue with too many database connections, you can try enabling this option.
13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'.
+Example: ` BATCH_UPDATE_INTERVAL=5`
14. Request frequency limit:
+ `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180.
+ `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60.
15. Encoder cache settings:
+`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment.
+`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it.
16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set.
17. `RELAY_PROXY`: After setting up, use this proxy to request APIs.
18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds.
19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images.
20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'.
21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default.
22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'.
23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md).
24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'.
25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'.
26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'.
27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time.
28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time.
### Command Line Parameters
1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`.

View File

@@ -88,7 +88,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -370,34 +369,33 @@ graph LR
+ 例子:`NODE_TYPE=slave`
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+例子:`CHANNEL_TEST_FREQUENCY=1440`
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
11. 例子:`CHANNEL_TEST_FREQUENCY=1440`
12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
14. 请求频率限制:
15. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
15. 编码器缓存设置:
16. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
17. `RELAY_PROXY`:设置后使用该代理来请求 API。
18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
20. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
21. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
22. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`。
23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
18. `RELAY_PROXY`:设置后使用该代理来请求 API。
19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
21. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
22. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
23. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`。
24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -63,6 +63,7 @@ var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var SMTPAuthLoginEnabled = false
var GitHubClientId = ""
var GitHubClientSecret = ""
@@ -143,8 +144,6 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1")

View File

@@ -4,16 +4,38 @@ import (
"crypto/rand"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"net"
"net/smtp"
"strings"
"time"
)
func shouldAuth() bool {
return config.SMTPAccount != "" || config.SMTPToken != ""
type loginAuth struct {
username, password string
}
func LoginAuth(username, password string) smtp.Auth {
return &loginAuth{username, password}
}
func (a *loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte(a.username), nil
}
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown command from server during login auth")
}
}
return nil, nil
}
func SendEmail(subject string, receiver string, content string) error {
@@ -46,24 +68,21 @@ func SendEmail(subject string, receiver string, content string) error {
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
var auth smtp.Auth
if config.SMTPAuthLoginEnabled {
auth = LoginAuth(config.SMTPAccount, config.SMTPToken)
} else {
auth = smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
}
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
if config.SMTPPort == 465 || !shouldAuth() {
// need advanced client
var conn net.Conn
var err error
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: config.SMTPServer,
}
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
} else {
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: config.SMTPServer,
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil {
return err
}
@@ -72,10 +91,8 @@ func SendEmail(subject string, receiver string, content string) error {
return err
}
defer client.Close()
if shouldAuth() {
if err = client.Auth(auth); err != nil {
return err
}
if err = client.Auth(auth); err != nil {
return err
}
if err = client.Mail(config.SMTPFrom); err != nil {
return err

View File

@@ -14,7 +14,6 @@ import (
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
@@ -28,15 +27,15 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/gin-gonic/gin"
)
func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
if model == "" {
model = "gpt-3.5-turbo"
}
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2,
Model: model,
Stream: false,
Model: "gpt-3.5-turbo",
}
testMessage := relaymodel.Message{
Role: "user",
@@ -46,7 +45,7 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
return testRequest
}
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
@@ -69,8 +68,12 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := request.Model
var modelName string
modelList := adaptor.GetModelList()
modelMap := channel.GetModelMapping()
if len(modelList) != 0 {
modelName = modelList[0]
}
if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
@@ -80,8 +83,9 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
modelName = modelMap[modelName]
}
}
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil {
return err, nil
@@ -135,15 +139,10 @@ func TestChannel(c *gin.Context) {
})
return
}
model := c.Query("model")
testRequest := buildTestRequest(model)
tik := time.Now()
err, _ = testChannel(channel, testRequest)
err, _ = testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
milliseconds = 0
}
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
@@ -151,7 +150,6 @@ func TestChannel(c *gin.Context) {
"success": false,
"message": err.Error(),
"time": consumedTime,
"model": model,
})
return
}
@@ -159,7 +157,6 @@ func TestChannel(c *gin.Context) {
"success": true,
"message": "",
"time": consumedTime,
"model": model,
})
return
}
@@ -190,12 +187,11 @@ func testChannels(notify bool, scope string) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now()
testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest)
err, openaiErr := testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else {

2
go.mod
View File

@@ -68,7 +68,7 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect

4
go.sum
View File

@@ -110,8 +110,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View File

@@ -330,6 +330,7 @@
"通常和邮箱地址保持一致": "Usually consistent with the email address",
"SMTP 访问凭证": "SMTP Access Credential",
"敏感信息不会发送到前端显示": "Sensitive information will not be displayed in the frontend",
"使用 SMTP LOGIN 认证方式": "Use LOGIN as SMTP authentication method",
"保存 SMTP 设置": "Save SMTP Settings",
"配置 GitHub OAuth App": "Configure GitHub OAuth App",
"用以支持通过 GitHub 进行登录注册": "To support login & registration via GitHub",

View File

@@ -30,17 +30,13 @@ func CreateRootAccountIfNeed() error {
if err != nil {
return err
}
accessToken := random.GetUUID()
if config.InitialRootAccessToken != "" {
accessToken = config.InitialRootAccessToken
}
rootUser := User{
Username: "root",
Password: hashedPassword,
Role: RoleRootUser,
Status: UserStatusEnabled,
DisplayName: "Root User",
AccessToken: accessToken,
AccessToken: random.GetUUID(),
Quota: 500000000000000,
}
DB.Create(&rootUser)

View File

@@ -45,6 +45,7 @@ func InitOptionMap() {
config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
config.OptionMap["SMTPAccount"] = ""
config.OptionMap["SMTPToken"] = ""
config.OptionMap["SMTPAuthLoginEnabled"] = strconv.FormatBool(config.SMTPAuthLoginEnabled)
config.OptionMap["Notice"] = ""
config.OptionMap["About"] = ""
config.OptionMap["HomePageContent"] = ""
@@ -150,6 +151,8 @@ func updateOptionMap(key string, value string) (err error) {
config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
config.DisplayTokenStatEnabled = boolValue
case "SMTPAuthLoginEnabled":
config.SMTPAuthLoginEnabled = boolValue
}
}
switch key {

View File

@@ -1,16 +1,17 @@
package aws
import (
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -18,52 +19,18 @@ import (
var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct {
awsAdapter utils.AwsAdapter
Meta *meta.Meta
AwsClient *bedrockruntime.Client
meta *meta.Meta
awsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
a.meta = meta
a.awsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
adaptor := GetAdaptor(request.Model)
if adaptor == nil {
return nil, errors.New("adaptor not found")
}
a.awsAdapter = adaptor
return adaptor.ConvertRequest(c, relayMode, request)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if a.awsAdapter == nil {
return nil, utils.WrapErr(errors.New("awsAdapter is nil"))
}
return a.awsAdapter.DoResponse(c, a.AwsClient, meta)
}
func (a *Adaptor) GetModelList() (models []string) {
for model := range adaptors {
models = append(models, model)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
@@ -72,6 +39,17 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
@@ -82,3 +60,23 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, a.awsClient)
} else {
err, usage = Handler(c, a.awsClient, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}

View File

@@ -1,37 +0,0 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@@ -1,37 +0,0 @@
package aws
import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ utils.AwsAdapter = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
llamaReq := ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, llamaReq)
return llamaReq, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}

View File

@@ -1,231 +0,0 @@
// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"text/template"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/random"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// Only support llama-3-8b and llama-3-70b instruction models
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var AwsModelIDMap = map[string]string{
"llama3-8b-8192": "meta.llama3-8b-instruct-v1:0",
"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0",
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
// promptTemplate with range
const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|>
`
var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate))
func RenderPrompt(messages []relaymodel.Message) string {
var buf bytes.Buffer
err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages})
if err != nil {
logger.SysError("error rendering prompt messages: " + err.Error())
}
return buf.String()
}
func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
llamaRequest := Request{
MaxGenLen: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
}
if llamaRequest.MaxGenLen == 0 {
llamaRequest.MaxGenLen = 2048
}
prompt := RenderPrompt(textRequest.Messages)
llamaRequest.Prompt = prompt
return &llamaRequest
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
}
var llamaResponse Response
err = json.Unmarshal(awsResp.Body, &llamaResponse)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := ResponseLlama2OpenAI(&llamaResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: llamaResponse.PromptTokenCount,
CompletionTokens: llamaResponse.GenerationTokenCount,
TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
var responseText string
if len(llamaResponse.Generation) > 0 {
responseText = llamaResponse.Generation
}
choice := openai.TextResponseChoice{
Index: 0,
Message: relaymodel.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: llamaResponse.StopReason,
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
}
awsReq.Body, err = json.Marshal(llamaReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
var llamaResp StreamResponse
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
if llamaResp.PromptTokenCount > 0 {
usage.PromptTokens = llamaResp.PromptTokenCount
}
if llamaResp.StopReason == "stop" {
usage.CompletionTokens = llamaResp.GenerationTokenCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
response := StreamResponseLlama2OpenAI(&llamaResp)
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}
func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = llamaResponse.Generation
choice.Delta.Role = "assistant"
finishReason := llamaResponse.StopReason
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse
}

View File

@@ -1,45 +0,0 @@
package aws_test
import (
"testing"
aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/stretchr/testify/assert"
)
func TestRenderPrompt(t *testing.T) {
messages := []relaymodel.Message{
{
Role: "user",
Content: "What's your name?",
},
}
prompt := aws.RenderPrompt(messages)
expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
messages = []relaymodel.Message{
{
Role: "system",
Content: "Your name is Kat. You are a detective.",
},
{
Role: "user",
Content: "What's your name?",
},
{
Role: "assistant",
Content: "Kat",
},
{
Role: "user",
Content: "What's your job?",
},
}
prompt = aws.RenderPrompt(messages)
expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
`
assert.Equal(t, expected, prompt)
}

View File

@@ -1,29 +0,0 @@
package aws
// Request is the request to AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Request struct {
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
}
// Response is the response from AWS Llama3
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Response struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}
// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None}
type StreamResponse struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason string `json:"stop_reason"`
}

View File

@@ -5,6 +5,8 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"io"
"net/http"
@@ -15,17 +17,23 @@ import (
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var AwsModelIDMap = map[string]string{
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
@@ -36,7 +44,7 @@ var AwsModelIDMap = map[string]string{
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
@@ -46,7 +54,7 @@ func awsModelID(requestModel string) (string, error) {
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
@@ -57,30 +65,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
@@ -100,7 +108,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
@@ -111,7 +119,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return utils.WrapErr(errors.New("request not found")), nil
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)
@@ -119,16 +127,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()

View File

@@ -1,39 +0,0 @@
package aws
import (
claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude"
llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
)
type AwsModelType int
const (
AwsClaude AwsModelType = iota + 1
AwsLlama3
)
var (
adaptors = map[string]AwsModelType{}
)
func init() {
for model := range claude.AwsModelIDMap {
adaptors[model] = AwsClaude
}
for model := range llama3.AwsModelIDMap {
adaptors[model] = AwsLlama3
}
}
func GetAdaptor(model string) utils.AwsAdapter {
adaptorType := adaptors[model]
switch adaptorType {
case AwsClaude:
return &claude.Adaptor{}
case AwsLlama3:
return &llama3.Adaptor{}
default:
return nil
}
}

View File

@@ -1,51 +0,0 @@
package utils
import (
"errors"
"io"
"net/http"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
type AwsAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
}
type Adaptor struct {
Meta *meta.Meta
AwsClient *bedrockruntime.Client
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}

View File

@@ -1,16 +0,0 @@
package utils
import (
"net/http"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func WrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: err.Error(),
},
}
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
@@ -29,14 +28,7 @@ func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
}
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
@@ -49,14 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case relaymode.Completions:
return ConvertCompletionsRequest(*request), nil
case relaymode.ChatCompletions, relaymode.Embeddings:
return request, nil
default:
return nil, errors.New("not implemented")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {

View File

@@ -3,13 +3,11 @@ package cloudflare
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
@@ -18,23 +16,57 @@ import (
"github.com/songquanpeng/one-api/relay/model"
)
func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
p, _ := textRequest.Prompt.(string)
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
var promptBuilder strings.Builder
for _, message := range textRequest.Messages {
promptBuilder.WriteString(message.StringContent())
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
}
return &Request{
Prompt: p,
MaxTokens: textRequest.MaxTokens,
Prompt: promptBuilder.String(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cloudflareResponse.Result.Response,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
Created: helper.GetTimestamp(),
}
return &openaiResponse
}
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
responseModel := c.GetString(ctxkey.OriginalModel)
responseModel := c.GetString("original_model")
var responseText string
for scanner.Scan() {
@@ -45,22 +77,22 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response openai.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
var cloudflareResponse StreamResponse
err := json.Unmarshal([]byte(data), &cloudflareResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, v := range response.Choices {
v.Delta.Role = "assistant"
responseText += v.Delta.StringContent()
response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
if response == nil {
continue
}
responseText += cloudflareResponse.Response
response.Id = id
response.Model = modelName
response.Model = responseModel
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
@@ -91,25 +123,22 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var response openai.TextResponse
err = json.Unmarshal(responseBody, &response)
var cloudflareResponse Response
err = json.Unmarshal(responseBody, &cloudflareResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
response.Model = modelName
var responseText string
for _, v := range response.Choices {
responseText += v.Message.Content.(string)
}
usage := openai.ResponseText2Usage(responseText, modelName, promptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse)
fullTextResponse.Model = modelName
usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens)
fullTextResponse.Usage = *usage
fullTextResponse.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
_, err = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -1,13 +1,25 @@
package cloudflare
import "github.com/songquanpeng/one-api/relay/model"
type Request struct {
Messages []model.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
type Result struct {
Response string `json:"response"`
}
type Response struct {
Result Result `json:"result"`
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
}
type StreamResponse struct {
Response string `json:"response"`
}

View File

@@ -1,19 +0,0 @@
package novita
// https://novita.ai/llm-api
var ModelList = []string{
"meta-llama/llama-3-8b-instruct",
"meta-llama/llama-3-70b-instruct",
"nousresearch/hermes-2-pro-llama-3-8b",
"nousresearch/nous-hermes-llama2-13b",
"mistralai/mistral-7b-instruct",
"cognitivecomputations/dolphin-mixtral-8x22b",
"sao10k/l3-70b-euryale-v2.1",
"sophosympatheia/midnight-rose-70b",
"gryphe/mythomax-l2-13b",
"Nous-Hermes-2-Mixtral-8x7B-DPO",
"lzlv_70b",
"teknium/openhermes-2.5-mistral-7b",
"microsoft/wizardlm-2-8x22b",
}

View File

@@ -1,15 +0,0 @@
package novita
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode)
}

View File

@@ -3,19 +3,17 @@ package openai
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
)
type Adaptor struct {
@@ -50,8 +48,6 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return minimax.GetRequestURL(meta)
case channeltype.Doubao:
return doubao.GetRequestURL(meta)
case channeltype.Novita:
return novita.GetRequestURL(meta)
default:
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/channeltype"
@@ -29,7 +28,6 @@ var CompatibleChannels = []int{
channeltype.StepFun,
channeltype.DeepSeek,
channeltype.TogetherAI,
channeltype.Novita,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -58,8 +56,6 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "together.ai", togetherai.ModelList
case channeltype.Doubao:
return "doubao", doubao.ModelList
case channeltype.Novita:
return "novita", novita.ModelList
default:
return "openai", ModelList
}

View File

@@ -4,12 +4,11 @@ import (
"bufio"
"bytes"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
@@ -32,7 +31,6 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() {
data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format
@@ -43,7 +41,6 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data)
doneRendered = true
continue
}
switch relayMode {
@@ -84,9 +81,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
logger.SysError("error reading stream: " + err.Error())
}
if !doneRendered {
render.Done(c)
}
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@@ -2,7 +2,6 @@ package ratio
import (
"encoding/json"
"fmt"
"strings"
"github.com/songquanpeng/one-api/common/logger"
@@ -170,9 +169,6 @@ var ModelRatio = map[string]float64{
"step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
// https://cohere.com/pricing
"command": 0.5,
"command-nightly": 0.5,
@@ -189,11 +185,7 @@ var ModelRatio = map[string]float64{
"deepl-ja": 25.0 / 1000 * USD,
}
var CompletionRatio = map[string]float64{
// aws llama3
"llama3-8b-8192(33)": 0.0006 / 0.0003,
"llama3-70b-8192(33)": 0.0035 / 0.00265,
}
var CompletionRatio = map[string]float64{}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
@@ -242,28 +234,22 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
}
func GetModelRatio(name string, channelType int) float64 {
func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := ModelRatio[model]; ok {
return ratio
ratio, ok := ModelRatio[name]
if !ok {
ratio, ok = DefaultModelRatio[name]
}
if ratio, ok := DefaultModelRatio[model]; ok {
return ratio
if !ok {
logger.SysError("model ratio not found: " + name)
return 30
}
if ratio, ok := ModelRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[name]; ok {
return ratio
}
logger.SysError("model ratio not found: " + name)
return 30
return ratio
}
func CompletionRatio2JSONString() string {
@@ -279,17 +265,7 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := CompletionRatio[model]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[model]; ok {
return ratio
}
func GetCompletionRatio(name string) float64 {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}

View File

@@ -42,6 +42,5 @@ const (
DeepL
TogetherAI
Doubao
Novita
Dummy
)

View File

@@ -42,7 +42,6 @@ var ChannelBaseURLs = []string{
"https://api-free.deepl.com", // 38
"https://api.together.xyz", // 39
"https://ark.cn-beijing.volces.com", // 40
"https://api.novita.ai/v3/openai", // 41
}
func init() {

View File

@@ -7,10 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
@@ -25,6 +21,9 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
)
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@@ -54,7 +53,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
modelRatio := billingratio.GetModelRatio(audioModel, channelType)
modelRatio := billingratio.GetModelRatio(audioModel)
groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int64

View File

@@ -4,10 +4,6 @@ import (
"context"
"errors"
"fmt"
"math"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
@@ -20,6 +16,9 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
"strings"
)
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
@@ -96,7 +95,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
return
}
var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))

View File

@@ -6,9 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
@@ -20,6 +17,8 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
@@ -167,7 +166,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
modelRatio := billingratio.GetModelRatio(imageModel)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)

View File

@@ -4,9 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
@@ -17,6 +14,8 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
@@ -36,7 +35,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
modelRatio := billingratio.GetModelRatio(textRequest.Model)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
// pre-consume quota

View File

@@ -47,7 +47,7 @@ const PersonalSetting = () => {
const [countdown, setCountdown] = useState(30);
const [affLink, setAffLink] = useState('');
const [systemToken, setSystemToken] = useState('');
const [models, setModels] = useState([]);
// const [models, setModels] = useState([]);
const [openTransfer, setOpenTransfer] = useState(false);
const [transferAmount, setTransferAmount] = useState(0);
@@ -72,7 +72,7 @@ const PersonalSetting = () => {
console.log(userState);
}
);
loadModels().then();
// loadModels().then();
getAffLink().then();
setTransferAmount(getQuotaPerUnit());
}, []);
@@ -127,16 +127,16 @@ const PersonalSetting = () => {
}
};
const loadModels = async () => {
let res = await API.get(`/api/user/available_models`);
const { success, message, data } = res.data;
if (success) {
setModels(data);
console.log(data);
} else {
showError(message);
}
};
// const loadModels = async () => {
// let res = await API.get(`/api/user/models`);
// const { success, message, data } = res.data;
// if (success) {
// setModels(data);
// console.log(data);
// } else {
// showError(message);
// }
// };
const handleAffLinkClick = async (e) => {
e.target.select();
@@ -344,7 +344,7 @@ const PersonalSetting = () => {
}
>
<Typography.Title heading={6}>调用信息</Typography.Title>
<p>可用模型可点击复制</p>
{/* <Typography.Title heading={6}>可用模型</Typography.Title>
<div style={{ marginTop: 10 }}>
<Space wrap>
{models.map((model) => (
@@ -355,7 +355,7 @@ const PersonalSetting = () => {
</Tag>
))}
</Space>
</div>
</div> */}
</Card>
{/* <Card
footer={

View File

@@ -13,7 +13,7 @@ export const CHANNEL_OPTIONS = {
},
33: {
key: 33,
text: 'AWS',
text: 'AWS Claude',
value: 33,
color: 'primary'
},
@@ -161,12 +161,6 @@ export const CHANNEL_OPTIONS = {
value: 39,
color: 'primary'
},
41: {
key: 41,
text: 'Novita',
value: 41,
color: 'purple'
},
8: {
key: 8,
text: '自定义渠道',

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import {
API,
@@ -70,33 +70,13 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/?p=${startIdx}`);
const { success, message, data } = res.data;
if (success) {
let localChannels = data.map((channel) => {
if (channel.models === '') {
channel.models = [];
channel.test_model = "";
} else {
channel.models = channel.models.split(',');
if (channel.models.length > 0) {
channel.test_model = channel.models[0];
}
channel.model_options = channel.models.map((model) => {
return {
key: model,
text: model,
value: model,
}
})
console.log('channel', channel)
}
return channel;
});
if (startIdx === 0) {
setChannels(localChannels);
} else {
let newChannels = [...channels];
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels);
setChannels(newChannels);
}
if (startIdx === 0) {
setChannels(data);
} else {
let newChannels = [...channels];
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
setChannels(newChannels);
}
} else {
showError(message);
}
@@ -245,31 +225,19 @@ const ChannelsTable = () => {
setSearching(false);
};
const switchTestModel = async (idx, model) => {
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].test_model = model;
setChannels(newChannels);
};
const testChannel = async (id, name, idx, m) => {
const res = await API.get(`/api/channel/test/${id}?model=${m}`);
const { success, message, time, model } = res.data;
const testChannel = async (id, name, idx) => {
const res = await API.get(`/api/channel/test/${id}/`);
const { success, message, time } = res.data;
if (success) {
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`渠道 ${name} 测试成功,模型 ${model}耗时 ${time.toFixed(2)} 秒。`);
showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
}
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
};
const testChannels = async (scope) => {
@@ -437,7 +405,6 @@ const ChannelsTable = () => {
>
优先级
</Table.HeaderCell>
<Table.HeaderCell>测试模型</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row>
</Table.Header>
@@ -492,24 +459,13 @@ const ChannelsTable = () => {
basic
/>
</Table.Cell>
<Table.Cell>
<Dropdown
placeholder='请选择测试模型'
selection
options={channel.model_options}
defaultValue={channel.test_model}
onChange={(event, data) => {
switchTestModel(idx, data.value);
}}
/>
</Table.Cell>
<Table.Cell>
<div>
<Button
size={'small'}
positive
onClick={() => {
testChannel(channel.id, channel.name, idx, channel.test_model);
testChannel(channel.id, channel.name, idx);
}}
>
测试

View File

@@ -18,6 +18,7 @@ const SystemSetting = () => {
SMTPAccount: '',
SMTPFrom: '',
SMTPToken: '',
SMTPAuthLoginEnabled: '',
ServerAddress: '',
Footer: '',
WeChatAuthEnabled: '',
@@ -76,6 +77,7 @@ const SystemSetting = () => {
case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
case 'RegisterEnabled':
case 'SMTPAuthLoginEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
@@ -107,7 +109,7 @@ const SystemSetting = () => {
}
if (
name === 'Notice' ||
name.startsWith('SMTP') ||
(name.startsWith('SMTP') && !name.endsWith('Enabled')) ||
name === 'ServerAddress' ||
name === 'GitHubClientId' ||
name === 'GitHubClientSecret' ||
@@ -444,6 +446,12 @@ const SystemSetting = () => {
checked={inputs.RegisterEnabled === 'true'}
placeholder='敏感信息不会发送到前端显示'
/>
<Form.Checkbox
checked={inputs.SMTPAuthLoginEnabled === 'true'}
label='使用 SMTP LOGIN 认证方式'
name='SMTPAuthLoginEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button>
<Divider />

View File

@@ -1,12 +1,11 @@
export const CHANNEL_OPTIONS = [
{key: 1, text: 'OpenAI', value: 1, color: 'green'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'black'},
{key: 33, text: 'AWS', value: 33, color: 'black'},
{key: 33, text: 'AWS Claude', value: 33, color: 'black'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'},
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange'},
{key: 24, text: 'Google Gemini', value: 24, color: 'orange'},
{key: 28, text: 'Mistral AI', value: 28, color: 'orange'},
{key: 41, text: 'Novita', value: 41, color: 'purple'},
{key: 40, text: '字节跳动豆包', value: 40, color: 'blue'},
{key: 15, text: '百度文心千帆', value: 15, color: 'blue'},
{key: 17, text: '阿里通义千问', value: 17, color: 'orange'},