mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-27 11:53:42 +08:00
Compare commits
50 Commits
v0.6.8-alp
...
v0.6.10-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49ffb1c60d | ||
|
|
2f16649896 | ||
|
|
af3aa57bd6 | ||
|
|
e9f117ff72 | ||
|
|
6bb5247bd6 | ||
|
|
305ce14fe3 | ||
|
|
36c8f4f15c | ||
|
|
45b51ea0ee | ||
|
|
7c8628bd95 | ||
|
|
6ab87f8a08 | ||
|
|
833fa7ad6f | ||
|
|
6eb0770a89 | ||
|
|
92cd46d64f | ||
|
|
2b2dc2c733 | ||
|
|
a3d7df7f89 | ||
|
|
c368232f50 | ||
|
|
cbfc983dc3 | ||
|
|
8ec092ba44 | ||
|
|
b0b88a79ff | ||
|
|
7e51b04221 | ||
|
|
f75a17f8eb | ||
|
|
6f13a3bb3c | ||
|
|
f092eed1db | ||
|
|
629378691b | ||
|
|
3716e1b0e6 | ||
|
|
a4d6e7a886 | ||
|
|
cb772e5d06 | ||
|
|
e32cb0b844 | ||
|
|
fdd7bf41c0 | ||
|
|
29389ed44f | ||
|
|
88acc5a614 | ||
|
|
a21681096a | ||
|
|
32f90a79a8 | ||
|
|
99c8c77504 | ||
|
|
649ecbf29c | ||
|
|
3a27c90910 | ||
|
|
cba82404ae | ||
|
|
c9ac670ba1 | ||
|
|
15f815c23c | ||
|
|
89b63ca96f | ||
|
|
8cc54489b9 | ||
|
|
58bf60805e | ||
|
|
6714cf96d6 | ||
|
|
f9774698e9 | ||
|
|
2af6f6a166 | ||
|
|
04bb3ef392 | ||
|
|
b4bfa418a8 | ||
|
|
e7e99e558a | ||
|
|
402fcf7f79 | ||
|
|
36039e329e |
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -1,13 +1,13 @@
|
|||||||
name: CI
|
name: CI
|
||||||
|
|
||||||
# This setup assumes that you run the unit tests with code coverage in the same
|
# This setup assumes that you run the unit tests with code coverage in the same
|
||||||
# workflow that will also print the coverage report as comment to the pull request.
|
# workflow that will also print the coverage report as comment to the pull request.
|
||||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
|
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
|
||||||
# when new code is pushed to the branch of the pull request. In addition, you also
|
# when new code is pushed to the branch of the pull request. In addition, you also
|
||||||
# need to trigger this workflow when new code is pushed to the main branch because
|
# need to trigger this workflow when new code is pushed to the main branch because
|
||||||
# we need to upload the code coverage results as artifact for the main branch as
|
# we need to upload the code coverage results as artifact for the main branch as
|
||||||
# well since it will be the baseline code coverage.
|
# well since it will be the baseline code coverage.
|
||||||
#
|
#
|
||||||
# We do not want to trigger the workflow for pushes to *any* branch because this
|
# We do not want to trigger the workflow for pushes to *any* branch because this
|
||||||
# would trigger our jobs twice on pull requests (once from "push" event and once
|
# would trigger our jobs twice on pull requests (once from "push" event and once
|
||||||
# from "pull_request->synchronize")
|
# from "pull_request->synchronize")
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
go-version: ^1.22
|
go-version: ^1.22
|
||||||
|
|
||||||
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
||||||
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
|
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
|
||||||
# in the next step as well as the next job.
|
# in the next step as well as the next job.
|
||||||
- name: Test
|
- name: Test
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,4 +9,5 @@ logs
|
|||||||
data
|
data
|
||||||
/web/node_modules
|
/web/node_modules
|
||||||
cmd.md
|
cmd.md
|
||||||
.env
|
.env
|
||||||
|
/one-api
|
||||||
|
|||||||
11
README.md
11
README.md
@@ -89,6 +89,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
+ [x] [DeepL](https://www.deepl.com/)
|
+ [x] [DeepL](https://www.deepl.com/)
|
||||||
+ [x] [together.ai](https://www.together.ai/)
|
+ [x] [together.ai](https://www.together.ai/)
|
||||||
+ [x] [novita.ai](https://www.novita.ai/)
|
+ [x] [novita.ai](https://www.novita.ai/)
|
||||||
|
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
|
||||||
|
+ [x] [xAI](https://x.ai/)
|
||||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
|
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
|
||||||
3. 支持通过**负载均衡**的方式访问多个渠道。
|
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||||
@@ -113,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
21. 支持 Cloudflare Turnstile 用户校验。
|
21. 支持 Cloudflare Turnstile 用户校验。
|
||||||
22. 支持用户管理,支持**多种用户登录注册方式**:
|
22. 支持用户管理,支持**多种用户登录注册方式**:
|
||||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||||
+ 支持使用飞书进行授权登录。
|
+ 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。
|
||||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
+ 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。
|
||||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||||
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
|
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
|
||||||
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
|
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
|
||||||
@@ -251,9 +253,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
#### QChatGPT - QQ机器人
|
#### QChatGPT - QQ机器人
|
||||||
项目主页:https://github.com/RockChinQ/QChatGPT
|
项目主页:https://github.com/RockChinQ/QChatGPT
|
||||||
|
|
||||||
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
|
根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。
|
||||||
|
|
||||||
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
|
运行期间可以通过`!model`命令查看、切换可用模型。
|
||||||
|
|
||||||
### 部署到第三方平台
|
### 部署到第三方平台
|
||||||
<details>
|
<details>
|
||||||
@@ -398,6 +400,7 @@ graph LR
|
|||||||
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
|
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
|
||||||
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
|
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
|
||||||
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
|
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
|
||||||
|
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ var PasswordLoginEnabled = true
|
|||||||
var PasswordRegisterEnabled = true
|
var PasswordRegisterEnabled = true
|
||||||
var EmailVerificationEnabled = false
|
var EmailVerificationEnabled = false
|
||||||
var GitHubOAuthEnabled = false
|
var GitHubOAuthEnabled = false
|
||||||
|
var OidcEnabled = false
|
||||||
var WeChatAuthEnabled = false
|
var WeChatAuthEnabled = false
|
||||||
var TurnstileCheckEnabled = false
|
var TurnstileCheckEnabled = false
|
||||||
var RegisterEnabled = true
|
var RegisterEnabled = true
|
||||||
@@ -70,6 +71,13 @@ var GitHubClientSecret = ""
|
|||||||
var LarkClientId = ""
|
var LarkClientId = ""
|
||||||
var LarkClientSecret = ""
|
var LarkClientSecret = ""
|
||||||
|
|
||||||
|
var OidcClientId = ""
|
||||||
|
var OidcClientSecret = ""
|
||||||
|
var OidcWellKnown = ""
|
||||||
|
var OidcAuthorizationEndpoint = ""
|
||||||
|
var OidcTokenEndpoint = ""
|
||||||
|
var OidcUserinfoEndpoint = ""
|
||||||
|
|
||||||
var WeChatServerAddress = ""
|
var WeChatServerAddress = ""
|
||||||
var WeChatServerToken = ""
|
var WeChatServerToken = ""
|
||||||
var WeChatAccountQRCodeImageURL = ""
|
var WeChatAccountQRCodeImageURL = ""
|
||||||
@@ -152,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
|
|||||||
var RelayProxy = env.String("RELAY_PROXY", "")
|
var RelayProxy = env.String("RELAY_PROXY", "")
|
||||||
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
|
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
|
||||||
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
|
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
|
||||||
|
|
||||||
|
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
|
||||||
|
|||||||
@@ -20,4 +20,5 @@ const (
|
|||||||
BaseURL = "base_url"
|
BaseURL = "base_url"
|
||||||
AvailableModels = "available_models"
|
AvailableModels = "available_models"
|
||||||
KeyRequestBody = "key_request_body"
|
KeyRequestBody = "key_request_body"
|
||||||
|
SystemPrompt = "system_prompt"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = json.Unmarshal(requestBody, &v)
|
err = json.Unmarshal(requestBody, &v)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
} else {
|
} else {
|
||||||
// skip for now
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
// TODO: someday non json request have variant model, we will need to implementation this
|
err = c.ShouldBind(&v)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Reset request body
|
// Reset request body
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -137,3 +137,23 @@ func String2Int(str string) int {
|
|||||||
}
|
}
|
||||||
return num
|
return num
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Float64PtrMax(p *float64, maxValue float64) *float64 {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if *p > maxValue {
|
||||||
|
return &maxValue
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64PtrMin(p *float64, minValue float64) *float64 {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if *p < minValue {
|
||||||
|
return &minValue
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ package render
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func StringData(c *gin.Context, str string) {
|
func StringData(c *gin.Context, str string) {
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
|
req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
225
controller/auth/oidc.go
Normal file
225
controller/auth/oidc.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/controller"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OidcResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcUser struct {
|
||||||
|
OpenID string `json:"sub"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
PreferredUsername string `json:"preferred_username"`
|
||||||
|
Picture string `json:"picture"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||||
|
if code == "" {
|
||||||
|
return nil, errors.New("无效的参数")
|
||||||
|
}
|
||||||
|
values := map[string]string{
|
||||||
|
"client_id": config.OidcClientId,
|
||||||
|
"client_secret": config.OidcClientSecret,
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(values)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
var oidcResponse OidcResponse
|
||||||
|
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||||
|
res2, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
var oidcUser OidcUser
|
||||||
|
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &oidcUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OidcAuth(c *gin.Context) {
|
||||||
|
session := sessions.Default(c)
|
||||||
|
state := c.Query("state")
|
||||||
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "state is empty or not same",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
username := session.Get("username")
|
||||||
|
if username != nil {
|
||||||
|
OidcBind(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !config.OidcEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := c.Query("code")
|
||||||
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
OidcId: oidcUser.OpenID,
|
||||||
|
}
|
||||||
|
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||||
|
err := user.FillUserByOidcId()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if config.RegisterEnabled {
|
||||||
|
user.Email = oidcUser.Email
|
||||||
|
if oidcUser.PreferredUsername != "" {
|
||||||
|
user.Username = oidcUser.PreferredUsername
|
||||||
|
} else {
|
||||||
|
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
|
}
|
||||||
|
if oidcUser.Name != "" {
|
||||||
|
user.DisplayName = oidcUser.Name
|
||||||
|
} else {
|
||||||
|
user.DisplayName = "OIDC User"
|
||||||
|
}
|
||||||
|
err := user.Insert(0)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员关闭了新用户注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Status != model.UserStatusEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "用户已被封禁",
|
||||||
|
"success": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
controller.SetupLogin(&user, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func OidcBind(c *gin.Context) {
|
||||||
|
if !config.OidcEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := c.Query("code")
|
||||||
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
OidcId: oidcUser.OpenID,
|
||||||
|
}
|
||||||
|
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "该 OIDC 账户已被绑定",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session := sessions.Default(c)
|
||||||
|
id := session.Get("id")
|
||||||
|
// id := c.GetInt("id") // critical bug!
|
||||||
|
user.Id = id.(int)
|
||||||
|
err = user.FillUserById()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user.OidcId = oidcUser.OpenID
|
||||||
|
err = user.Update(false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "bind",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) {
|
|||||||
if config.DisplayTokenStatEnabled {
|
if config.DisplayTokenStatEnabled {
|
||||||
tokenId := c.GetInt(ctxkey.TokenId)
|
tokenId := c.GetInt(ctxkey.TokenId)
|
||||||
token, err = model.GetTokenById(tokenId)
|
token, err = model.GetTokenById(tokenId)
|
||||||
expiredTime = token.ExpiredTime
|
if err == nil {
|
||||||
remainQuota = token.RemainQuota
|
expiredTime = token.ExpiredTime
|
||||||
usedQuota = token.UsedQuota
|
remainQuota = token.RemainQuota
|
||||||
|
usedQuota = token.UsedQuota
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
userId := c.GetInt(ctxkey.Id)
|
userId := c.GetInt(ctxkey.Id)
|
||||||
remainQuota, err = model.GetUserQuota(userId)
|
remainQuota, err = model.GetUserQuota(userId)
|
||||||
|
|||||||
@@ -81,6 +81,26 @@ type APGC2DGPTUsageResponse struct {
|
|||||||
TotalUsed float64 `json:"total_used"`
|
TotalUsed float64 `json:"total_used"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SiliconFlowUsageResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status bool `json:"status"`
|
||||||
|
Data struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Image string `json:"image"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
IsAdmin bool `json:"isAdmin"`
|
||||||
|
Balance string `json:"balance"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Introduction string `json:"introduction"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
ChargeBalance string `json:"chargeBalance"`
|
||||||
|
TotalBalance string `json:"totalBalance"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetAuthHeader get auth header
|
// GetAuthHeader get auth header
|
||||||
func GetAuthHeader(token string) http.Header {
|
func GetAuthHeader(token string) http.Header {
|
||||||
h := http.Header{}
|
h := http.Header{}
|
||||||
@@ -203,6 +223,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
|||||||
return response.TotalAvailable, nil
|
return response.TotalAvailable, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
|
||||||
|
url := "https://api.siliconflow.cn/v1/user/info"
|
||||||
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
response := SiliconFlowUsageResponse{}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if response.Code != 20000 {
|
||||||
|
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
||||||
|
}
|
||||||
|
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
channel.UpdateBalance(balance)
|
||||||
|
return balance, nil
|
||||||
|
}
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := channeltype.ChannelBaseURLs[channel.Type]
|
baseURL := channeltype.ChannelBaseURLs[channel.Type]
|
||||||
if channel.GetBaseURL() == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
@@ -227,6 +269,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
return updateChannelAPI2GPTBalance(channel)
|
return updateChannelAPI2GPTBalance(channel)
|
||||||
case channeltype.AIGC2D:
|
case channeltype.AIGC2D:
|
||||||
return updateChannelAIGC2DBalance(channel)
|
return updateChannelAIGC2DBalance(channel)
|
||||||
|
case channeltype.SiliconFlow:
|
||||||
|
return updateChannelSiliconFlowBalance(channel)
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,9 +76,9 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
|
|||||||
if len(modelNames) > 0 {
|
if len(modelNames) > 0 {
|
||||||
modelName = modelNames[0]
|
modelName = modelNames[0]
|
||||||
}
|
}
|
||||||
if modelMap != nil && modelMap[modelName] != "" {
|
}
|
||||||
modelName = modelMap[modelName]
|
if modelMap != nil && modelMap[modelName] != "" {
|
||||||
}
|
modelName = modelMap[modelName]
|
||||||
}
|
}
|
||||||
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
|
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
|
||||||
request.Model = modelName
|
request.Model = modelName
|
||||||
|
|||||||
@@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": gin.H{
|
"data": gin.H{
|
||||||
"version": common.Version,
|
"version": common.Version,
|
||||||
"start_time": common.StartTime,
|
"start_time": common.StartTime,
|
||||||
"email_verification": config.EmailVerificationEnabled,
|
"email_verification": config.EmailVerificationEnabled,
|
||||||
"github_oauth": config.GitHubOAuthEnabled,
|
"github_oauth": config.GitHubOAuthEnabled,
|
||||||
"github_client_id": config.GitHubClientId,
|
"github_client_id": config.GitHubClientId,
|
||||||
"lark_client_id": config.LarkClientId,
|
"lark_client_id": config.LarkClientId,
|
||||||
"system_name": config.SystemName,
|
"system_name": config.SystemName,
|
||||||
"logo": config.Logo,
|
"logo": config.Logo,
|
||||||
"footer_html": config.Footer,
|
"footer_html": config.Footer,
|
||||||
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
|
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
|
||||||
"wechat_login": config.WeChatAuthEnabled,
|
"wechat_login": config.WeChatAuthEnabled,
|
||||||
"server_address": config.ServerAddress,
|
"server_address": config.ServerAddress,
|
||||||
"turnstile_check": config.TurnstileCheckEnabled,
|
"turnstile_check": config.TurnstileCheckEnabled,
|
||||||
"turnstile_site_key": config.TurnstileSiteKey,
|
"turnstile_site_key": config.TurnstileSiteKey,
|
||||||
"top_up_link": config.TopUpLink,
|
"top_up_link": config.TopUpLink,
|
||||||
"chat_link": config.ChatLink,
|
"chat_link": config.ChatLink,
|
||||||
"quota_per_unit": config.QuotaPerUnit,
|
"quota_per_unit": config.QuotaPerUnit,
|
||||||
"display_in_currency": config.DisplayInCurrencyEnabled,
|
"display_in_currency": config.DisplayInCurrencyEnabled,
|
||||||
|
"oidc": config.OidcEnabled,
|
||||||
|
"oidc_client_id": config.OidcClientId,
|
||||||
|
"oidc_well_known": config.OidcWellKnown,
|
||||||
|
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
|
||||||
|
"oidc_token_endpoint": config.OidcTokenEndpoint,
|
||||||
|
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
|
|||||||
channelName := c.GetString(ctxkey.ChannelName)
|
channelName := c.GetString(ctxkey.ChannelName)
|
||||||
group := c.GetString(ctxkey.Group)
|
group := c.GetString(ctxkey.Group)
|
||||||
originalModel := c.GetString(ctxkey.OriginalModel)
|
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||||
requestId := c.GetString(helper.RequestIdKey)
|
requestId := c.GetString(helper.RequestIdKey)
|
||||||
retryTimes := config.RetryTimes
|
retryTimes := config.RetryTimes
|
||||||
if !shouldRetry(c, bizErr.StatusCode) {
|
if !shouldRetry(c, bizErr.StatusCode) {
|
||||||
@@ -87,8 +87,7 @@ func Relay(c *gin.Context) {
|
|||||||
channelId := c.GetInt(ctxkey.ChannelId)
|
channelId := c.GetInt(ctxkey.ChannelId)
|
||||||
lastFailedChannelId = channelId
|
lastFailedChannelId = channelId
|
||||||
channelName := c.GetString(ctxkey.ChannelName)
|
channelName := c.GetString(ctxkey.ChannelName)
|
||||||
// BUG: bizErr is in race condition
|
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
|
||||||
}
|
}
|
||||||
if bizErr != nil {
|
if bizErr != nil {
|
||||||
if bizErr.StatusCode == http.StatusTooManyRequests {
|
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||||
@@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
|
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) {
|
||||||
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||||
|
|||||||
8
go.mod
8
go.mod
@@ -25,7 +25,7 @@ require (
|
|||||||
github.com/pkoukk/tiktoken-go v0.1.7
|
github.com/pkoukk/tiktoken-go v0.1.7
|
||||||
github.com/smartystreets/goconvey v1.8.1
|
github.com/smartystreets/goconvey v1.8.1
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
golang.org/x/crypto v0.24.0
|
golang.org/x/crypto v0.31.0
|
||||||
golang.org/x/image v0.18.0
|
golang.org/x/image v0.18.0
|
||||||
google.golang.org/api v0.187.0
|
google.golang.org/api v0.187.0
|
||||||
gorm.io/driver/mysql v1.5.6
|
gorm.io/driver/mysql v1.5.6
|
||||||
@@ -99,9 +99,9 @@ require (
|
|||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/net v0.26.0 // indirect
|
golang.org/x/net v0.26.0 // indirect
|
||||||
golang.org/x/oauth2 v0.21.0 // indirect
|
golang.org/x/oauth2 v0.21.0 // indirect
|
||||||
golang.org/x/sync v0.7.0 // indirect
|
golang.org/x/sync v0.10.0 // indirect
|
||||||
golang.org/x/sys v0.21.0 // indirect
|
golang.org/x/sys v0.28.0 // indirect
|
||||||
golang.org/x/text v0.16.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
|
||||||
|
|||||||
16
go.sum
16
go.sum
@@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
|||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
|
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
|
||||||
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
|
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
|
||||||
@@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
|
|||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model" form:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
@@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set(ctxkey.Channel, channel.Type)
|
c.Set(ctxkey.Channel, channel.Type)
|
||||||
c.Set(ctxkey.ChannelId, channel.Id)
|
c.Set(ctxkey.ChannelId, channel.Id)
|
||||||
c.Set(ctxkey.ChannelName, channel.Name)
|
c.Set(ctxkey.ChannelName, channel.Name)
|
||||||
|
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
|
||||||
|
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
|
||||||
|
}
|
||||||
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||||
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
|||||||
27
middleware/gzip.go
Normal file
27
middleware/gzip.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GzipDecodeMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if c.GetHeader("Content-Encoding") == "gzip" {
|
||||||
|
gzipReader, err := gzip.NewReader(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer gzipReader.Close()
|
||||||
|
|
||||||
|
// Replace the request body with the decompressed data
|
||||||
|
c.Request.Body = io.NopCloser(gzipReader)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue processing the request
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,6 +37,7 @@ type Channel struct {
|
|||||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
Config string `json:"config"`
|
Config string `json:"config"`
|
||||||
|
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChannelConfig struct {
|
type ChannelConfig struct {
|
||||||
|
|||||||
13
model/log.go
13
model/log.go
@@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
@@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
|
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
|
||||||
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
|
ifnull := "ifnull"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
ifnull = "COALESCE"
|
||||||
|
}
|
||||||
|
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
@@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
||||||
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
|
ifnull := "ifnull"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
ifnull = "COALESCE"
|
||||||
|
}
|
||||||
|
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func InitOptionMap() {
|
|||||||
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
|
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
|
||||||
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
|
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
|
||||||
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
|
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
|
||||||
|
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
|
||||||
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
|
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
|
||||||
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
|
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
|
||||||
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
|
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
|
||||||
@@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
config.EmailVerificationEnabled = boolValue
|
config.EmailVerificationEnabled = boolValue
|
||||||
case "GitHubOAuthEnabled":
|
case "GitHubOAuthEnabled":
|
||||||
config.GitHubOAuthEnabled = boolValue
|
config.GitHubOAuthEnabled = boolValue
|
||||||
|
case "OidcEnabled":
|
||||||
|
config.OidcEnabled = boolValue
|
||||||
case "WeChatAuthEnabled":
|
case "WeChatAuthEnabled":
|
||||||
config.WeChatAuthEnabled = boolValue
|
config.WeChatAuthEnabled = boolValue
|
||||||
case "TurnstileCheckEnabled":
|
case "TurnstileCheckEnabled":
|
||||||
@@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
config.LarkClientId = value
|
config.LarkClientId = value
|
||||||
case "LarkClientSecret":
|
case "LarkClientSecret":
|
||||||
config.LarkClientSecret = value
|
config.LarkClientSecret = value
|
||||||
|
case "OidcClientId":
|
||||||
|
config.OidcClientId = value
|
||||||
|
case "OidcClientSecret":
|
||||||
|
config.OidcClientSecret = value
|
||||||
|
case "OidcWellKnown":
|
||||||
|
config.OidcWellKnown = value
|
||||||
|
case "OidcAuthorizationEndpoint":
|
||||||
|
config.OidcAuthorizationEndpoint = value
|
||||||
|
case "OidcTokenEndpoint":
|
||||||
|
config.OidcTokenEndpoint = value
|
||||||
|
case "OidcUserinfoEndpoint":
|
||||||
|
config.OidcUserinfoEndpoint = value
|
||||||
case "Footer":
|
case "Footer":
|
||||||
config.Footer = value
|
config.Footer = value
|
||||||
case "SystemName":
|
case "SystemName":
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ type Token struct {
|
|||||||
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
||||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
||||||
Models *string `json:"models" gorm:"default:''"` // allowed models
|
Models *string `json:"models" gorm:"type:text"` // allowed models
|
||||||
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
|
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
|
|||||||
return &token, err
|
return &token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) Insert() error {
|
func (t *Token) Insert() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(token).Error
|
err = DB.Create(t).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||||
func (token *Token) Update() error {
|
func (t *Token) Update() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
|
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) SelectUpdate() error {
|
func (t *Token) SelectUpdate() error {
|
||||||
// This can update zero values
|
// This can update zero values
|
||||||
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) Delete() error {
|
func (t *Token) Delete() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Delete(token).Error
|
err = DB.Delete(t).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Token) GetModels() string {
|
||||||
|
if t == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if t.Models == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *t.Models
|
||||||
|
}
|
||||||
|
|
||||||
func DeleteTokenById(id int, userId int) (err error) {
|
func DeleteTokenById(id int, userId int) (err error) {
|
||||||
// Why we need userId here? In case user want to delete other's token.
|
// Why we need userId here? In case user want to delete other's token.
|
||||||
if id == 0 || userId == 0 {
|
if id == 0 || userId == 0 {
|
||||||
@@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
|||||||
|
|
||||||
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||||
token, err := GetTokenById(tokenId)
|
token, err := GetTokenById(tokenId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if quota > 0 {
|
if quota > 0 {
|
||||||
err = DecreaseUserQuota(token.UserId, quota)
|
err = DecreaseUserQuota(token.UserId, quota)
|
||||||
} else {
|
} else {
|
||||||
err = IncreaseUserQuota(token.UserId, -quota)
|
err = IncreaseUserQuota(token.UserId, -quota)
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !token.UnlimitedQuota {
|
if !token.UnlimitedQuota {
|
||||||
if quota > 0 {
|
if quota > 0 {
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type User struct {
|
|||||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||||
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
|
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
|
||||||
|
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||||
Quota int64 `json:"quota" gorm:"bigint;default:0"`
|
Quota int64 `json:"quota" gorm:"bigint;default:0"`
|
||||||
@@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) FillUserByOidcId() error {
|
||||||
|
if user.OidcId == "" {
|
||||||
|
return errors.New("oidc id 为空!")
|
||||||
|
}
|
||||||
|
DB.Where(User{OidcId: user.OidcId}).First(user)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (user *User) FillUserByWeChatId() error {
|
func (user *User) FillUserByWeChatId() error {
|
||||||
if user.WeChatId == "" {
|
if user.WeChatId == "" {
|
||||||
return errors.New("WeChat id 为空!")
|
return errors.New("WeChat id 为空!")
|
||||||
@@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
|
|||||||
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||||
|
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||||
|
}
|
||||||
|
|
||||||
func IsUsernameAlreadyTaken(username string) bool {
|
func IsUsernameAlreadyTaken(username string) bool {
|
||||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
package monitor
|
package monitor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
||||||
@@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
switch err.Type {
|
switch err.Type {
|
||||||
case "insufficient_quota":
|
case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
|
||||||
return true
|
|
||||||
// https://docs.anthropic.com/claude/reference/errors
|
|
||||||
case "authentication_error":
|
|
||||||
return true
|
|
||||||
case "permission_error":
|
|
||||||
return true
|
|
||||||
case "forbidden":
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
|
|
||||||
return true
|
lowerMessage := strings.ToLower(err.Message)
|
||||||
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
if strings.Contains(lowerMessage, "your access was terminated") ||
|
||||||
return true
|
strings.Contains(lowerMessage, "violation of our policies") ||
|
||||||
}
|
strings.Contains(lowerMessage, "your credit balance is too low") ||
|
||||||
//if strings.Contains(err.Message, "quota") {
|
strings.Contains(lowerMessage, "organization has been disabled") ||
|
||||||
// return true
|
strings.Contains(lowerMessage, "credit") ||
|
||||||
//}
|
strings.Contains(lowerMessage, "balance") ||
|
||||||
if strings.Contains(err.Message, "credit") {
|
strings.Contains(lowerMessage, "permission denied") ||
|
||||||
return true
|
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||||
}
|
strings.Contains(lowerMessage, "已欠费") {
|
||||||
if strings.Contains(err.Message, "balance") {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||||
@@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
|||||||
return &vertexai.Adaptor{}
|
return &vertexai.Adaptor{}
|
||||||
case apitype.Proxy:
|
case apitype.Proxy:
|
||||||
return &proxy.Adaptor{}
|
return &proxy.Adaptor{}
|
||||||
|
case apitype.Replicate:
|
||||||
|
return &replicate.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,23 @@
|
|||||||
package ali
|
package ali
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
"qwen-turbo", "qwen-turbo-latest",
|
||||||
"text-embedding-v1",
|
"qwen-plus", "qwen-plus-latest",
|
||||||
|
"qwen-max", "qwen-max-latest",
|
||||||
|
"qwen-max-longcontext",
|
||||||
|
"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
|
||||||
|
"qwen-vl-ocr", "qwen-vl-ocr-latest",
|
||||||
|
"qwen-audio-turbo",
|
||||||
|
"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
|
||||||
|
"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
|
||||||
|
"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
|
||||||
|
"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
|
||||||
|
"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
|
||||||
|
"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
|
||||||
|
"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
|
||||||
|
"qwen2-audio-instruct", "qwen-audio-chat",
|
||||||
|
"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
|
||||||
|
"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
|
||||||
|
"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
|
||||||
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
|
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package ali
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
enableSearch = true
|
enableSearch = true
|
||||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||||
}
|
}
|
||||||
if request.TopP >= 1 {
|
request.TopP = helper.Float64PtrMax(request.TopP, 0.9999)
|
||||||
request.TopP = 0.9999
|
|
||||||
}
|
|
||||||
return &ChatRequest{
|
return &ChatRequest{
|
||||||
Model: aliModel,
|
Model: aliModel,
|
||||||
Input: Input{
|
Input: Input{
|
||||||
@@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
|
|
||||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||||
return &EmbeddingRequest{
|
return &EmbeddingRequest{
|
||||||
Model: "text-embedding-v1",
|
Model: request.Model,
|
||||||
Input: struct {
|
Input: struct {
|
||||||
Texts []string `json:"texts"`
|
Texts []string `json:"texts"`
|
||||||
}{
|
}{
|
||||||
@@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
requestModel := c.GetString(ctxkey.RequestModel)
|
||||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||||
|
fullTextResponse.Model = requestModel
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ type Input struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Parameters struct {
|
type Parameters struct {
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
Seed uint64 `json:"seed,omitempty"`
|
Seed uint64 `json:"seed,omitempty"`
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
EnableSearch bool `json:"enable_search,omitempty"`
|
||||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
ResultFormat string `json:"result_format,omitempty"`
|
ResultFormat string `json:"result_format,omitempty"`
|
||||||
Tools []model.Tool `json:"tools,omitempty"`
|
Tools []model.Tool `json:"tools,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,11 @@ package anthropic
|
|||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
||||||
"claude-3-haiku-20240307",
|
"claude-3-haiku-20240307",
|
||||||
|
"claude-3-5-haiku-20241022",
|
||||||
"claude-3-sonnet-20240229",
|
"claude-3-sonnet-20240229",
|
||||||
"claude-3-opus-20240229",
|
"claude-3-opus-20240229",
|
||||||
"claude-3-5-sonnet-20240620",
|
"claude-3-5-sonnet-20240620",
|
||||||
|
"claude-3-5-sonnet-20241022",
|
||||||
|
"claude-3-5-sonnet-latest",
|
||||||
|
"claude-3-5-haiku-20241022",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ type Request struct {
|
|||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
|||||||
@@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{
|
|||||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||||
"claude-2.0": "anthropic.claude-v2",
|
"claude-2.0": "anthropic.claude-v2",
|
||||||
"claude-2.1": "anthropic.claude-v2:1",
|
"claude-2.1": "anthropic.claude-v2:1",
|
||||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
|
||||||
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
||||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
|
||||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsModelID(requestModel string) (string, error) {
|
func awsModelID(requestModel string) (string, error) {
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ type Request struct {
|
|||||||
Messages []anthropic.Message `json:"messages"`
|
Messages []anthropic.Message `json:"messages"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ package aws
|
|||||||
//
|
//
|
||||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
MaxGenLen int `json:"max_gen_len,omitempty"`
|
MaxGenLen int `json:"max_gen_len,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response is the response from AWS Llama3
|
// Response is the response from AWS Llama3
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ type Message struct {
|
|||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
PenaltyScore *float64 `json:"penalty_score,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
DisableSearch bool `json:"disable_search,omitempty"`
|
DisableSearch bool `json:"disable_search,omitempty"`
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package cloudflare
|
package cloudflare
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
|
"@cf/meta/llama-3.1-8b-instruct",
|
||||||
"@cf/meta/llama-2-7b-chat-fp16",
|
"@cf/meta/llama-2-7b-chat-fp16",
|
||||||
"@cf/meta/llama-2-7b-chat-int8",
|
"@cf/meta/llama-2-7b-chat-int8",
|
||||||
"@cf/mistral/mistral-7b-instruct-v0.1",
|
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ type Request struct {
|
|||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
Raw bool `json:"raw,omitempty"`
|
Raw bool `json:"raw,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
K: textRequest.TopK,
|
K: textRequest.TopK,
|
||||||
Stream: textRequest.Stream,
|
Stream: textRequest.Stream,
|
||||||
FrequencyPenalty: textRequest.FrequencyPenalty,
|
FrequencyPenalty: textRequest.FrequencyPenalty,
|
||||||
PresencePenalty: textRequest.FrequencyPenalty,
|
PresencePenalty: textRequest.PresencePenalty,
|
||||||
Seed: int(textRequest.Seed),
|
Seed: int(textRequest.Seed),
|
||||||
}
|
}
|
||||||
if cohereRequest.Model == "" {
|
if cohereRequest.Model == "" {
|
||||||
|
|||||||
@@ -10,15 +10,15 @@ type Request struct {
|
|||||||
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
|
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
|
||||||
Connectors []Connector `json:"connectors,omitempty"`
|
Connectors []Connector `json:"connectors,omitempty"`
|
||||||
Documents []Document `json:"documents,omitempty"`
|
Documents []Document `json:"documents,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3
|
Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
||||||
K int `json:"k,omitempty"` // 默认值为0
|
K int `json:"k,omitempty"` // 默认值为0
|
||||||
P float64 `json:"p,omitempty"` // 默认值为0.75
|
P *float64 `json:"p,omitempty"` // 默认值为0.75
|
||||||
Seed int `json:"seed,omitempty"`
|
Seed int `json:"seed,omitempty"`
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
|
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
|
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
ToolResults []ToolResult `json:"tool_results,omitempty"`
|
ToolResults []ToolResult `json:"tool_results,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,12 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
|
defaultVersion := config.GeminiVersion
|
||||||
|
if meta.ActualModelName == "gemini-2.0-flash-exp" {
|
||||||
|
defaultVersion = "v1beta"
|
||||||
|
}
|
||||||
|
|
||||||
|
version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
|
||||||
action := ""
|
action := ""
|
||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.Embeddings:
|
case relaymode.Embeddings:
|
||||||
@@ -36,6 +41,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
|||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
action = "streamGenerateContent?alt=sse"
|
action = "streamGenerateContent?alt=sse"
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package gemini
|
|||||||
// https://ai.google.dev/models/gemini
|
// https://ai.google.dev/models/gemini
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
|
"gemini-pro", "gemini-1.0-pro",
|
||||||
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
|
"gemini-1.5-flash", "gemini-1.5-pro",
|
||||||
|
"text-embedding-004", "aqa",
|
||||||
|
"gemini-2.0-flash-exp",
|
||||||
|
"gemini-2.0-flash-thinking-exp",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
@@ -28,6 +29,11 @@ const (
|
|||||||
VisionMaxImageNum = 16
|
VisionMaxImageNum = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var mimeTypeMap = map[string]string{
|
||||||
|
"json_object": "application/json",
|
||||||
|
"text": "text/plain",
|
||||||
|
}
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||||
geminiRequest := ChatRequest{
|
geminiRequest := ChatRequest{
|
||||||
@@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
Threshold: config.GeminiSafetySetting,
|
Threshold: config.GeminiSafetySetting,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||||
|
Threshold: config.GeminiSafetySetting,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
GenerationConfig: ChatGenerationConfig{
|
GenerationConfig: ChatGenerationConfig{
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
@@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
MaxOutputTokens: textRequest.MaxTokens,
|
MaxOutputTokens: textRequest.MaxTokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if textRequest.ResponseFormat != nil {
|
||||||
|
if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
|
||||||
|
geminiRequest.GenerationConfig.ResponseMimeType = mimeType
|
||||||
|
}
|
||||||
|
if textRequest.ResponseFormat.JsonSchema != nil {
|
||||||
|
geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
|
||||||
|
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
|
||||||
|
}
|
||||||
|
}
|
||||||
if textRequest.Tools != nil {
|
if textRequest.Tools != nil {
|
||||||
functions := make([]model.Function, 0, len(textRequest.Tools))
|
functions := make([]model.Function, 0, len(textRequest.Tools))
|
||||||
for _, tool := range textRequest.Tools {
|
for _, tool := range textRequest.Tools {
|
||||||
@@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||||
} else {
|
} else {
|
||||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
var builder strings.Builder
|
||||||
|
for _, part := range candidate.Content.Parts {
|
||||||
|
if i > 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(part.Text)
|
||||||
|
}
|
||||||
|
choice.Message.Content = builder.String()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
choice.Message.Content = ""
|
choice.Message.Content = ""
|
||||||
|
|||||||
@@ -65,10 +65,12 @@ type ChatTools struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatGenerationConfig struct {
|
type ChatGenerationConfig struct {
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||||
TopP float64 `json:"topP,omitempty"`
|
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||||
TopK float64 `json:"topK,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
TopP *float64 `json:"topP,omitempty"`
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
TopK float64 `json:"topK,omitempty"`
|
||||||
StopSequences []string `json:"stopSequences,omitempty"`
|
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||||
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
|
StopSequences []string `json:"stopSequences,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,24 @@ package groq
|
|||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemma-7b-it",
|
"gemma-7b-it",
|
||||||
"llama2-7b-2048",
|
"gemma2-9b-it",
|
||||||
"llama2-70b-4096",
|
"llama-3.1-70b-versatile",
|
||||||
"mixtral-8x7b-32768",
|
"llama-3.1-8b-instant",
|
||||||
"llama3-8b-8192",
|
"llama-3.2-11b-text-preview",
|
||||||
|
"llama-3.2-11b-vision-preview",
|
||||||
|
"llama-3.2-1b-preview",
|
||||||
|
"llama-3.2-3b-preview",
|
||||||
|
"llama-3.2-11b-vision-preview",
|
||||||
|
"llama-3.2-90b-text-preview",
|
||||||
|
"llama-3.2-90b-vision-preview",
|
||||||
|
"llama-guard-3-8b",
|
||||||
"llama3-70b-8192",
|
"llama3-70b-8192",
|
||||||
|
"llama3-8b-8192",
|
||||||
|
"llama3-groq-70b-8192-tool-use-preview",
|
||||||
|
"llama3-groq-8b-8192-tool-use-preview",
|
||||||
|
"llava-v1.5-7b-4096-preview",
|
||||||
|
"mixtral-8x7b-32768",
|
||||||
|
"distil-whisper-large-v3-en",
|
||||||
|
"whisper-large-v3",
|
||||||
|
"whisper-large-v3-turbo",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
|||||||
// https://github.com/ollama/ollama/blob/main/docs/api.md
|
// https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||||
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
|
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
|
||||||
if meta.Mode == relaymode.Embeddings {
|
if meta.Mode == relaymode.Embeddings {
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
|
fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL)
|
||||||
}
|
}
|
||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
FrequencyPenalty: request.FrequencyPenalty,
|
FrequencyPenalty: request.FrequencyPenalty,
|
||||||
PresencePenalty: request.PresencePenalty,
|
PresencePenalty: request.PresencePenalty,
|
||||||
|
NumPredict: request.MaxTokens,
|
||||||
|
NumCtx: request.NumCtx,
|
||||||
},
|
},
|
||||||
Stream: request.Stream,
|
Stream: request.Stream,
|
||||||
}
|
}
|
||||||
@@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
common.SetEventStreamHeaders(c)
|
common.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := strings.TrimPrefix(scanner.Text(), "}")
|
data := scanner.Text()
|
||||||
data = data + "}"
|
if strings.HasPrefix(data, "}") {
|
||||||
|
data = strings.TrimPrefix(data, "}") + "}"
|
||||||
|
}
|
||||||
|
|
||||||
var ollamaResponse ChatResponse
|
var ollamaResponse ChatResponse
|
||||||
err := json.Unmarshal([]byte(data), &ollamaResponse)
|
err := json.Unmarshal([]byte(data), &ollamaResponse)
|
||||||
@@ -157,8 +161,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
|
|
||||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||||
return &EmbeddingRequest{
|
return &EmbeddingRequest{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Prompt: strings.Join(request.ParseInput(), " "),
|
Input: request.ParseInput(),
|
||||||
|
Options: &Options{
|
||||||
|
Seed: int(request.Seed),
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
FrequencyPenalty: request.FrequencyPenalty,
|
||||||
|
PresencePenalty: request.PresencePenalty,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi
|
|||||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||||
Object: "list",
|
Object: "list",
|
||||||
Data: make([]openai.EmbeddingResponseItem, 0, 1),
|
Data: make([]openai.EmbeddingResponseItem, 0, 1),
|
||||||
Model: "text-embedding-v1",
|
Model: response.Model,
|
||||||
Usage: model.Usage{TotalTokens: 0},
|
Usage: model.Usage{TotalTokens: 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
for i, embedding := range response.Embeddings {
|
||||||
Object: `embedding`,
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||||
Index: 0,
|
Object: `embedding`,
|
||||||
Embedding: response.Embedding,
|
Index: i,
|
||||||
})
|
Embedding: embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
return &openAIEmbeddingResponse
|
return &openAIEmbeddingResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Seed int `json:"seed,omitempty"`
|
Seed int `json:"seed,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||||
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@@ -37,11 +39,15 @@ type ChatResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Input []string `json:"input"`
|
||||||
|
// Truncate bool `json:"truncate,omitempty"`
|
||||||
|
Options *Options `json:"options,omitempty"`
|
||||||
|
// KeepAlive string `json:"keep_alive,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
Embedding []float64 `json:"embedding,omitempty"`
|
Model string `json:"model"`
|
||||||
|
Embeddings [][]float64 `json:"embeddings"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
if request.Stream {
|
||||||
|
// always return usage in stream mode
|
||||||
|
if request.StreamOptions == nil {
|
||||||
|
request.StreamOptions = &model.StreamOptions{}
|
||||||
|
}
|
||||||
|
request.StreamOptions.IncludeUsage = true
|
||||||
|
}
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,10 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
|
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
|
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/novita"
|
"github.com/songquanpeng/one-api/relay/adaptor/novita"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
|
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
|
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/xai"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,6 +32,8 @@ var CompatibleChannels = []int{
|
|||||||
channeltype.DeepSeek,
|
channeltype.DeepSeek,
|
||||||
channeltype.TogetherAI,
|
channeltype.TogetherAI,
|
||||||
channeltype.Novita,
|
channeltype.Novita,
|
||||||
|
channeltype.SiliconFlow,
|
||||||
|
channeltype.XAI,
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||||
@@ -60,6 +64,10 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
|||||||
return "doubao", doubao.ModelList
|
return "doubao", doubao.ModelList
|
||||||
case channeltype.Novita:
|
case channeltype.Novita:
|
||||||
return "novita", novita.ModelList
|
return "novita", novita.ModelList
|
||||||
|
case channeltype.SiliconFlow:
|
||||||
|
return "siliconflow", siliconflow.ModelList
|
||||||
|
case channeltype.XAI:
|
||||||
|
return "xai", xai.ModelList
|
||||||
default:
|
default:
|
||||||
return "openai", ModelList
|
return "openai", ModelList
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ var ModelList = []string{
|
|||||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||||
"gpt-4o", "gpt-4o-2024-05-13",
|
"gpt-4o", "gpt-4o-2024-05-13",
|
||||||
|
"gpt-4o-2024-08-06",
|
||||||
|
"chatgpt-4o-latest",
|
||||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||||
@@ -18,4 +20,7 @@ var ModelList = []string{
|
|||||||
"dall-e-2", "dall-e-3",
|
"dall-e-2", "dall-e-3",
|
||||||
"whisper-1",
|
"whisper-1",
|
||||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||||
|
"o1", "o1-2024-12-17",
|
||||||
|
"o1-preview", "o1-preview-2024-09-12",
|
||||||
|
"o1-mini", "o1-mini-2024-09-12",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
|
||||||
usage := &model.Usage{}
|
usage := &model.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
usage.CompletionTokens = CountTokenText(responseText, modelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
|||||||
render.StringData(c, data) // if error happened, pass the data to client
|
render.StringData(c, data) // if error happened, pass the data to client
|
||||||
continue // just ignore the error
|
continue // just ignore the error
|
||||||
}
|
}
|
||||||
if len(streamResponse.Choices) == 0 {
|
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||||
// but for empty choice, we should not pass it to client, this is for azure
|
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
||||||
continue // just ignore empty choice
|
continue // just ignore empty choice
|
||||||
}
|
}
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/model"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||||
|
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
|
||||||
|
|
||||||
Error := model.Error{
|
Error := model.Error{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "one_api_error",
|
Type: "one_api_error",
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ type Prompt struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
Prompt Prompt `json:"prompt"`
|
Prompt Prompt `json:"prompt"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
TopP float64 `json:"topP,omitempty"`
|
TopP *float64 `json:"topP,omitempty"`
|
||||||
TopK int `json:"topK,omitempty"`
|
TopK int `json:"topK,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
|
|||||||
136
relay/adaptor/replicate/adaptor.go
Normal file
136
relay/adaptor/replicate/adaptor.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
meta *meta.Meta
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
|
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
return DrawImageRequest{
|
||||||
|
Input: ImageInput{
|
||||||
|
Steps: 25,
|
||||||
|
Prompt: request.Prompt,
|
||||||
|
Guidance: 3,
|
||||||
|
Seed: int(time.Now().UnixNano()),
|
||||||
|
SafetyTolerance: 5,
|
||||||
|
NImages: 1, // replicate will always return 1 image
|
||||||
|
Width: 1440,
|
||||||
|
Height: 1440,
|
||||||
|
AspectRatio: "1:1",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if !request.Stream {
|
||||||
|
// TODO: support non-stream mode
|
||||||
|
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the prompt from OpenAI messages
|
||||||
|
var promptBuilder strings.Builder
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
switch msgCnt := message.Content.(type) {
|
||||||
|
case string:
|
||||||
|
promptBuilder.WriteString(message.Role)
|
||||||
|
promptBuilder.WriteString(": ")
|
||||||
|
promptBuilder.WriteString(msgCnt)
|
||||||
|
promptBuilder.WriteString("\n")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
replicateRequest := ReplicateChatRequest{
|
||||||
|
Input: ChatInput{
|
||||||
|
Prompt: promptBuilder.String(),
|
||||||
|
MaxTokens: request.MaxTokens,
|
||||||
|
Temperature: 1.0,
|
||||||
|
TopP: 1.0,
|
||||||
|
PresencePenalty: 0.0,
|
||||||
|
FrequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map optional fields
|
||||||
|
if request.Temperature != nil {
|
||||||
|
replicateRequest.Input.Temperature = *request.Temperature
|
||||||
|
}
|
||||||
|
if request.TopP != nil {
|
||||||
|
replicateRequest.Input.TopP = *request.TopP
|
||||||
|
}
|
||||||
|
if request.PresencePenalty != nil {
|
||||||
|
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
|
||||||
|
}
|
||||||
|
if request.FrequencyPenalty != nil {
|
||||||
|
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
|
||||||
|
}
|
||||||
|
if request.MaxTokens > 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = request.MaxTokens
|
||||||
|
} else if request.MaxTokens == 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
return replicateRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
|
a.meta = meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
|
if !slices.Contains(ModelList, meta.OriginModelName) {
|
||||||
|
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||||
|
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
logger.Info(c, "send request to replicate")
|
||||||
|
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
err, usage = ImageHandler(c, resp)
|
||||||
|
case relaymode.ChatCompletions:
|
||||||
|
err, usage = ChatHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return "replicate"
|
||||||
|
}
|
||||||
191
relay/adaptor/replicate/chat.go
Normal file
191
relay/adaptor/replicate/chat.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatHandler(c *gin.Context, resp *http.Response) (
|
||||||
|
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskData.URLs.Stream == "" {
|
||||||
|
return errors.New("stream url is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// request stream url
|
||||||
|
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "chat stream handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxMeta := meta.GetByContext(c)
|
||||||
|
usage = openai.ResponseText2Usage(responseText,
|
||||||
|
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
eventPrefix = "event: "
|
||||||
|
dataPrefix = "data: "
|
||||||
|
done = "[DONE]"
|
||||||
|
)
|
||||||
|
|
||||||
|
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
|
||||||
|
// request stream endpoint
|
||||||
|
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "new request to stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
streamReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
streamReq.Header.Set("Cache-Control", "no-store")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(streamReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "do request to stream")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
common.SetEventStreamHeaders(c)
|
||||||
|
doneRendered := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle comments starting with ':'
|
||||||
|
if strings.HasPrefix(line, ":") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSE fields
|
||||||
|
if strings.HasPrefix(line, eventPrefix) {
|
||||||
|
event := strings.TrimSpace(line[len(eventPrefix):])
|
||||||
|
var data string
|
||||||
|
// Read the following lines to get data and id
|
||||||
|
for scanner.Scan() {
|
||||||
|
nextLine := scanner.Text()
|
||||||
|
if nextLine == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(nextLine, dataPrefix) {
|
||||||
|
data = nextLine[len(dataPrefix):]
|
||||||
|
} else if strings.HasPrefix(nextLine, "id:") {
|
||||||
|
// id = strings.TrimSpace(nextLine[len("id:"):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if event == "output" {
|
||||||
|
render.StringData(c, data)
|
||||||
|
responseText += data
|
||||||
|
} else if event == "done" {
|
||||||
|
render.Done(c)
|
||||||
|
doneRendered = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return "", errors.Wrap(err, "scan stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !doneRendered {
|
||||||
|
render.Done(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseText, nil
|
||||||
|
}
|
||||||
58
relay/adaptor/replicate/constant.go
Normal file
58
relay/adaptor/replicate/constant.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
// ModelList is a list of models that can be used with Replicate.
|
||||||
|
//
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
var ModelList = []string{
|
||||||
|
// -------------------------------------
|
||||||
|
// image model
|
||||||
|
// -------------------------------------
|
||||||
|
"black-forest-labs/flux-1.1-pro",
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra",
|
||||||
|
"black-forest-labs/flux-canny-dev",
|
||||||
|
"black-forest-labs/flux-canny-pro",
|
||||||
|
"black-forest-labs/flux-depth-dev",
|
||||||
|
"black-forest-labs/flux-depth-pro",
|
||||||
|
"black-forest-labs/flux-dev",
|
||||||
|
"black-forest-labs/flux-dev-lora",
|
||||||
|
"black-forest-labs/flux-fill-dev",
|
||||||
|
"black-forest-labs/flux-fill-pro",
|
||||||
|
"black-forest-labs/flux-pro",
|
||||||
|
"black-forest-labs/flux-redux-dev",
|
||||||
|
"black-forest-labs/flux-redux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell-lora",
|
||||||
|
"ideogram-ai/ideogram-v2",
|
||||||
|
"ideogram-ai/ideogram-v2-turbo",
|
||||||
|
"recraft-ai/recraft-v3",
|
||||||
|
"recraft-ai/recraft-v3-svg",
|
||||||
|
"stability-ai/stable-diffusion-3",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo",
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium",
|
||||||
|
// -------------------------------------
|
||||||
|
// language model
|
||||||
|
// -------------------------------------
|
||||||
|
"ibm-granite/granite-20b-code-instruct-8k",
|
||||||
|
"ibm-granite/granite-3.0-2b-instruct",
|
||||||
|
"ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k",
|
||||||
|
"meta/llama-2-13b",
|
||||||
|
"meta/llama-2-13b-chat",
|
||||||
|
"meta/llama-2-70b",
|
||||||
|
"meta/llama-2-70b-chat",
|
||||||
|
"meta/llama-2-7b",
|
||||||
|
"meta/llama-2-7b-chat",
|
||||||
|
"meta/meta-llama-3.1-405b-instruct",
|
||||||
|
"meta/meta-llama-3-70b",
|
||||||
|
"meta/meta-llama-3-70b-instruct",
|
||||||
|
"meta/meta-llama-3-8b",
|
||||||
|
"meta/meta-llama-3-8b-instruct",
|
||||||
|
"mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"mistralai/mistral-7b-v0.1",
|
||||||
|
"mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
|
// -------------------------------------
|
||||||
|
// video model
|
||||||
|
// -------------------------------------
|
||||||
|
// "minimax/video-01", // TODO: implement the adaptor
|
||||||
|
}
|
||||||
222
relay/adaptor/replicate/image.go
Normal file
222
relay/adaptor/replicate/image.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImagesEditsHandler just copy response body to client
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro
|
||||||
|
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
// c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
// for k, v := range resp.Header {
|
||||||
|
// c.Writer.Header().Set(k, v[0])
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
|
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
// }
|
||||||
|
// defer resp.Body.Close()
|
||||||
|
|
||||||
|
// return nil, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
var errNextLoop = errors.New("next_loop")
|
||||||
|
|
||||||
|
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed: %s", taskData.Status)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := taskData.GetOutput()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get output")
|
||||||
|
}
|
||||||
|
if len(output) == 0 {
|
||||||
|
return errors.New("response output is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var pool errgroup.Group
|
||||||
|
respBody := &openai.ImageResponse{
|
||||||
|
Created: taskData.CompletedAt.Unix(),
|
||||||
|
Data: []openai.ImageData{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, imgOut := range output {
|
||||||
|
imgOut := imgOut
|
||||||
|
pool.Go(func() error {
|
||||||
|
// download image
|
||||||
|
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, imgOut, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "download image")
|
||||||
|
}
|
||||||
|
defer imgResp.Body.Close()
|
||||||
|
|
||||||
|
if imgResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(imgResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
imgResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err := io.ReadAll(imgResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read image")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err = ConvertImageToPNG(imgData)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "convert image")
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||||
|
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||||
|
base64.StdEncoding.EncodeToString(imgData)),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.Wait(); err != nil {
|
||||||
|
if len(respBody.Data) == 0 {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, respBody)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageToPNG converts a WebP image to PNG format
|
||||||
|
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||||
|
// bypass if it's already a PNG image
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||||
|
return webpData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if is jpeg, convert to png
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode jpeg")
|
||||||
|
}
|
||||||
|
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the WebP image
|
||||||
|
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode webp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the image as PNG
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
159
relay/adaptor/replicate/model.go
Normal file
159
relay/adaptor/replicate/model.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DrawImageRequest draw image by fluxpro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type DrawImageRequest struct {
|
||||||
|
Input ImageInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
|
||||||
|
type ImageInput struct {
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
ImagePrompt string `json:"image_prompt"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
Interval int `json:"interval" binding:"required,min=1,max=4"`
|
||||||
|
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NImages int `json:"n_images" binding:"required,min=1,max=8"`
|
||||||
|
Width int `json:"width" binding:"required,min=256,max=1440"`
|
||||||
|
Height int `json:"height" binding:"required,min=256,max=1440"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type InpaintingImageByFlusReplicateRequest struct {
|
||||||
|
Input FluxInpaintingInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxInpaintingInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type FluxInpaintingInput struct {
|
||||||
|
Mask string `json:"mask" binding:"required"`
|
||||||
|
Image string `json:"image" binding:"required"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
OutputFormat string `json:"output_format"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
PromptUnsampling bool `json:"prompt_unsampling"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponse is response of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type ImageResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input DrawImageRequest `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output any `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs FluxURLs `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ImageResponse) GetOutput() ([]string, error) {
|
||||||
|
switch v := r.Output.(type) {
|
||||||
|
case string:
|
||||||
|
return []string{v}, nil
|
||||||
|
case []string:
|
||||||
|
return v, nil
|
||||||
|
case nil:
|
||||||
|
return nil, nil
|
||||||
|
case []interface{}:
|
||||||
|
// convert []interface{} to []string
|
||||||
|
ret := make([]string, len(v))
|
||||||
|
for idx, vv := range v {
|
||||||
|
if vvv, ok := vv.(string); ok {
|
||||||
|
ret[idx] = vvv
|
||||||
|
} else {
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxMetrics is metrics of ImageResponse
|
||||||
|
type FluxMetrics struct {
|
||||||
|
ImageCount int `json:"image_count"`
|
||||||
|
PredictTime float64 `json:"predict_time"`
|
||||||
|
TotalTime float64 `json:"total_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxURLs is urls of ImageResponse
|
||||||
|
type FluxURLs struct {
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateChatRequest struct {
|
||||||
|
Input ChatInput `json:"input" form:"input" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatInput is input of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
|
||||||
|
type ChatInput struct {
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
MinTokens int `json:"min_tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
SystemPrompt string `json:"system_prompt"`
|
||||||
|
StopSequences string `json:"stop_sequences"`
|
||||||
|
PromptTemplate string `json:"prompt_template"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse is response of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
|
||||||
|
type ChatResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input ChatInput `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output []string `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs ChatResponseUrl `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponseUrl is task urls of ChatResponse
|
||||||
|
type ChatResponseUrl struct {
|
||||||
|
Stream string `json:"stream"`
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
||||||
36
relay/adaptor/siliconflow/constants.go
Normal file
36
relay/adaptor/siliconflow/constants.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package siliconflow
|
||||||
|
|
||||||
|
// https://docs.siliconflow.cn/docs/getting-started
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"deepseek-ai/deepseek-llm-67b-chat",
|
||||||
|
"Qwen/Qwen1.5-14B-Chat",
|
||||||
|
"Qwen/Qwen1.5-7B-Chat",
|
||||||
|
"Qwen/Qwen1.5-110B-Chat",
|
||||||
|
"Qwen/Qwen1.5-32B-Chat",
|
||||||
|
"01-ai/Yi-1.5-6B-Chat",
|
||||||
|
"01-ai/Yi-1.5-9B-Chat-16K",
|
||||||
|
"01-ai/Yi-1.5-34B-Chat-16K",
|
||||||
|
"THUDM/chatglm3-6b",
|
||||||
|
"deepseek-ai/DeepSeek-V2-Chat",
|
||||||
|
"THUDM/glm-4-9b-chat",
|
||||||
|
"Qwen/Qwen2-72B-Instruct",
|
||||||
|
"Qwen/Qwen2-7B-Instruct",
|
||||||
|
"Qwen/Qwen2-57B-A14B-Instruct",
|
||||||
|
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
||||||
|
"Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
"internlm/internlm2_5-7b-chat",
|
||||||
|
"BAAI/bge-large-en-v1.5",
|
||||||
|
"BAAI/bge-large-zh-v1.5",
|
||||||
|
"Pro/Qwen/Qwen2-7B-Instruct",
|
||||||
|
"Pro/Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
"Pro/Qwen/Qwen1.5-7B-Chat",
|
||||||
|
"Pro/THUDM/glm-4-9b-chat",
|
||||||
|
"Pro/THUDM/chatglm3-6b",
|
||||||
|
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
|
||||||
|
"Pro/01-ai/Yi-1.5-6B-Chat",
|
||||||
|
"Pro/google/gemma-2-9b-it",
|
||||||
|
"Pro/internlm/internlm2_5-7b-chat",
|
||||||
|
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
|
}
|
||||||
@@ -1,7 +1,13 @@
|
|||||||
package stepfun
|
package stepfun
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
|
"step-1-8k",
|
||||||
"step-1-32k",
|
"step-1-32k",
|
||||||
|
"step-1-128k",
|
||||||
|
"step-1-256k",
|
||||||
|
"step-1-flash",
|
||||||
|
"step-2-16k",
|
||||||
|
"step-1v-8k",
|
||||||
"step-1v-32k",
|
"step-1v-32k",
|
||||||
"step-1-200k",
|
"step-1x-medium",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,4 +5,5 @@ var ModelList = []string{
|
|||||||
"hunyuan-standard",
|
"hunyuan-standard",
|
||||||
"hunyuan-standard-256K",
|
"hunyuan-standard-256K",
|
||||||
"hunyuan-pro",
|
"hunyuan-pro",
|
||||||
|
"hunyuan-vision",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
Model: &request.Model,
|
Model: &request.Model,
|
||||||
Stream: &request.Stream,
|
Stream: &request.Stream,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
TopP: &request.TopP,
|
TopP: request.TopP,
|
||||||
Temperature: &request.Temperature,
|
Temperature: request.Temperature,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229",
|
"claude-3-haiku@20240307",
|
||||||
|
"claude-3-sonnet@20240229",
|
||||||
|
"claude-3-opus@20240229",
|
||||||
|
"claude-3-5-sonnet@20240620",
|
||||||
|
"claude-3-5-sonnet-v2@20241022",
|
||||||
|
"claude-3-5-haiku@20241022",
|
||||||
}
|
}
|
||||||
|
|
||||||
const anthropicVersion = "vertex-2023-10-16"
|
const anthropicVersion = "vertex-2023-10-16"
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ type Request struct {
|
|||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
|
"gemini-pro", "gemini-pro-vision",
|
||||||
|
"gemini-1.5-pro-001", "gemini-1.5-flash-001",
|
||||||
|
"gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
||||||
|
"gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp",
|
||||||
}
|
}
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
|
|||||||
5
relay/adaptor/xai/constants.go
Normal file
5
relay/adaptor/xai/constants.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package xai
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"grok-beta",
|
||||||
|
}
|
||||||
@@ -5,6 +5,8 @@ var ModelList = []string{
|
|||||||
"SparkDesk-v1.1",
|
"SparkDesk-v1.1",
|
||||||
"SparkDesk-v2.1",
|
"SparkDesk-v2.1",
|
||||||
"SparkDesk-v3.1",
|
"SparkDesk-v3.1",
|
||||||
|
"SparkDesk-v3.1-128K",
|
||||||
"SparkDesk-v3.5",
|
"SparkDesk-v3.5",
|
||||||
|
"SparkDesk-v3.5-32K",
|
||||||
"SparkDesk-v4.0",
|
"SparkDesk-v4.0",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseAPIVersionByModelName(modelName string) string {
|
func parseAPIVersionByModelName(modelName string) string {
|
||||||
parts := strings.Split(modelName, "-")
|
index := strings.IndexAny(modelName, "-")
|
||||||
if len(parts) == 2 {
|
if index != -1 {
|
||||||
return parts[1]
|
return modelName[index+1:]
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string {
|
|||||||
func apiVersion2domain(apiVersion string) string {
|
func apiVersion2domain(apiVersion string) string {
|
||||||
switch apiVersion {
|
switch apiVersion {
|
||||||
case "v1.1":
|
case "v1.1":
|
||||||
return "general"
|
return "lite"
|
||||||
case "v2.1":
|
case "v2.1":
|
||||||
return "generalv2"
|
return "generalv2"
|
||||||
case "v3.1":
|
case "v3.1":
|
||||||
return "generalv3"
|
return "generalv3"
|
||||||
|
case "v3.1-128K":
|
||||||
|
return "pro-128k"
|
||||||
case "v3.5":
|
case "v3.5":
|
||||||
return "generalv3.5"
|
return "generalv3.5"
|
||||||
|
case "v3.5-32K":
|
||||||
|
return "max-32k"
|
||||||
case "v4.0":
|
case "v4.0":
|
||||||
return "4.0Ultra"
|
return "4.0Ultra"
|
||||||
}
|
}
|
||||||
@@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
|
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
|
||||||
|
var authUrl string
|
||||||
domain := apiVersion2domain(apiVersion)
|
domain := apiVersion2domain(apiVersion)
|
||||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
switch apiVersion {
|
||||||
|
case "v3.1-128K":
|
||||||
|
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret)
|
||||||
|
break
|
||||||
|
case "v3.5-32K":
|
||||||
|
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||||
|
}
|
||||||
return domain, authUrl
|
return domain, authUrl
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ type ChatRequest struct {
|
|||||||
} `json:"header"`
|
} `json:"header"`
|
||||||
Parameter struct {
|
Parameter struct {
|
||||||
Chat struct {
|
Chat struct {
|
||||||
Domain string `json:"domain,omitempty"`
|
Domain string `json:"domain,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Auditing bool `json:"auditing,omitempty"`
|
Auditing bool `json:"auditing,omitempty"`
|
||||||
} `json:"chat"`
|
} `json:"chat"`
|
||||||
} `json:"parameter"`
|
} `json:"parameter"`
|
||||||
Payload struct {
|
Payload struct {
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
|
baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
|
||||||
return baiduEmbeddingRequest, err
|
return baiduEmbeddingRequest, err
|
||||||
default:
|
default:
|
||||||
// TopP (0.0, 1.0)
|
// TopP [0.0, 1.0]
|
||||||
request.TopP = math.Min(0.99, request.TopP)
|
request.TopP = helper.Float64PtrMax(request.TopP, 1)
|
||||||
request.TopP = math.Max(0.01, request.TopP)
|
request.TopP = helper.Float64PtrMin(request.TopP, 0)
|
||||||
|
|
||||||
// Temperature (0.0, 1.0)
|
// Temperature [0.0, 1.0]
|
||||||
request.Temperature = math.Min(0.99, request.Temperature)
|
request.Temperature = helper.Float64PtrMax(request.Temperature, 1)
|
||||||
request.Temperature = math.Max(0.01, request.Temperature)
|
request.Temperature = helper.Float64PtrMin(request.Temperature, 0)
|
||||||
a.SetVersionByModeName(request.Model)
|
a.SetVersionByModeName(request.Model)
|
||||||
if a.APIVersion == "v4" {
|
if a.APIVersion == "v4" {
|
||||||
return request, nil
|
return request, nil
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ type Message struct {
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Prompt []Message `json:"prompt"`
|
Prompt []Message `json:"prompt"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
RequestId string `json:"request_id,omitempty"`
|
RequestId string `json:"request_id,omitempty"`
|
||||||
Incremental bool `json:"incremental,omitempty"`
|
Incremental bool `json:"incremental,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ const (
|
|||||||
DeepL
|
DeepL
|
||||||
VertexAI
|
VertexAI
|
||||||
Proxy
|
Proxy
|
||||||
|
Replicate
|
||||||
|
|
||||||
Dummy // this one is only for count, do not add any channel after this
|
Dummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{
|
|||||||
"720x1280": 1,
|
"720x1280": 1,
|
||||||
"1280x720": 1,
|
"1280x720": 1,
|
||||||
},
|
},
|
||||||
|
"step-1x-medium": {
|
||||||
|
"256x256": 1,
|
||||||
|
"512x512": 1,
|
||||||
|
"768x768": 1,
|
||||||
|
"1024x1024": 1,
|
||||||
|
"1280x800": 1,
|
||||||
|
"800x1280": 1,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var ImageGenerationAmounts = map[string][2]int{
|
var ImageGenerationAmounts = map[string][2]int{
|
||||||
@@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{
|
|||||||
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
|
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
|
||||||
"wanx-v1": {1, 4}, // Ali
|
"wanx-v1": {1, 4}, // Ali
|
||||||
"cogview-3": {1, 1},
|
"cogview-3": {1, 1},
|
||||||
|
"step-1x-medium": {1, 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
var ImagePromptLengthLimitations = map[string]int{
|
var ImagePromptLengthLimitations = map[string]int{
|
||||||
@@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{
|
|||||||
"ali-stable-diffusion-v1.5": 4000,
|
"ali-stable-diffusion-v1.5": 4000,
|
||||||
"wanx-v1": 4000,
|
"wanx-v1": 4000,
|
||||||
"cogview-3": 833,
|
"cogview-3": 833,
|
||||||
|
"step-1x-medium": 4000,
|
||||||
}
|
}
|
||||||
|
|
||||||
var ImageOriginModelName = map[string]string{
|
var ImageOriginModelName = map[string]string{
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ var ModelRatio = map[string]float64{
|
|||||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4o": 2.5, // $0.005 / 1K tokens
|
"gpt-4o": 2.5, // $0.005 / 1K tokens
|
||||||
|
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
|
||||||
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
||||||
|
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
|
||||||
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
||||||
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
|
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
|
||||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||||
@@ -46,8 +48,14 @@ var ModelRatio = map[string]float64{
|
|||||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||||
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
|
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
|
||||||
"davinci-002": 1, // $0.002 / 1K tokens
|
"o1": 7.5, // $15.00 / 1M input tokens
|
||||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
"o1-2024-12-17": 7.5,
|
||||||
|
"o1-preview": 7.5, // $15.00 / 1M input tokens
|
||||||
|
"o1-preview-2024-09-12": 7.5,
|
||||||
|
"o1-mini": 1.5, // $3.00 / 1M input tokens
|
||||||
|
"o1-mini-2024-09-12": 1.5,
|
||||||
|
"davinci-002": 1, // $0.002 / 1K tokens
|
||||||
|
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||||
"text-ada-001": 0.2,
|
"text-ada-001": 0.2,
|
||||||
"text-babbage-001": 0.25,
|
"text-babbage-001": 0.25,
|
||||||
"text-curie-001": 1,
|
"text-curie-001": 1,
|
||||||
@@ -77,8 +85,10 @@ var ModelRatio = map[string]float64{
|
|||||||
"claude-2.0": 8.0 / 1000 * USD,
|
"claude-2.0": 8.0 / 1000 * USD,
|
||||||
"claude-2.1": 8.0 / 1000 * USD,
|
"claude-2.1": 8.0 / 1000 * USD,
|
||||||
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
|
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
|
||||||
|
"claude-3-5-haiku-20241022": 1.0 / 1000 * USD,
|
||||||
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
|
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
|
||||||
"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
|
"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
|
||||||
|
"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD,
|
||||||
"claude-3-opus-20240229": 15.0 / 1000 * USD,
|
"claude-3-opus-20240229": 15.0 / 1000 * USD,
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
|
||||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||||
@@ -98,12 +108,15 @@ var ModelRatio = map[string]float64{
|
|||||||
"bge-large-en": 0.002 * RMB,
|
"bge-large-en": 0.002 * RMB,
|
||||||
"tao-8k": 0.002 * RMB,
|
"tao-8k": 0.002 * RMB,
|
||||||
// https://ai.google.dev/pricing
|
// https://ai.google.dev/pricing
|
||||||
"PaLM-2": 1,
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-1.0-pro": 1,
|
||||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-1.5-pro": 1,
|
||||||
"gemini-1.0-pro-vision-001": 1,
|
"gemini-1.5-pro-001": 1,
|
||||||
"gemini-1.0-pro-001": 1,
|
"gemini-1.5-flash": 1,
|
||||||
"gemini-1.5-pro": 1,
|
"gemini-1.5-flash-001": 1,
|
||||||
|
"gemini-2.0-flash-exp": 1,
|
||||||
|
"gemini-2.0-flash-thinking-exp": 1,
|
||||||
|
"aqa": 1,
|
||||||
// https://open.bigmodel.cn/pricing
|
// https://open.bigmodel.cn/pricing
|
||||||
"glm-4": 0.1 * RMB,
|
"glm-4": 0.1 * RMB,
|
||||||
"glm-4v": 0.1 * RMB,
|
"glm-4v": 0.1 * RMB,
|
||||||
@@ -115,27 +128,94 @@ var ModelRatio = map[string]float64{
|
|||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"cogview-3": 0.25 * RMB,
|
"cogview-3": 0.25 * RMB,
|
||||||
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||||
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
|
"qwen-turbo": 1.4286, // ¥0.02 / 1k tokens
|
||||||
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-turbo-latest": 1.4286,
|
||||||
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-plus": 1.4286,
|
||||||
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-plus-latest": 1.4286,
|
||||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
"qwen-max": 1.4286,
|
||||||
"ali-stable-diffusion-xl": 8,
|
"qwen-max-latest": 1.4286,
|
||||||
"ali-stable-diffusion-v1.5": 8,
|
"qwen-max-longcontext": 1.4286,
|
||||||
"wanx-v1": 8,
|
"qwen-vl-max": 1.4286,
|
||||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-max-latest": 1.4286,
|
||||||
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-plus": 1.4286,
|
||||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-plus-latest": 1.4286,
|
||||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-ocr": 1.4286,
|
||||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-ocr-latest": 1.4286,
|
||||||
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-audio-turbo": 1.4286,
|
||||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
"qwen-math-plus": 1.4286,
|
||||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
"qwen-math-plus-latest": 1.4286,
|
||||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"qwen-math-turbo": 1.4286,
|
||||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"qwen-math-turbo-latest": 1.4286,
|
||||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
"qwen-coder-plus": 1.4286,
|
||||||
"ChatStd": 0.01 * RMB,
|
"qwen-coder-plus-latest": 1.4286,
|
||||||
"ChatPro": 0.1 * RMB,
|
"qwen-coder-turbo": 1.4286,
|
||||||
|
"qwen-coder-turbo-latest": 1.4286,
|
||||||
|
"qwq-32b-preview": 1.4286,
|
||||||
|
"qwen2.5-72b-instruct": 1.4286,
|
||||||
|
"qwen2.5-32b-instruct": 1.4286,
|
||||||
|
"qwen2.5-14b-instruct": 1.4286,
|
||||||
|
"qwen2.5-7b-instruct": 1.4286,
|
||||||
|
"qwen2.5-3b-instruct": 1.4286,
|
||||||
|
"qwen2.5-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2.5-0.5b-instruct": 1.4286,
|
||||||
|
"qwen2-72b-instruct": 1.4286,
|
||||||
|
"qwen2-57b-a14b-instruct": 1.4286,
|
||||||
|
"qwen2-7b-instruct": 1.4286,
|
||||||
|
"qwen2-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2-0.5b-instruct": 1.4286,
|
||||||
|
"qwen1.5-110b-chat": 1.4286,
|
||||||
|
"qwen1.5-72b-chat": 1.4286,
|
||||||
|
"qwen1.5-32b-chat": 1.4286,
|
||||||
|
"qwen1.5-14b-chat": 1.4286,
|
||||||
|
"qwen1.5-7b-chat": 1.4286,
|
||||||
|
"qwen1.5-1.8b-chat": 1.4286,
|
||||||
|
"qwen1.5-0.5b-chat": 1.4286,
|
||||||
|
"qwen-72b-chat": 1.4286,
|
||||||
|
"qwen-14b-chat": 1.4286,
|
||||||
|
"qwen-7b-chat": 1.4286,
|
||||||
|
"qwen-1.8b-chat": 1.4286,
|
||||||
|
"qwen-1.8b-longcontext-chat": 1.4286,
|
||||||
|
"qwen2-vl-7b-instruct": 1.4286,
|
||||||
|
"qwen2-vl-2b-instruct": 1.4286,
|
||||||
|
"qwen-vl-v1": 1.4286,
|
||||||
|
"qwen-vl-chat-v1": 1.4286,
|
||||||
|
"qwen2-audio-instruct": 1.4286,
|
||||||
|
"qwen-audio-chat": 1.4286,
|
||||||
|
"qwen2.5-math-72b-instruct": 1.4286,
|
||||||
|
"qwen2.5-math-7b-instruct": 1.4286,
|
||||||
|
"qwen2.5-math-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2-math-72b-instruct": 1.4286,
|
||||||
|
"qwen2-math-7b-instruct": 1.4286,
|
||||||
|
"qwen2-math-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-32b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-14b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-7b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-3b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-0.5b-instruct": 1.4286,
|
||||||
|
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||||
|
"text-embedding-v3": 0.05,
|
||||||
|
"text-embedding-v2": 0.05,
|
||||||
|
"text-embedding-async-v2": 0.05,
|
||||||
|
"text-embedding-async-v1": 0.05,
|
||||||
|
"ali-stable-diffusion-xl": 8.00,
|
||||||
|
"ali-stable-diffusion-v1.5": 8.00,
|
||||||
|
"wanx-v1": 8.00,
|
||||||
|
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||||
|
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
|
"ChatStd": 0.01 * RMB,
|
||||||
|
"ChatPro": 0.1 * RMB,
|
||||||
// https://platform.moonshot.cn/pricing
|
// https://platform.moonshot.cn/pricing
|
||||||
"moonshot-v1-8k": 0.012 * RMB,
|
"moonshot-v1-8k": 0.012 * RMB,
|
||||||
"moonshot-v1-32k": 0.024 * RMB,
|
"moonshot-v1-32k": 0.024 * RMB,
|
||||||
@@ -158,20 +238,35 @@ var ModelRatio = map[string]float64{
|
|||||||
"mistral-large-latest": 8.0 / 1000 * USD,
|
"mistral-large-latest": 8.0 / 1000 * USD,
|
||||||
"mistral-embed": 0.1 / 1000 * USD,
|
"mistral-embed": 0.1 / 1000 * USD,
|
||||||
// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
|
// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
|
||||||
"llama3-70b-8192": 0.59 / 1000 * USD,
|
"gemma-7b-it": 0.07 / 1000000 * USD,
|
||||||
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
|
"gemma2-9b-it": 0.20 / 1000000 * USD,
|
||||||
"llama3-8b-8192": 0.05 / 1000 * USD,
|
"llama-3.1-70b-versatile": 0.59 / 1000000 * USD,
|
||||||
"gemma-7b-it": 0.1 / 1000 * USD,
|
"llama-3.1-8b-instant": 0.05 / 1000000 * USD,
|
||||||
"llama2-70b-4096": 0.64 / 1000 * USD,
|
"llama-3.2-11b-text-preview": 0.05 / 1000000 * USD,
|
||||||
"llama2-7b-2048": 0.1 / 1000 * USD,
|
"llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD,
|
||||||
|
"llama-3.2-1b-preview": 0.05 / 1000000 * USD,
|
||||||
|
"llama-3.2-3b-preview": 0.05 / 1000000 * USD,
|
||||||
|
"llama-3.2-90b-text-preview": 0.59 / 1000000 * USD,
|
||||||
|
"llama-guard-3-8b": 0.05 / 1000000 * USD,
|
||||||
|
"llama3-70b-8192": 0.59 / 1000000 * USD,
|
||||||
|
"llama3-8b-8192": 0.05 / 1000000 * USD,
|
||||||
|
"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
|
||||||
|
"llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD,
|
||||||
|
"mixtral-8x7b-32768": 0.24 / 1000000 * USD,
|
||||||
|
|
||||||
// https://platform.lingyiwanwu.com/docs#-计费单元
|
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||||
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
|
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
|
||||||
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
|
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
|
||||||
"yi-vl-plus": 6.0 / 1000 * RMB,
|
"yi-vl-plus": 6.0 / 1000 * RMB,
|
||||||
// stepfun todo
|
// https://platform.stepfun.com/docs/pricing/details
|
||||||
"step-1v-32k": 0.024 * RMB,
|
"step-1-8k": 0.005 / 1000 * RMB,
|
||||||
"step-1-32k": 0.024 * RMB,
|
"step-1-32k": 0.015 / 1000 * RMB,
|
||||||
"step-1-200k": 0.15 * RMB,
|
"step-1-128k": 0.040 / 1000 * RMB,
|
||||||
|
"step-1-256k": 0.095 / 1000 * RMB,
|
||||||
|
"step-1-flash": 0.001 / 1000 * RMB,
|
||||||
|
"step-2-16k": 0.038 / 1000 * RMB,
|
||||||
|
"step-1v-8k": 0.005 / 1000 * RMB,
|
||||||
|
"step-1v-32k": 0.015 / 1000 * RMB,
|
||||||
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
|
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
|
||||||
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
|
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
|
||||||
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
|
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
|
||||||
@@ -189,6 +284,52 @@ var ModelRatio = map[string]float64{
|
|||||||
"deepl-zh": 25.0 / 1000 * USD,
|
"deepl-zh": 25.0 / 1000 * USD,
|
||||||
"deepl-en": 25.0 / 1000 * USD,
|
"deepl-en": 25.0 / 1000 * USD,
|
||||||
"deepl-ja": 25.0 / 1000 * USD,
|
"deepl-ja": 25.0 / 1000 * USD,
|
||||||
|
// https://console.x.ai/
|
||||||
|
"grok-beta": 5.0 / 1000 * USD,
|
||||||
|
// replicate charges based on the number of generated images
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
|
||||||
|
"black-forest-labs/flux-canny-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-canny-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-depth-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-depth-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-dev-lora": 0.032 * USD,
|
||||||
|
"black-forest-labs/flux-fill-dev": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-fill-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-pro": 0.055 * USD,
|
||||||
|
"black-forest-labs/flux-redux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2": 0.08 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
|
||||||
|
"recraft-ai/recraft-v3": 0.04 * USD,
|
||||||
|
"recraft-ai/recraft-v3-svg": 0.08 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3": 0.035 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||||
|
// replicate chat models
|
||||||
|
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
|
||||||
|
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
|
||||||
|
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
|
||||||
|
"meta/llama-2-13b": 0.100 * USD,
|
||||||
|
"meta/llama-2-13b-chat": 0.100 * USD,
|
||||||
|
"meta/llama-2-70b": 0.650 * USD,
|
||||||
|
"meta/llama-2-70b-chat": 0.650 * USD,
|
||||||
|
"meta/llama-2-7b": 0.050 * USD,
|
||||||
|
"meta/llama-2-7b-chat": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
|
||||||
|
"meta/meta-llama-3-70b": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-8b": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-v0.1": 0.050 * USD,
|
||||||
|
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
|
||||||
}
|
}
|
||||||
|
|
||||||
var CompletionRatio = map[string]float64{
|
var CompletionRatio = map[string]float64{
|
||||||
@@ -197,8 +338,10 @@ var CompletionRatio = map[string]float64{
|
|||||||
"llama3-70b-8192(33)": 0.0035 / 0.00265,
|
"llama3-70b-8192(33)": 0.0035 / 0.00265,
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultModelRatio map[string]float64
|
var (
|
||||||
var DefaultCompletionRatio map[string]float64
|
DefaultModelRatio map[string]float64
|
||||||
|
DefaultCompletionRatio map[string]float64
|
||||||
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultModelRatio = make(map[string]float64)
|
DefaultModelRatio = make(map[string]float64)
|
||||||
@@ -310,7 +453,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
return 4.0 / 3.0
|
return 4.0 / 3.0
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4") {
|
if strings.HasPrefix(name, "gpt-4") {
|
||||||
if strings.HasPrefix(name, "gpt-4o-mini") {
|
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
|
||||||
return 4
|
return 4
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4-turbo") ||
|
if strings.HasPrefix(name, "gpt-4-turbo") ||
|
||||||
@@ -320,6 +463,13 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
}
|
}
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
// including o1, o1-preview, o1-mini
|
||||||
|
if strings.HasPrefix(name, "o1") {
|
||||||
|
return 4
|
||||||
|
}
|
||||||
|
if name == "chatgpt-4o-latest" {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
if strings.HasPrefix(name, "claude-3") {
|
if strings.HasPrefix(name, "claude-3") {
|
||||||
return 5
|
return 5
|
||||||
}
|
}
|
||||||
@@ -335,6 +485,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
if strings.HasPrefix(name, "deepseek-") {
|
if strings.HasPrefix(name, "deepseek-") {
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
switch name {
|
switch name {
|
||||||
case "llama2-70b-4096":
|
case "llama2-70b-4096":
|
||||||
return 0.8 / 0.64
|
return 0.8 / 0.64
|
||||||
@@ -348,6 +499,37 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
return 3
|
return 3
|
||||||
case "command-r-plus":
|
case "command-r-plus":
|
||||||
return 5
|
return 5
|
||||||
|
case "grok-beta":
|
||||||
|
return 3
|
||||||
|
// Replicate Models
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
case "ibm-granite/granite-20b-code-instruct-8k":
|
||||||
|
return 5
|
||||||
|
case "ibm-granite/granite-3.0-2b-instruct":
|
||||||
|
return 8.333333333333334
|
||||||
|
case "ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-13b",
|
||||||
|
"meta/llama-2-13b-chat",
|
||||||
|
"meta/llama-2-7b",
|
||||||
|
"meta/llama-2-7b-chat",
|
||||||
|
"meta/meta-llama-3-8b",
|
||||||
|
"meta/meta-llama-3-8b-instruct":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-70b",
|
||||||
|
"meta/llama-2-70b-chat",
|
||||||
|
"meta/meta-llama-3-70b",
|
||||||
|
"meta/meta-llama-3-70b-instruct":
|
||||||
|
return 2.750 / 0.650 // ≈4.230769
|
||||||
|
case "meta/meta-llama-3.1-405b-instruct":
|
||||||
|
return 1
|
||||||
|
case "mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"mistralai/mistral-7b-v0.1":
|
||||||
|
return 5
|
||||||
|
case "mistralai/mixtral-8x7b-instruct-v0.1":
|
||||||
|
return 1.000 / 0.300 // ≈3.333333
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,5 +45,8 @@ const (
|
|||||||
Novita
|
Novita
|
||||||
VertextAI
|
VertextAI
|
||||||
Proxy
|
Proxy
|
||||||
|
SiliconFlow
|
||||||
|
XAI
|
||||||
|
Replicate
|
||||||
Dummy
|
Dummy
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
|
|||||||
apiType = apitype.DeepL
|
apiType = apitype.DeepL
|
||||||
case VertextAI:
|
case VertextAI:
|
||||||
apiType = apitype.VertexAI
|
apiType = apitype.VertexAI
|
||||||
|
case Replicate:
|
||||||
|
apiType = apitype.Replicate
|
||||||
case Proxy:
|
case Proxy:
|
||||||
apiType = apitype.Proxy
|
apiType = apitype.Proxy
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.novita.ai/v3/openai", // 41
|
"https://api.novita.ai/v3/openai", // 41
|
||||||
"", // 42
|
"", // 42
|
||||||
"", // 43
|
"", // 43
|
||||||
|
"https://api.siliconflow.cn", // 44
|
||||||
|
"https://api.x.ai", // 45
|
||||||
|
"https://api.replicate.com/v1/models/", // 46
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package role
|
package role
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
System = "system"
|
||||||
Assistant = "assistant"
|
Assistant = "assistant"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/relay/constant/role"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -90,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
|
|||||||
return preConsumedQuota, nil
|
return preConsumedQuota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
|
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||||
return
|
return
|
||||||
@@ -118,7 +119,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
var extraLog string
|
||||||
|
if systemPromptReset {
|
||||||
|
extraLog = " (注意系统提示词已被重置)"
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog)
|
||||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||||
@@ -142,15 +147,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK &&
|
||||||
|
// replicate return 201 to create a task
|
||||||
|
resp.StatusCode != http.StatusCreated {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if meta.ChannelType == channeltype.DeepL {
|
if meta.ChannelType == channeltype.DeepL {
|
||||||
// skip stream check for deepl
|
// skip stream check for deepl
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
|
|
||||||
|
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
|
||||||
|
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
|
||||||
|
// requiring the client to request the stream endpoint in the task info
|
||||||
|
meta.ChannelType != channeltype.Replicate {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
|
||||||
|
if prompt == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(request.Messages) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if request.Messages[0].Role == role.System {
|
||||||
|
request.Messages[0].Content = prompt
|
||||||
|
logger.Infof(ctx, "rewrite system prompt")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
request.Messages = append([]relaymodel.Message{{
|
||||||
|
Role: role.System,
|
||||||
|
Content: prompt,
|
||||||
|
}}, request.Messages...)
|
||||||
|
logger.Infof(ctx, "add system prompt")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
|
||||||
imageRequest := &relaymodel.ImageRequest{}
|
imageRequest := &relaymodel.ImageRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||||
// check prompt length
|
// check prompt length
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||||
@@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
|
|
||||||
|
// these adaptors need to convert the request
|
||||||
switch meta.ChannelType {
|
switch meta.ChannelType {
|
||||||
case channeltype.Ali:
|
case channeltype.Zhipu,
|
||||||
fallthrough
|
channeltype.Ali,
|
||||||
case channeltype.Baidu:
|
channeltype.Replicate,
|
||||||
fallthrough
|
channeltype.Baidu:
|
||||||
case channeltype.Zhipu:
|
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
@@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||||
|
|
||||||
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
var quota int64
|
||||||
|
switch meta.ChannelType {
|
||||||
|
case channeltype.Replicate:
|
||||||
|
// replicate always return 1 image
|
||||||
|
quota = int64(ratio * imageCostRatio * 1000)
|
||||||
|
default:
|
||||||
|
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||||
|
}
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
@@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
if resp != nil &&
|
||||||
|
resp.StatusCode != http.StatusCreated && // replicate returns 201
|
||||||
|
resp.StatusCode != http.StatusOK {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@@ -35,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
meta.OriginModelName = textRequest.Model
|
meta.OriginModelName = textRequest.Model
|
||||||
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||||
meta.ActualModelName = textRequest.Model
|
meta.ActualModelName = textRequest.Model
|
||||||
|
// set system prompt if not empty
|
||||||
|
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
|
||||||
// get model ratio & group ratio
|
// get model ratio & group ratio
|
||||||
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
|
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
|
||||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||||
@@ -79,12 +82,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
return respErr
|
return respErr
|
||||||
}
|
}
|
||||||
// post-consume quota
|
// post-consume quota
|
||||||
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
|
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
||||||
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
|
if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
|
||||||
// no need to convert request for openai
|
// no need to convert request for openai
|
||||||
return c.Request.Body, nil
|
return c.Request.Body, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ type Meta struct {
|
|||||||
ActualModelName string
|
ActualModelName string
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
PromptTokens int // only for DoResponse
|
PromptTokens int // only for DoResponse
|
||||||
|
SystemPrompt string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetByContext(c *gin.Context) *Meta {
|
func GetByContext(c *gin.Context) *Meta {
|
||||||
@@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta {
|
|||||||
BaseURL: c.GetString(ctxkey.BaseURL),
|
BaseURL: c.GetString(ctxkey.BaseURL),
|
||||||
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
|
SystemPrompt: c.GetString(ctxkey.SystemPrompt),
|
||||||
}
|
}
|
||||||
cfg, ok := c.Get(ctxkey.Config)
|
cfg, ok := c.Get(ctxkey.Config)
|
||||||
if ok {
|
if ok {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ContentTypeText = "text"
|
ContentTypeText = "text"
|
||||||
ContentTypeImageURL = "image_url"
|
ContentTypeImageURL = "image_url"
|
||||||
|
ContentTypeInputAudio = "input_audio"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,34 +1,70 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
|
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JSONSchema struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Schema map[string]interface{} `json:"schema,omitempty"`
|
||||||
|
Strict *bool `json:"strict,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Audio struct {
|
||||||
|
Voice string `json:"voice,omitempty"`
|
||||||
|
Format string `json:"format,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
Messages []Message `json:"messages,omitempty"`
|
// https://platform.openai.com/docs/api-reference/chat/create
|
||||||
Model string `json:"model,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
Store *bool `json:"store,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
Metadata any `json:"metadata,omitempty"`
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
LogitBias any `json:"logit_bias,omitempty"`
|
||||||
Seed float64 `json:"seed,omitempty"`
|
Logprobs *bool `json:"logprobs,omitempty"`
|
||||||
Stop any `json:"stop,omitempty"`
|
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
Modalities []string `json:"modalities,omitempty"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Prediction any `json:"prediction,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
Audio *Audio `json:"audio,omitempty"`
|
||||||
FunctionCall any `json:"function_call,omitempty"`
|
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||||
Functions any `json:"functions,omitempty"`
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
Seed float64 `json:"seed,omitempty"`
|
||||||
Prompt any `json:"prompt,omitempty"`
|
ServiceTier *string `json:"service_tier,omitempty"`
|
||||||
Input any `json:"input,omitempty"`
|
Stop any `json:"stop,omitempty"`
|
||||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Dimensions int `json:"dimensions,omitempty"`
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
Instruction string `json:"instruction,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
FunctionCall any `json:"function_call,omitempty"`
|
||||||
|
Functions any `json:"functions,omitempty"`
|
||||||
|
// https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
|
Input any `json:"input,omitempty"`
|
||||||
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
|
// https://platform.openai.com/docs/api-reference/images/create
|
||||||
|
Prompt any `json:"prompt,omitempty"`
|
||||||
|
Quality *string `json:"quality,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Style *string `json:"style,omitempty"`
|
||||||
|
// Others
|
||||||
|
Instruction string `json:"instruction,omitempty"`
|
||||||
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
|
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
|
||||||
|
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth)
|
||||||
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
|
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
|
||||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
|
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
|
||||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
|
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
func SetRelayRouter(router *gin.Engine) {
|
func SetRelayRouter(router *gin.Engine) {
|
||||||
router.Use(middleware.CORS())
|
router.Use(middleware.CORS())
|
||||||
|
router.Use(middleware.GzipDecodeMiddleware())
|
||||||
// https://platform.openai.com/docs/api-reference/introduction
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
modelsRouter := router.Group("/v1/models")
|
modelsRouter := router.Group("/v1/models")
|
||||||
modelsRouter.Use(middleware.TokenAuth())
|
modelsRouter.Use(middleware.TokenAuth())
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken';
|
|||||||
const COPY_OPTIONS = [
|
const COPY_OPTIONS = [
|
||||||
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
|
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
|
||||||
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
||||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
|
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||||
|
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const OPEN_LINK_OPTIONS = [
|
const OPEN_LINK_OPTIONS = [
|
||||||
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
||||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
|
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||||
|
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' }
|
||||||
];
|
];
|
||||||
|
|
||||||
function renderTimestamp(timestamp) {
|
function renderTimestamp(timestamp) {
|
||||||
@@ -60,7 +62,12 @@ const TokensTable = () => {
|
|||||||
onOpenLink('next-mj');
|
onOpenLink('next-mj');
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }
|
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' },
|
||||||
|
{
|
||||||
|
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
|
||||||
|
onOpenLink('lobechat');
|
||||||
|
}
|
||||||
|
}
|
||||||
];
|
];
|
||||||
|
|
||||||
const columns = [
|
const columns = [
|
||||||
@@ -177,6 +184,11 @@ const TokensTable = () => {
|
|||||||
node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => {
|
node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => {
|
||||||
onOpenLink('opencat', record.key);
|
onOpenLink('opencat', record.key);
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
|
||||||
|
onOpenLink('lobechat');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -382,6 +394,9 @@ const TokensTable = () => {
|
|||||||
case 'next-mj':
|
case 'next-mj':
|
||||||
url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
break;
|
break;
|
||||||
|
case 'lobechat':
|
||||||
|
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
if (!chatLink) {
|
if (!chatLink) {
|
||||||
showError('管理员未设置聊天链接');
|
showError('管理员未设置聊天链接');
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 39, text: 'together.ai', value: 39, color: 'blue' },
|
{ key: 39, text: 'together.ai', value: 39, color: 'blue' },
|
||||||
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
|
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
|
||||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||||
|
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||||
|
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||||
|
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ const EditChannel = (props) => {
|
|||||||
base_url: '',
|
base_url: '',
|
||||||
other: '',
|
other: '',
|
||||||
model_mapping: '',
|
model_mapping: '',
|
||||||
|
system_prompt: '',
|
||||||
models: [],
|
models: [],
|
||||||
auto_ban: 1,
|
auto_ban: 1,
|
||||||
groups: ['default']
|
groups: ['default']
|
||||||
@@ -63,7 +64,7 @@ const EditChannel = (props) => {
|
|||||||
let localModels = [];
|
let localModels = [];
|
||||||
switch (value) {
|
switch (value) {
|
||||||
case 14:
|
case 14:
|
||||||
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
|
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"];
|
||||||
break;
|
break;
|
||||||
case 11:
|
case 11:
|
||||||
localModels = ['PaLM-2'];
|
localModels = ['PaLM-2'];
|
||||||
@@ -78,7 +79,7 @@ const EditChannel = (props) => {
|
|||||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||||
break;
|
break;
|
||||||
case 18:
|
case 18:
|
||||||
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
|
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0'];
|
||||||
break;
|
break;
|
||||||
case 19:
|
case 19:
|
||||||
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
|
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
|
||||||
@@ -304,163 +305,163 @@ const EditChannel = (props) => {
|
|||||||
width={isMobile() ? '100%' : 600}
|
width={isMobile() ? '100%' : 600}
|
||||||
>
|
>
|
||||||
<Spin spinning={loading}>
|
<Spin spinning={loading}>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>类型:</Typography.Text>
|
<Typography.Text strong>类型:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Select
|
<Select
|
||||||
name='type'
|
name='type'
|
||||||
required
|
required
|
||||||
optionList={CHANNEL_OPTIONS}
|
optionList={CHANNEL_OPTIONS}
|
||||||
value={inputs.type}
|
value={inputs.type}
|
||||||
onChange={value => handleInputChange('type', value)}
|
onChange={value => handleInputChange('type', value)}
|
||||||
style={{width: '50%'}}
|
style={{ width: '50%' }}
|
||||||
/>
|
/>
|
||||||
{
|
{
|
||||||
inputs.type === 3 && (
|
inputs.type === 3 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Banner type={"warning"} description={
|
<Banner type={"warning"} description={
|
||||||
<>
|
<>
|
||||||
注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的
|
注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的
|
||||||
model
|
model
|
||||||
参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
|
参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
|
||||||
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
|
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
|
||||||
</>
|
</>
|
||||||
}>
|
}>
|
||||||
</Banner>
|
</Banner>
|
||||||
</div>
|
</div>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text>
|
<Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='AZURE_OPENAI_ENDPOINT'
|
label='AZURE_OPENAI_ENDPOINT'
|
||||||
name='azure_base_url'
|
name='azure_base_url'
|
||||||
placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'}
|
placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('base_url', value)
|
handleInputChange('base_url', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.base_url}
|
value={inputs.base_url}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>默认 API 版本:</Typography.Text>
|
<Typography.Text strong>默认 API 版本:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='默认 API 版本'
|
label='默认 API 版本'
|
||||||
name='azure_other'
|
name='azure_other'
|
||||||
placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'}
|
placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('other', value)
|
handleInputChange('other', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.other}
|
value={inputs.other}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type === 8 && (
|
inputs.type === 8 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>Base URL:</Typography.Text>
|
<Typography.Text strong>Base URL:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
name='base_url'
|
name='base_url'
|
||||||
placeholder={'请输入自定义渠道的 Base URL'}
|
placeholder={'请输入自定义渠道的 Base URL'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('base_url', value)
|
handleInputChange('base_url', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.base_url}
|
value={inputs.base_url}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>名称:</Typography.Text>
|
<Typography.Text strong>名称:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
required
|
required
|
||||||
name='name'
|
name='name'
|
||||||
placeholder={'请为渠道命名'}
|
placeholder={'请为渠道命名'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('name', value)
|
handleInputChange('name', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.name}
|
value={inputs.name}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>分组:</Typography.Text>
|
<Typography.Text strong>分组:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Select
|
<Select
|
||||||
placeholder={'请选择可以使用该渠道的分组'}
|
placeholder={'请选择可以使用该渠道的分组'}
|
||||||
name='groups'
|
name='groups'
|
||||||
required
|
required
|
||||||
multiple
|
multiple
|
||||||
selection
|
selection
|
||||||
allowAdditions
|
allowAdditions
|
||||||
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
|
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('groups', value)
|
handleInputChange('groups', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.groups}
|
value={inputs.groups}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
optionList={groupOptions}
|
optionList={groupOptions}
|
||||||
/>
|
/>
|
||||||
{
|
{
|
||||||
inputs.type === 18 && (
|
inputs.type === 18 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型版本:</Typography.Text>
|
<Typography.Text strong>模型版本:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
name='other'
|
name='other'
|
||||||
placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
|
placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('other', value)
|
handleInputChange('other', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.other}
|
value={inputs.other}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type === 21 && (
|
inputs.type === 21 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>知识库 ID:</Typography.Text>
|
<Typography.Text strong>知识库 ID:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='知识库 ID'
|
label='知识库 ID'
|
||||||
name='other'
|
name='other'
|
||||||
placeholder={'请输入知识库 ID,例如:123456'}
|
placeholder={'请输入知识库 ID,例如:123456'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('other', value)
|
handleInputChange('other', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.other}
|
value={inputs.other}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型:</Typography.Text>
|
<Typography.Text strong>模型:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Select
|
<Select
|
||||||
placeholder={'请选择该渠道所支持的模型'}
|
placeholder={'请选择该渠道所支持的模型'}
|
||||||
name='models'
|
name='models'
|
||||||
required
|
required
|
||||||
multiple
|
multiple
|
||||||
selection
|
selection
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('models', value)
|
handleInputChange('models', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.models}
|
value={inputs.models}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
optionList={modelOptions}
|
optionList={modelOptions}
|
||||||
/>
|
/>
|
||||||
<div style={{lineHeight: '40px', marginBottom: '12px'}}>
|
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
|
||||||
<Space>
|
<Space>
|
||||||
<Button type='primary' onClick={() => {
|
<Button type='primary' onClick={() => {
|
||||||
handleInputChange('models', basicModels);
|
handleInputChange('models', basicModels);
|
||||||
@@ -473,28 +474,41 @@ const EditChannel = (props) => {
|
|||||||
}}>清除所有模型</Button>
|
}}>清除所有模型</Button>
|
||||||
</Space>
|
</Space>
|
||||||
<Input
|
<Input
|
||||||
addonAfter={
|
addonAfter={
|
||||||
<Button type='primary' onClick={addCustomModel}>填入</Button>
|
<Button type='primary' onClick={addCustomModel}>填入</Button>
|
||||||
}
|
}
|
||||||
placeholder='输入自定义模型名称'
|
placeholder='输入自定义模型名称'
|
||||||
value={customModel}
|
value={customModel}
|
||||||
onChange={(value) => {
|
onChange={(value) => {
|
||||||
setCustomModel(value.trim());
|
setCustomModel(value.trim());
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型重定向:</Typography.Text>
|
<Typography.Text strong>模型重定向:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<TextArea
|
<TextArea
|
||||||
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
||||||
name='model_mapping'
|
name='model_mapping'
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('model_mapping', value)
|
handleInputChange('model_mapping', value)
|
||||||
}}
|
}}
|
||||||
autosize
|
autosize
|
||||||
value={inputs.model_mapping}
|
value={inputs.model_mapping}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
<div style={{ marginTop: 10 }}>
|
||||||
|
<Typography.Text strong>系统提示词:</Typography.Text>
|
||||||
|
</div>
|
||||||
|
<TextArea
|
||||||
|
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
|
||||||
|
name='system_prompt'
|
||||||
|
onChange={value => {
|
||||||
|
handleInputChange('system_prompt', value)
|
||||||
|
}}
|
||||||
|
autosize
|
||||||
|
value={inputs.system_prompt}
|
||||||
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
<Typography.Text style={{
|
<Typography.Text style={{
|
||||||
color: 'rgba(var(--semi-blue-5), 1)',
|
color: 'rgba(var(--semi-blue-5), 1)',
|
||||||
@@ -507,116 +521,116 @@ const EditChannel = (props) => {
|
|||||||
}>
|
}>
|
||||||
填入模板
|
填入模板
|
||||||
</Typography.Text>
|
</Typography.Text>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>密钥:</Typography.Text>
|
<Typography.Text strong>密钥:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
{
|
{
|
||||||
batch ?
|
batch ?
|
||||||
<TextArea
|
<TextArea
|
||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={'请输入密钥,一行一个'}
|
placeholder={'请输入密钥,一行一个'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('key', value)
|
handleInputChange('key', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
style={{minHeight: 150, fontFamily: 'JetBrains Mono, Consolas'}}
|
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
:
|
:
|
||||||
<Input
|
<Input
|
||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={type2secretPrompt(inputs.type)}
|
placeholder={type2secretPrompt(inputs.type)}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('key', value)
|
handleInputChange('key', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>组织:</Typography.Text>
|
<Typography.Text strong>组织:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='组织,可选,不填则为默认组织'
|
label='组织,可选,不填则为默认组织'
|
||||||
name='openai_organization'
|
name='openai_organization'
|
||||||
placeholder='请输入组织org-xxx'
|
placeholder='请输入组织org-xxx'
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('openai_organization', value)
|
handleInputChange('openai_organization', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.openai_organization}
|
value={inputs.openai_organization}
|
||||||
/>
|
/>
|
||||||
<div style={{marginTop: 10, display: 'flex'}}>
|
<div style={{ marginTop: 10, display: 'flex' }}>
|
||||||
<Space>
|
<Space>
|
||||||
<Checkbox
|
<Checkbox
|
||||||
name='auto_ban'
|
name='auto_ban'
|
||||||
checked={autoBan}
|
checked={autoBan}
|
||||||
onChange={
|
onChange={
|
||||||
() => {
|
() => {
|
||||||
setAutoBan(!autoBan);
|
setAutoBan(!autoBan);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// onChange={handleInputChange}
|
// onChange={handleInputChange}
|
||||||
/>
|
/>
|
||||||
<Typography.Text
|
<Typography.Text
|
||||||
strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text>
|
strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text>
|
||||||
</Space>
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{
|
{
|
||||||
!isEdit && (
|
!isEdit && (
|
||||||
<div style={{marginTop: 10, display: 'flex'}}>
|
<div style={{ marginTop: 10, display: 'flex' }}>
|
||||||
<Space>
|
<Space>
|
||||||
<Checkbox
|
<Checkbox
|
||||||
checked={batch}
|
checked={batch}
|
||||||
label='批量创建'
|
label='批量创建'
|
||||||
name='batch'
|
name='batch'
|
||||||
onChange={() => setBatch(!batch)}
|
onChange={() => setBatch(!batch)}
|
||||||
/>
|
/>
|
||||||
<Typography.Text strong>批量创建</Typography.Text>
|
<Typography.Text strong>批量创建</Typography.Text>
|
||||||
</Space>
|
</Space>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
||||||
|
<>
|
||||||
|
<div style={{ marginTop: 10 }}>
|
||||||
|
<Typography.Text strong>代理:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
)
|
<Input
|
||||||
|
label='代理'
|
||||||
|
name='base_url'
|
||||||
|
placeholder={'此项可选,用于通过代理站来进行 API 调用'}
|
||||||
|
onChange={value => {
|
||||||
|
handleInputChange('base_url', value)
|
||||||
|
}}
|
||||||
|
value={inputs.base_url}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
inputs.type === 22 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>代理:</Typography.Text>
|
<Typography.Text strong>私有部署地址:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='代理'
|
name='base_url'
|
||||||
name='base_url'
|
placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
|
||||||
placeholder={'此项可选,用于通过代理站来进行 API 调用'}
|
onChange={value => {
|
||||||
onChange={value => {
|
handleInputChange('base_url', value)
|
||||||
handleInputChange('base_url', value)
|
}}
|
||||||
}}
|
value={inputs.base_url}
|
||||||
value={inputs.base_url}
|
autoComplete='new-password'
|
||||||
autoComplete='new-password'
|
/>
|
||||||
/>
|
</>
|
||||||
</>
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
{
|
|
||||||
inputs.type === 22 && (
|
|
||||||
<>
|
|
||||||
<div style={{marginTop: 10}}>
|
|
||||||
<Typography.Text strong>私有部署地址:</Typography.Text>
|
|
||||||
</div>
|
|
||||||
<Input
|
|
||||||
name='base_url'
|
|
||||||
placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
|
|
||||||
onChange={value => {
|
|
||||||
handleInputChange('base_url', value)
|
|
||||||
}}
|
|
||||||
value={inputs.base_url}
|
|
||||||
autoComplete='new-password'
|
|
||||||
/>
|
|
||||||
</>
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
</Spin>
|
</Spin>
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 5.4 KiB After Width: | Height: | Size: 4.3 KiB |
7
web/berry/src/assets/images/icons/oidc.svg
Normal file
7
web/berry/src/assets/images/icons/oidc.svg
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<svg t="1723135116886" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"
|
||||||
|
p-id="10969" width="200" height="200">
|
||||||
|
<path d="M512 960C265 960 64 759 64 512S265 64 512 64s448 201 448 448-201 448-448 448z m0-882.6c-239.7 0-434.6 195-434.6 434.6s195 434.6 434.6 434.6 434.6-195 434.6-434.6S751.7 77.4 512 77.4z"
|
||||||
|
p-id="10970" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="60"></path>
|
||||||
|
<path d="M197.7 512c0-78.3 31.6-98.8 87.2-98.8 56.2 0 87.2 20.5 87.2 98.8s-31 98.8-87.2 98.8c-55.7 0-87.2-20.5-87.2-98.8z m130.4 0c0-46.8-7.8-64.5-43.2-64.5-35.2 0-42.9 17.7-42.9 64.5 0 47.1 7.8 63.7 42.9 63.7 35.4 0 43.2-16.6 43.2-63.7zM409.7 415.9h42.1V608h-42.1V415.9zM653.9 512c0 74.2-37.1 96.1-93.6 96.1h-65.9V415.9h65.9c56.5 0 93.6 16.1 93.6 96.1z m-43.5 0c0-49.3-17.7-60.6-52.3-60.6h-21.6v120.7h21.6c35.4 0 52.3-13.3 52.3-60.1zM686.5 512c0-74.2 36.3-98.8 92.7-98.8 18.3 0 33.2 2.2 44.8 6.4v36.3c-11.9-4.2-26-6.6-42.1-6.6-34.6 0-49.8 15.5-49.8 62.6 0 50.1 15.2 62.6 49.3 62.6 15.8 0 30.2-2.2 44.8-7.5v36c-11.3 4.7-28.5 8-46.8 8-56.1-0.2-92.9-18.7-92.9-99z"
|
||||||
|
p-id="10971" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="20"></path>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.2 KiB |
@@ -22,7 +22,12 @@ const config = {
|
|||||||
turnstile_site_key: '',
|
turnstile_site_key: '',
|
||||||
version: '',
|
version: '',
|
||||||
wechat_login: false,
|
wechat_login: false,
|
||||||
wechat_qrcode: ''
|
wechat_qrcode: '',
|
||||||
|
oidc: false,
|
||||||
|
oidc_client_id: '',
|
||||||
|
oidc_authorization_endpoint: '',
|
||||||
|
oidc_token_endpoint: '',
|
||||||
|
oidc_userinfo_endpoint: '',
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -173,6 +173,24 @@ export const CHANNEL_OPTIONS = {
|
|||||||
value: 43,
|
value: 43,
|
||||||
color: 'primary'
|
color: 'primary'
|
||||||
},
|
},
|
||||||
|
44: {
|
||||||
|
key: 44,
|
||||||
|
text: 'SiliconFlow',
|
||||||
|
value: 44,
|
||||||
|
color: 'primary'
|
||||||
|
},
|
||||||
|
45: {
|
||||||
|
key: 45,
|
||||||
|
text: 'xAI',
|
||||||
|
value: 45,
|
||||||
|
color: 'primary'
|
||||||
|
},
|
||||||
|
45: {
|
||||||
|
key: 46,
|
||||||
|
text: 'Replicate',
|
||||||
|
value: 46,
|
||||||
|
color: 'primary'
|
||||||
|
},
|
||||||
41: {
|
41: {
|
||||||
key: 41,
|
key: 41,
|
||||||
text: 'Novita',
|
text: 'Novita',
|
||||||
|
|||||||
@@ -70,6 +70,28 @@ const useLogin = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const oidcLogin = async (code, state) => {
|
||||||
|
try {
|
||||||
|
const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`);
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
if (message === 'bind') {
|
||||||
|
showSuccess('绑定成功!');
|
||||||
|
navigate('/panel');
|
||||||
|
} else {
|
||||||
|
dispatch({ type: LOGIN, payload: data });
|
||||||
|
localStorage.setItem('user', JSON.stringify(data));
|
||||||
|
showSuccess('登录成功!');
|
||||||
|
navigate('/panel');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { success, message };
|
||||||
|
} catch (err) {
|
||||||
|
// 请求失败,设置错误信息
|
||||||
|
return { success: false, message: '' };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const wechatLogin = async (code) => {
|
const wechatLogin = async (code) => {
|
||||||
try {
|
try {
|
||||||
const res = await API.get(`/api/oauth/wechat?code=${code}`);
|
const res = await API.get(`/api/oauth/wechat?code=${code}`);
|
||||||
@@ -94,7 +116,7 @@ const useLogin = () => {
|
|||||||
navigate('/');
|
navigate('/');
|
||||||
};
|
};
|
||||||
|
|
||||||
return { login, logout, githubLogin, wechatLogin, larkLogin };
|
return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin };
|
||||||
};
|
};
|
||||||
|
|
||||||
export default useLogin;
|
export default useLogin;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login'))
|
|||||||
const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register')));
|
const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register')));
|
||||||
const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth')));
|
const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth')));
|
||||||
const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth')));
|
const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth')));
|
||||||
|
const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth')));
|
||||||
const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword')));
|
const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword')));
|
||||||
const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword')));
|
const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword')));
|
||||||
const Home = Loadable(lazy(() => import('views/Home')));
|
const Home = Loadable(lazy(() => import('views/Home')));
|
||||||
@@ -53,6 +54,10 @@ const OtherRoutes = {
|
|||||||
path: '/oauth/lark',
|
path: '/oauth/lark',
|
||||||
element: <LarkOAuth />
|
element: <LarkOAuth />
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
path: 'oauth/oidc',
|
||||||
|
element: <OidcOAuth />
|
||||||
|
},
|
||||||
{
|
{
|
||||||
path: '/404',
|
path: '/404',
|
||||||
element: <NotFoundView />
|
element: <NotFoundView />
|
||||||
|
|||||||
@@ -95,7 +95,22 @@ export async function onLarkOAuthClicked(lark_client_id) {
|
|||||||
const state = await getOAuthState();
|
const state = await getOAuthState();
|
||||||
if (!state) return;
|
if (!state) return;
|
||||||
let redirect_uri = `${window.location.origin}/oauth/lark`;
|
let redirect_uri = `${window.location.origin}/oauth/lark`;
|
||||||
window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`);
|
window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
|
||||||
|
const state = await getOAuthState();
|
||||||
|
if (!state) return;
|
||||||
|
const redirect_uri = `${window.location.origin}/oauth/oidc`;
|
||||||
|
const response_type = "code";
|
||||||
|
const scope = "openid profile email";
|
||||||
|
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
|
||||||
|
if (openInNewTab) {
|
||||||
|
window.open(url);
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
window.location.href = url;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isAdmin() {
|
export function isAdmin() {
|
||||||
|
|||||||
94
web/berry/src/views/Authentication/Auth/OidcOAuth.js
Normal file
94
web/berry/src/views/Authentication/Auth/OidcOAuth.js
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||||
|
import React, { useEffect, useState } from 'react';
|
||||||
|
import { showError } from 'utils/common';
|
||||||
|
import useLogin from 'hooks/useLogin';
|
||||||
|
|
||||||
|
// material-ui
|
||||||
|
import { useTheme } from '@mui/material/styles';
|
||||||
|
import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material';
|
||||||
|
|
||||||
|
// project imports
|
||||||
|
import AuthWrapper from '../AuthWrapper';
|
||||||
|
import AuthCardWrapper from '../AuthCardWrapper';
|
||||||
|
import Logo from 'ui-component/Logo';
|
||||||
|
|
||||||
|
// assets
|
||||||
|
|
||||||
|
// ================================|| AUTH3 - LOGIN ||================================ //
|
||||||
|
|
||||||
|
const OidcOAuth = () => {
|
||||||
|
const theme = useTheme();
|
||||||
|
const matchDownSM = useMediaQuery(theme.breakpoints.down('md'));
|
||||||
|
|
||||||
|
const [searchParams] = useSearchParams();
|
||||||
|
const [prompt, setPrompt] = useState('处理中...');
|
||||||
|
const { oidcLogin } = useLogin();
|
||||||
|
|
||||||
|
let navigate = useNavigate();
|
||||||
|
|
||||||
|
const sendCode = async (code, state, count) => {
|
||||||
|
const { success, message } = await oidcLogin(code, state);
|
||||||
|
if (!success) {
|
||||||
|
if (message) {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
if (count === 0) {
|
||||||
|
setPrompt(`操作失败,重定向至登录界面中...`);
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||||
|
navigate('/login');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
count++;
|
||||||
|
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||||
|
await sendCode(code, state, count);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let code = searchParams.get('code');
|
||||||
|
let state = searchParams.get('state');
|
||||||
|
sendCode(code, state, 0).then();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AuthWrapper>
|
||||||
|
<Grid container direction="column" justifyContent="flex-end">
|
||||||
|
<Grid item xs={12}>
|
||||||
|
<Grid container justifyContent="center" alignItems="center" sx={{ minHeight: 'calc(100vh - 136px)' }}>
|
||||||
|
<Grid item sx={{ m: { xs: 1, sm: 3 }, mb: 0 }}>
|
||||||
|
<AuthCardWrapper>
|
||||||
|
<Grid container spacing={2} alignItems="center" justifyContent="center">
|
||||||
|
<Grid item sx={{ mb: 3 }}>
|
||||||
|
<Link to="#">
|
||||||
|
<Logo />
|
||||||
|
</Link>
|
||||||
|
</Grid>
|
||||||
|
<Grid item xs={12}>
|
||||||
|
<Grid container direction={matchDownSM ? 'column-reverse' : 'row'} alignItems="center" justifyContent="center">
|
||||||
|
<Grid item>
|
||||||
|
<Stack alignItems="center" justifyContent="center" spacing={1}>
|
||||||
|
<Typography color={theme.palette.primary.main} gutterBottom variant={matchDownSM ? 'h3' : 'h2'}>
|
||||||
|
OIDC 登录
|
||||||
|
</Typography>
|
||||||
|
</Stack>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
<Grid item xs={12} container direction="column" justifyContent="center" alignItems="center" style={{ height: '200px' }}>
|
||||||
|
<CircularProgress />
|
||||||
|
<Typography variant="h3" paddingTop={'20px'}>
|
||||||
|
{prompt}
|
||||||
|
</Typography>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
</AuthCardWrapper>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
</AuthWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default OidcOAuth;
|
||||||
@@ -36,7 +36,8 @@ import VisibilityOff from '@mui/icons-material/VisibilityOff';
|
|||||||
import Github from 'assets/images/icons/github.svg';
|
import Github from 'assets/images/icons/github.svg';
|
||||||
import Wechat from 'assets/images/icons/wechat.svg';
|
import Wechat from 'assets/images/icons/wechat.svg';
|
||||||
import Lark from 'assets/images/icons/lark.svg';
|
import Lark from 'assets/images/icons/lark.svg';
|
||||||
import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common';
|
import OIDC from 'assets/images/icons/oidc.svg';
|
||||||
|
import { onGitHubOAuthClicked, onLarkOAuthClicked, onOidcClicked } from 'utils/common';
|
||||||
|
|
||||||
// ============================|| FIREBASE - LOGIN ||============================ //
|
// ============================|| FIREBASE - LOGIN ||============================ //
|
||||||
|
|
||||||
@@ -50,7 +51,7 @@ const LoginForm = ({ ...others }) => {
|
|||||||
// const [checked, setChecked] = useState(true);
|
// const [checked, setChecked] = useState(true);
|
||||||
|
|
||||||
let tripartiteLogin = false;
|
let tripartiteLogin = false;
|
||||||
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id) {
|
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id || siteInfo.oidc) {
|
||||||
tripartiteLogin = true;
|
tripartiteLogin = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,6 +146,29 @@ const LoginForm = ({ ...others }) => {
|
|||||||
</AnimateButton>
|
</AnimateButton>
|
||||||
</Grid>
|
</Grid>
|
||||||
)}
|
)}
|
||||||
|
{siteInfo.oidc && (
|
||||||
|
<Grid item xs={12}>
|
||||||
|
<AnimateButton>
|
||||||
|
<Button
|
||||||
|
disableElevation
|
||||||
|
fullWidth
|
||||||
|
onClick={() => onOidcClicked(siteInfo.oidc_authorization_endpoint,siteInfo.oidc_client_id)}
|
||||||
|
size="large"
|
||||||
|
variant="outlined"
|
||||||
|
sx={{
|
||||||
|
color: 'grey.700',
|
||||||
|
backgroundColor: theme.palette.grey[50],
|
||||||
|
borderColor: theme.palette.grey[100]
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
|
||||||
|
<img src={OIDC} alt="Lark" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
|
||||||
|
</Box>
|
||||||
|
使用 OIDC 登录
|
||||||
|
</Button>
|
||||||
|
</AnimateButton>
|
||||||
|
</Grid>
|
||||||
|
)}
|
||||||
<Grid item xs={12}>
|
<Grid item xs={12}>
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
|
|||||||
@@ -595,6 +595,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
|
|||||||
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
|
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
|
||||||
)}
|
)}
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl fullWidth error={Boolean(touched.system_prompt && errors.system_prompt)} sx={{ ...theme.typography.otherInput }}>
|
||||||
|
{/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */}
|
||||||
|
<TextField
|
||||||
|
multiline
|
||||||
|
id="channel-system_prompt-label"
|
||||||
|
label={inputLabel.system_prompt}
|
||||||
|
value={values.system_prompt}
|
||||||
|
name="system_prompt"
|
||||||
|
onBlur={handleBlur}
|
||||||
|
onChange={handleChange}
|
||||||
|
aria-describedby="helper-text-channel-system_prompt-label"
|
||||||
|
minRows={5}
|
||||||
|
placeholder={inputPrompt.system_prompt}
|
||||||
|
/>
|
||||||
|
{touched.system_prompt && errors.system_prompt ? (
|
||||||
|
<FormHelperText error id="helper-tex-channel-system_prompt-label">
|
||||||
|
{errors.system_prompt}
|
||||||
|
</FormHelperText>
|
||||||
|
) : (
|
||||||
|
<FormHelperText id="helper-tex-channel-system_prompt-label"> {inputPrompt.system_prompt} </FormHelperText>
|
||||||
|
)}
|
||||||
|
</FormControl>
|
||||||
<DialogActions>
|
<DialogActions>
|
||||||
<Button onClick={onCancel}>取消</Button>
|
<Button onClick={onCancel}>取消</Button>
|
||||||
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user