mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-23 09:53:42 +08:00
Compare commits
35 Commits
v0.5.8-alp
...
v0.5.10
Author | SHA1 | Date | |
---|---|---|---|
|
67c64e71c8 | ||
|
97030e27f8 | ||
|
461f5dab56 | ||
|
af378c59af | ||
|
bc6769826b | ||
|
0fe26cc4bd | ||
|
7d6a169669 | ||
|
66f06e5d6f | ||
|
6acb9537a9 | ||
|
7069c49bdf | ||
|
58dee76bf7 | ||
|
5cf23d8698 | ||
|
366b82128f | ||
|
2a70744dbf | ||
|
4c5feee0b6 | ||
|
9ba5388367 | ||
|
379074f7d0 | ||
|
01f7b0186f | ||
|
a3f80a3392 | ||
|
8f5b83562b | ||
|
b7570d5c77 | ||
|
0e73418cdf | ||
|
9889377f0e | ||
|
b273464e77 | ||
|
b4e43d97fd | ||
|
3347a44023 | ||
|
923e24534b | ||
|
b4d67ca614 | ||
|
d85e356b6e | ||
|
495fc628e4 | ||
|
76f9288c34 | ||
|
915d13fdd4 | ||
|
969f539777 | ||
|
54e5f8ecd2 | ||
|
34d517cfa2 |
@@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
|
||||
1. Support for multiple large models:
|
||||
+ [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||
+ [x] [Anthropic Claude Series Models](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 Series Models](https://developers.generativeai.google)
|
||||
+ [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google)
|
||||
+ [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
|
||||
|
@@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に
|
||||
1. 複数の大型モデルをサポート:
|
||||
+ [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート)
|
||||
+ [x] [Anthropic Claude シリーズモデル](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google)
|
||||
+ [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google)
|
||||
+ [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn)
|
||||
|
21
README.md
21
README.md
@@ -51,35 +51,29 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
<a href="https://iamazing.cn/page/reward">赞赏支持</a>
|
||||
</p>
|
||||
|
||||
> **Note**
|
||||
> [!NOTE]
|
||||
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||
>
|
||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
> **Warning**
|
||||
> [!WARNING]
|
||||
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
|
||||
|
||||
> **Warning**
|
||||
> [!WARNING]
|
||||
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`!
|
||||
|
||||
## 功能
|
||||
1. 支持多种大模型:
|
||||
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
|
||||
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
|
||||
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||
+ [x] [360 智脑](https://ai.360.cn)
|
||||
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
|
||||
2. 支持配置镜像以及众多第三方代理服务:
|
||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||
+ [x] [CloseAI](https://referer.shadowai.xyz/r/2412)
|
||||
+ [x] [API2D](https://api2d.com/r/197971)
|
||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
||||
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
|
||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
|
||||
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||
5. 支持**多机部署**,[详见此处](#多机部署)。
|
||||
@@ -92,14 +86,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
12. 支持**用户邀请奖励**。
|
||||
13. 支持以美元为单位显示额度。
|
||||
14. 支持发布公告,设置充值链接,设置新用户初始额度。
|
||||
15. 支持模型映射,重定向用户的请求模型。
|
||||
15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。
|
||||
16. 支持失败自动重试。
|
||||
17. 支持绘图接口。
|
||||
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
|
||||
19. 支持丰富的**自定义**设置,
|
||||
1. 支持自定义系统名称,logo 以及页脚。
|
||||
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
||||
20. 支持通过系统访问令牌访问管理 API。
|
||||
20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
|
||||
21. 支持 Cloudflare Turnstile 用户校验。
|
||||
22. 支持用户管理,支持**多种用户登录注册方式**:
|
||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||
@@ -371,6 +365,7 @@ graph LR
|
||||
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
|
||||
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
|
||||
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
|
||||
16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
|
||||
|
||||
### 命令行参数
|
||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||
|
@@ -78,6 +78,7 @@ var QuotaForInviter = 0
|
||||
var QuotaForInvitee = 0
|
||||
var ChannelDisableThreshold = 5.0
|
||||
var AutomaticDisableChannelEnabled = false
|
||||
var AutomaticEnableChannelEnabled = false
|
||||
var QuotaRemindThreshold = 1000
|
||||
var PreConsumedQuota = 500
|
||||
var ApproximateTokenEnabled = false
|
||||
@@ -186,6 +187,7 @@ const (
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@@ -213,4 +215,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", //23
|
||||
"", //24
|
||||
}
|
||||
|
@@ -4,3 +4,4 @@ var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
|
||||
var SQLitePath = "one-api.db"
|
||||
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
|
||||
|
@@ -1,11 +1,13 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SendEmail(subject string, receiver string, content string) error {
|
||||
@@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
SMTPFrom = SMTPAccount
|
||||
}
|
||||
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
|
||||
|
||||
// Extract domain from SMTPFrom
|
||||
parts := strings.Split(SMTPFrom, "@")
|
||||
var domain string
|
||||
if len(parts) > 1 {
|
||||
domain = parts[1]
|
||||
}
|
||||
// Generate a unique Message-ID
|
||||
buf := make([]byte, 16)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
|
||||
|
||||
mail := []byte(fmt.Sprintf("To: %s\r\n"+
|
||||
"From: %s<%s>\r\n"+
|
||||
"Subject: %s\r\n"+
|
||||
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
|
||||
"Date: %s\r\n"+
|
||||
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
|
||||
receiver, SystemName, SMTPFrom, encodedSubject, content))
|
||||
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
|
||||
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
|
||||
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
|
||||
to := strings.Split(receiver, ";")
|
||||
var err error
|
||||
|
||||
if SMTPPort == 465 {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
@@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
47
common/image/image.go
Normal file
47
common/image/image.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
img, _, err := image.DecodeConfig(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return img.Width, img.Height, nil
|
||||
}
|
||||
|
||||
var (
|
||||
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
|
||||
)
|
||||
|
||||
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
|
||||
encoded = strings.TrimPrefix(encoded, "data:image/png;base64,")
|
||||
base64 := strings.NewReader(reg.ReplaceAllString(encoded, ""))
|
||||
img, _, err := image.DecodeConfig(base64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return img.Width, img.Height, nil
|
||||
}
|
||||
|
||||
func GetImageSize(image string) (width int, height int, err error) {
|
||||
if strings.HasPrefix(image, "data:image/") {
|
||||
return GetImageSizeFromBase64(image)
|
||||
}
|
||||
return GetImageSizeFromUrl(image)
|
||||
}
|
154
common/image/image_test.go
Normal file
154
common/image/image_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package image_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
img "one-api/common/image"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
type CountingReader struct {
|
||||
reader io.Reader
|
||||
BytesRead int
|
||||
}
|
||||
|
||||
func (r *CountingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
r.BytesRead += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
var (
|
||||
cases = []struct {
|
||||
url string
|
||||
format string
|
||||
width int
|
||||
height int
|
||||
}{
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230},
|
||||
}
|
||||
)
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
// Bytes read: varies sometimes
|
||||
// jpeg: 1063892
|
||||
// png: 294462
|
||||
// webp: 99529
|
||||
// gif: 956153
|
||||
// jpeg#01: 32805
|
||||
for _, c := range cases {
|
||||
t.Run("Decode:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
reader := &CountingReader{reader: resp.Body}
|
||||
img, format, err := image.Decode(reader)
|
||||
assert.NoError(t, err)
|
||||
size := img.Bounds().Size()
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, size.X)
|
||||
assert.Equal(t, c.height, size.Y)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
|
||||
// Bytes read:
|
||||
// jpeg: 4096
|
||||
// png: 4096
|
||||
// webp: 4096
|
||||
// gif: 4096
|
||||
// jpeg#01: 4096
|
||||
for _, c := range cases {
|
||||
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
reader := &CountingReader{reader: resp.Body}
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, config.Width)
|
||||
assert.Equal(t, c.height, config.Height)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase64(t *testing.T) {
|
||||
// Bytes read:
|
||||
// jpeg: 1063892
|
||||
// png: 294462
|
||||
// webp: 99072
|
||||
// gif: 953856
|
||||
// jpeg#01: 32805
|
||||
for _, c := range cases {
|
||||
t.Run("Decode:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||
reader := &CountingReader{reader: body}
|
||||
img, format, err := image.Decode(reader)
|
||||
assert.NoError(t, err)
|
||||
size := img.Bounds().Size()
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, size.X)
|
||||
assert.Equal(t, c.height, size.Y)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
|
||||
// Bytes read:
|
||||
// jpeg: 1536
|
||||
// png: 768
|
||||
// webp: 768
|
||||
// gif: 1536
|
||||
// jpeg#01: 3840
|
||||
for _, c := range cases {
|
||||
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||
reader := &CountingReader{reader: body}
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, config.Width)
|
||||
assert.Equal(t, c.height, config.Height)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetImageSize(t *testing.T) {
|
||||
for i, c := range cases {
|
||||
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
|
||||
width, height, err := img.GetImageSize(c.url)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.width, width)
|
||||
assert.Equal(t, c.height, height)
|
||||
})
|
||||
}
|
||||
}
|
@@ -76,17 +76,22 @@ var ModelRatio = map[string]float64{
|
||||
"dall-e-3": 20, // $0.040 - $0.120 / image
|
||||
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
||||
"claude-2.0": 5.51, // $11.02 / 1M tokens
|
||||
"claude-2.1": 5.51, // $11.02 / 1M tokens
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
||||
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
||||
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||
|
@@ -5,20 +5,23 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
case common.ChannelTypeGemini:
|
||||
fallthrough
|
||||
case common.ChannelTypeAnthropic:
|
||||
fallthrough
|
||||
case common.ChannelTypeBaidu:
|
||||
@@ -43,16 +46,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
||||
}
|
||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
|
||||
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
||||
} else {
|
||||
if channel.GetBaseURL() != "" {
|
||||
requestURL = channel.GetBaseURL()
|
||||
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
||||
requestURL = baseURL
|
||||
}
|
||||
requestURL += "/v1/chat/completions"
|
||||
}
|
||||
// for Cloudflare AI gateway: https://github.com/songquanpeng/one-api/pull/639
|
||||
requestURL = strings.Replace(requestURL, "/v1/v1", "/v1", 1)
|
||||
|
||||
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
|
||||
}
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
@@ -73,11 +74,18 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var response TextResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
|
||||
}
|
||||
if response.Usage.CompletionTokens == 0 {
|
||||
if response.Error.Message == "" {
|
||||
response.Error.Message = "补全 tokens 非预期返回 0"
|
||||
}
|
||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||
}
|
||||
return nil, nil
|
||||
@@ -139,20 +147,32 @@ func TestChannel(c *gin.Context) {
|
||||
var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
func notifyRootUser(subject string, content string) {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
}
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
// enable & notify
|
||||
func enableChannel(channelId int, channelName string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
@@ -175,20 +195,21 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
go func() {
|
||||
for _, channel := range channels {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, *testRequest)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
if milliseconds > disableThreshold {
|
||||
if isChannelEnabled && milliseconds > disableThreshold {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if shouldDisableChannel(openaiErr, -1) {
|
||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
||||
enableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
@@ -360,6 +360,24 @@ func init() {
|
||||
Root: "claude-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-2.1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2.1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-2.0",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2.0",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot",
|
||||
Object: "model",
|
||||
@@ -405,6 +423,15 @@ func init() {
|
||||
Root: "PaLM-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_turbo",
|
||||
Object: "model",
|
||||
@@ -459,6 +486,24 @@ func init() {
|
||||
Root: "qwen-plus",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-max",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-max",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-max-longcontext",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-max-longcontext",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-v1",
|
||||
Object: "model",
|
||||
|
@@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
|
||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||
query := ""
|
||||
if len(request.Messages) != 0 {
|
||||
query = request.Messages[len(request.Messages)-1].Content
|
||||
query = request.Messages[len(request.Messages)-1].StringContent()
|
||||
}
|
||||
return &AIProxyLibraryRequest{
|
||||
Model: request.Model,
|
||||
|
@@ -13,13 +13,13 @@ import (
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
//Prompt string `json:"prompt"`
|
||||
Messages []AliMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
@@ -83,32 +83,17 @@ type AliChatResponse struct {
|
||||
|
||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = message.Content
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
Bot: request.Messages[i+1].Content,
|
||||
})
|
||||
i++
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
Content: message.StringContent(),
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Input: AliInput{
|
||||
Prompt: prompt,
|
||||
History: messages,
|
||||
Messages: messages,
|
||||
},
|
||||
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
||||
// TopP: request.TopP,
|
||||
|
@@ -1,15 +1,18 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
@@ -37,41 +40,40 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
}
|
||||
}
|
||||
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
modelRatio := common.GetModelRatio(audioModel)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
var quota int
|
||||
var preConsumedQuota int
|
||||
switch relayMode {
|
||||
case RelayModeAudioSpeech:
|
||||
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
|
||||
quota = preConsumedQuota
|
||||
default:
|
||||
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
|
||||
}
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
quota := 0
|
||||
// Check if user quota is enough
|
||||
if relayMode == RelayModeAudioSpeech {
|
||||
quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
|
||||
if quota > userQuota {
|
||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
} else {
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,13 +97,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||
requestBody := c.Request.Body
|
||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiVersion := GetAPIVersion(c)
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
||||
}
|
||||
|
||||
requestBody := &bytes.Buffer{}
|
||||
_, err = io.Copy(requestBody, c.Request.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
||||
responseFormat := c.DefaultPostForm("response_format", "json")
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
|
||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
req.Header.Set("api-key", apiKey)
|
||||
req.ContentLength = c.Request.ContentLength
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
@@ -119,11 +142,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if relayMode == RelayModeAudioSpeech {
|
||||
defer func(ctx context.Context) {
|
||||
go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
}(c.Request.Context())
|
||||
} else {
|
||||
if relayMode != RelayModeAudioSpeech {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
@@ -132,18 +151,55 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var whisperResponse WhisperResponse
|
||||
err = json.Unmarshal(responseBody, &whisperResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
|
||||
var openAIErr TextResponse
|
||||
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
||||
if openAIErr.Error.Message != "" {
|
||||
return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
defer func(ctx context.Context) {
|
||||
quota := countTokenText(whisperResponse.Text, audioModel)
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
}(c.Request.Context())
|
||||
|
||||
var text string
|
||||
switch responseFormat {
|
||||
case "json":
|
||||
text, err = getTextFromJSON(responseBody)
|
||||
case "text":
|
||||
text, err = getTextFromText(responseBody)
|
||||
case "srt":
|
||||
text, err = getTextFromSRT(responseBody)
|
||||
case "verbose_json":
|
||||
text, err = getTextFromVerboseJSON(responseBody)
|
||||
case "vtt":
|
||||
text, err = getTextFromVTT(responseBody)
|
||||
default:
|
||||
return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
||||
}
|
||||
quota = countTokenText(text, audioModel)
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if preConsumedQuota > 0 {
|
||||
// we need to roll back the pre-consumed quota
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
// negative means add quota back for token & user
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return relayErrorHandler(resp)
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
defer func(ctx context.Context) {
|
||||
go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
}(c.Request.Context())
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
@@ -159,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTextFromVTT(body []byte) (string, error) {
|
||||
return getTextFromSRT(body)
|
||||
}
|
||||
|
||||
func getTextFromVerboseJSON(body []byte) (string, error) {
|
||||
var whisperResponse WhisperVerboseJSONResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
||||
func getTextFromSRT(body []byte) (string, error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
||||
var builder strings.Builder
|
||||
var textLine bool
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if textLine {
|
||||
builder.WriteString(line)
|
||||
textLine = false
|
||||
continue
|
||||
} else if strings.Contains(line, "-->") {
|
||||
textLine = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func getTextFromText(body []byte) (string, error) {
|
||||
return strings.TrimSuffix(string(body), "\n"), nil
|
||||
}
|
||||
|
||||
func getTextFromJSON(body []byte) (string, error) {
|
||||
var whisperResponse WhisperJSONResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
@@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
@@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else if message.Role == "system" {
|
||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||
if prompt == "" {
|
||||
prompt = message.StringContent()
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
|
305
controller/relay-gemini.go
Normal file
305
controller/relay-gemini.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
//SafetySettings: []GeminiChatSafetySettings{
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_HARASSMENT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
//},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
content := GeminiChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to prompt from user for the same reason
|
||||
if content.Role == "system" {
|
||||
content.Role = "user"
|
||||
shouldAddDummyModelMessage = true
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
|
||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
},
|
||||
})
|
||||
shouldAddDummyModelMessage = false
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
}
|
||||
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
return g.Candidates[0].Content.Parts[0].Text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: stopFinishReason,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &stopFinishReason
|
||||
var response ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -18,7 +19,6 @@ func isWithinRange(element string, value int) bool {
|
||||
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
min := common.DalleGenerationImageAmounts[element][0]
|
||||
max := common.DalleGenerationImageAmounts[element][1]
|
||||
|
||||
@@ -33,15 +33,16 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
|
||||
var imageRequest ImageRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
// Size validation
|
||||
@@ -81,7 +82,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
|
||||
// Number of generated images validation
|
||||
if isWithinRange(imageModel, imageRequest.N) == false {
|
||||
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||
// channel not azure
|
||||
if channelType != common.ChannelTypeAzure {
|
||||
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// map model name
|
||||
@@ -104,8 +108,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
apiVersion := GetAPIVersion(c)
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
|
||||
jsonStr, err := json.Marshal(imageRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
@@ -122,7 +133,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
|
||||
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
if userQuota-quota < 0 {
|
||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
@@ -130,7 +141,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
token := c.Request.Header.Get("Authorization")
|
||||
if channelType == common.ChannelTypeAzure { // Azure authentication
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
req.Header.Set("api-key", token)
|
||||
} else {
|
||||
req.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
@@ -151,43 +168,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
var textResponse ImageResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota {
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
|
@@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var textResponse TextResponse
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
@@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -132,7 +131,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += countTokenText(choice.Message.Content, model)
|
||||
completionTokens += countTokenText(choice.Message.StringContent(), model)
|
||||
}
|
||||
textResponse.Usage = Usage{
|
||||
PromptTokens: promptTokens,
|
||||
|
@@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
|
@@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "assistant",
|
||||
@@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||
continue
|
||||
}
|
||||
messages = append(messages, TencentMessage{
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
|
@@ -27,6 +27,7 @@ const (
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
@@ -51,14 +52,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
var textRequest GeneralOpenAIRequest
|
||||
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
|
||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
|
||||
}
|
||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
@@ -121,6 +122,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
@@ -132,11 +135,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
apiVersion := GetAPIVersion(c)
|
||||
requestURL := strings.Split(requestURL, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
baseURL = c.GetString("base_url")
|
||||
@@ -147,7 +146,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
|
||||
}
|
||||
case APITypeClaude:
|
||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
||||
@@ -182,6 +183,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
case APITypeGemini:
|
||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
||||
if baseURL != "" {
|
||||
requestBaseURL = baseURL
|
||||
}
|
||||
version := "v1"
|
||||
if c.GetString("api_version") != "" {
|
||||
version = c.GetString("api_version")
|
||||
}
|
||||
action := "generateContent"
|
||||
if textRequest.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
case APITypeZhipu:
|
||||
method := "invoke"
|
||||
if textRequest.Stream {
|
||||
@@ -233,7 +251,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
if consumeQuota && preConsumedQuota > 0 {
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
@@ -279,6 +297,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeGemini:
|
||||
geminiChatRequest := requestOpenAI2Gemini(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeZhipu:
|
||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
||||
jsonStr, err := json.Marshal(zhipuRequest)
|
||||
@@ -365,10 +390,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if textRequest.Stream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
case APITypeTencent:
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case APITypePaLM:
|
||||
// do not set Authorization header
|
||||
case APITypeGemini:
|
||||
// do not set Authorization header
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
@@ -412,37 +442,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
go func() {
|
||||
if consumeQuota {
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
switch apiType {
|
||||
@@ -456,7 +485,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
|
||||
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -530,6 +559,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeGemini:
|
||||
if textRequest.Stream {
|
||||
err, responseText := geminiChatStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeZhipu:
|
||||
if isStream {
|
||||
err, usage := zhipuStreamHandler(c, resp)
|
||||
|
@@ -3,15 +3,19 @@ package controller
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/image"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
var stopFinishReason = "stop"
|
||||
@@ -86,7 +90,33 @@ func countTokenMessages(messages []Message, model string) int {
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Content)
|
||||
switch v := message.Content.(type) {
|
||||
case string:
|
||||
tokenNum += getTokenNum(tokenEncoder, v)
|
||||
case []any:
|
||||
for _, it := range v {
|
||||
m := it.(map[string]any)
|
||||
switch m["type"] {
|
||||
case "text":
|
||||
tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
|
||||
case "image_url":
|
||||
imageUrl, ok := m["image_url"].(map[string]any)
|
||||
if ok {
|
||||
url := imageUrl["url"].(string)
|
||||
detail := ""
|
||||
if imageUrl["detail"] != nil {
|
||||
detail = imageUrl["detail"].(string)
|
||||
}
|
||||
imageTokens, err := countImageTokens(url, detail)
|
||||
if err != nil {
|
||||
common.SysError("error counting image tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += imageTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
@@ -97,13 +127,81 @@ func countTokenMessages(messages []Message, model string) int {
|
||||
return tokenNum
|
||||
}
|
||||
|
||||
const (
|
||||
lowDetailCost = 85
|
||||
highDetailCostPerTile = 170
|
||||
additionalCost = 85
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
func countImageTokens(url string, detail string) (_ int, err error) {
|
||||
var fetchSize = true
|
||||
var width, height int
|
||||
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
|
||||
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
|
||||
// According to the official guide, "low" disable the high-res model,
|
||||
// and only receive low-res 512px x 512px version of the image, indicating
|
||||
// that image is treated as low-res when size is smaller than 512px x 512px,
|
||||
// then we can assume that image size larger than 512px x 512px is treated
|
||||
// as high-res. Then we have the following logic:
|
||||
// if detail == "" || detail == "auto" {
|
||||
// width, height, err = image.GetImageSize(url)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// fetchSize = false
|
||||
// // not sure if this is correct
|
||||
// if width > 512 || height > 512 {
|
||||
// detail = "high"
|
||||
// } else {
|
||||
// detail = "low"
|
||||
// }
|
||||
// }
|
||||
|
||||
// However, in my test, it seems to be always the same as "high".
|
||||
// The following image, which is 125x50, is still treated as high-res, taken
|
||||
// 255 tokens in the response of non-stream chat completion api.
|
||||
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
|
||||
if detail == "" || detail == "auto" {
|
||||
// assume by test, not sure if this is correct
|
||||
detail = "high"
|
||||
}
|
||||
switch detail {
|
||||
case "low":
|
||||
return lowDetailCost, nil
|
||||
case "high":
|
||||
if fetchSize {
|
||||
width, height, err = image.GetImageSize(url)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if width > 2048 || height > 2048 { // max(width, height) > 2048
|
||||
ratio := float64(2048) / math.Max(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
if width > 768 && height > 768 { // min(width, height) > 768
|
||||
ratio := float64(768) / math.Min(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
|
||||
result := numSquares*highDetailCostPerTile + additionalCost
|
||||
return result, nil
|
||||
default:
|
||||
return 0, errors.New("invalid detail option")
|
||||
}
|
||||
}
|
||||
|
||||
func countTokenInput(input any, model string) int {
|
||||
switch input.(type) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return countTokenText(input.(string), model)
|
||||
return countTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range input.([]string) {
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return countTokenText(text, model)
|
||||
@@ -144,6 +242,19 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func setEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
@@ -152,11 +263,52 @@ func setEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: OpenAIError{
|
||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
@@ -170,27 +322,40 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var textResponse TextResponse
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
var errResponse GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
||||
if errResponse.Error.Message != "" {
|
||||
// OpenAI format error, so we override the default one
|
||||
openAIErrorWithStatusCode.OpenAIError = errResponse.Error
|
||||
} else {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage()
|
||||
}
|
||||
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
if channelType == common.ChannelTypeOpenAI {
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case common.ChannelTypeAzure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
return fullRequestURL
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
@@ -198,10 +363,23 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
if totalQuota <= 0 {
|
||||
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
||||
|
||||
func GetAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
|
@@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "assistant",
|
||||
@@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
||||
} else {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -230,7 +230,13 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
|
@@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "system",
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "user",
|
||||
@@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
} else {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -12,10 +12,49 @@ import (
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Content any `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
type TextContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type ImageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
content, ok := m.Content.(string)
|
||||
if ok {
|
||||
return content
|
||||
}
|
||||
contentList, ok := m.Content.([]any)
|
||||
if ok {
|
||||
var contentStr string
|
||||
for _, contentItem := range contentList {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if contentMap["type"] == "text" {
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentStr += subStr
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentStr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
@@ -31,19 +70,30 @@ const (
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
@@ -83,18 +133,39 @@ type TextRequest struct {
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n"`
|
||||
Size string `json:"size"`
|
||||
Quality string `json:"quality"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
Style string `json:"style"`
|
||||
User string `json:"user"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperResponse struct {
|
||||
type WhisperJSONResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperVerboseJSONResponse struct {
|
||||
Task string `json:"task,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Segments []Segment `json:"segments,omitempty"`
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
Id int `json:"id"`
|
||||
Seek int `json:"seek"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
AvgLogprob float64 `json:"avg_logprob"`
|
||||
CompressionRatio float64 `json:"compression_ratio"`
|
||||
NoSpeechProb float64 `json:"no_speech_prob"`
|
||||
}
|
||||
|
||||
type TextToSpeechRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Input string `json:"input" binding:"required"`
|
||||
@@ -165,7 +236,7 @@ type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
@@ -201,9 +272,9 @@ func Relay(c *gin.Context) {
|
||||
relayMode = RelayModeEdits
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
relayMode = RelayModeAudioSpeech
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") {
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
relayMode = RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") {
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
}
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
|
6
go.mod
6
go.mod
@@ -15,7 +15,9 @@ require (
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/pkoukk/tiktoken-go v0.1.5
|
||||
github.com/stretchr/testify v1.8.3
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/image v0.14.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@@ -26,6 +28,7 @@ require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
@@ -50,12 +53,13 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
6
go.sum
6
go.sum
@@ -152,6 +152,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
@@ -168,8 +170,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
@@ -119,6 +119,7 @@
|
||||
" 年 ": " y ",
|
||||
"未测试": "Not tested",
|
||||
"通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
|
||||
"已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
|
||||
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
|
||||
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
|
||||
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!",
|
||||
@@ -139,6 +140,7 @@
|
||||
"启用": "Enable",
|
||||
"编辑": "Edit",
|
||||
"添加新的渠道": "Add a new channel",
|
||||
"测试所有通道": "Test all channels",
|
||||
"测试所有已启用通道": "Test all enabled channels",
|
||||
"更新所有已启用通道余额": "Update the balance of all enabled channels",
|
||||
"刷新": "Refresh",
|
||||
|
@@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
requestURL := c.Request.URL.String()
|
||||
consumeQuota := true
|
||||
if strings.HasPrefix(requestURL, "/v1/models") {
|
||||
consumeQuota = false
|
||||
}
|
||||
c.Set("consume_quota", consumeQuota)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("channelId", parts[1])
|
||||
|
@@ -87,8 +87,12 @@ func Distribute() func(c *gin.Context) {
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
26
middleware/recover.go
Normal file
26
middleware/recover.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func RelayPanicRecover() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
|
||||
"type": "one_api_panic",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -59,7 +60,8 @@ func chooseDB() (*gorm.DB, error) {
|
||||
// Use SQLite
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
}
|
||||
|
@@ -34,6 +34,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
||||
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
|
||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
||||
@@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "AutomaticDisableChannelEnabled":
|
||||
common.AutomaticDisableChannelEnabled = boolValue
|
||||
case "AutomaticEnableChannelEnabled":
|
||||
common.AutomaticEnableChannelEnabled = boolValue
|
||||
case "ApproximateTokenEnabled":
|
||||
common.ApproximateTokenEnabled = boolValue
|
||||
case "LogConsumeEnabled":
|
||||
|
3
pull_request_template.md
Normal file
3
pull_request_template.md
Normal file
@@ -0,0 +1,3 @@
|
||||
close #issue_number
|
||||
|
||||
我已确认该 PR 已自测通过,相关截图如下:
|
@@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
relayV1Router.POST("/completions", controller.Relay)
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
@@ -35,12 +35,38 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
relayV1Router.POST("/assistants", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
|
||||
}
|
||||
}
|
||||
|
@@ -234,7 +234,7 @@ const ChannelsTable = () => {
|
||||
const res = await API.get(`/api/channel/test`);
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
|
||||
showInfo('已成功开始测试所有通道,请刷新页面查看结果。');
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
|
@@ -16,6 +16,7 @@ const OperationSetting = () => {
|
||||
ChatLink: '',
|
||||
QuotaPerUnit: 0,
|
||||
AutomaticDisableChannelEnabled: '',
|
||||
AutomaticEnableChannelEnabled: '',
|
||||
ChannelDisableThreshold: 0,
|
||||
LogConsumeEnabled: '',
|
||||
DisplayInCurrencyEnabled: '',
|
||||
@@ -269,6 +270,12 @@ const OperationSetting = () => {
|
||||
name='AutomaticDisableChannelEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
|
||||
label='成功时自动启用通道'
|
||||
name='AutomaticEnableChannelEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('monitor').then();
|
||||
|
@@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
|
||||
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
|
||||
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
|
||||
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
|
||||
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
|
||||
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
|
||||
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
||||
|
@@ -60,7 +60,7 @@ const EditChannel = () => {
|
||||
let localModels = [];
|
||||
switch (value) {
|
||||
case 14:
|
||||
localModels = ['claude-instant-1', 'claude-2'];
|
||||
localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'];
|
||||
break;
|
||||
case 11:
|
||||
localModels = ['PaLM-2'];
|
||||
@@ -69,7 +69,7 @@ const EditChannel = () => {
|
||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
|
||||
break;
|
||||
case 17:
|
||||
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
|
||||
localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1'];
|
||||
break;
|
||||
case 16:
|
||||
localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||
@@ -83,6 +83,9 @@ const EditChannel = () => {
|
||||
case 23:
|
||||
localModels = ['hunyuan'];
|
||||
break;
|
||||
case 24:
|
||||
localModels = ['gemini-pro'];
|
||||
break;
|
||||
}
|
||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||
}
|
||||
@@ -343,6 +346,20 @@ const EditChannel = () => {
|
||||
</Form.Field>
|
||||
)
|
||||
}
|
||||
{
|
||||
inputs.type === 17 && (
|
||||
<Form.Field>
|
||||
<Form.Input
|
||||
label='插件参数'
|
||||
name='other'
|
||||
placeholder={'请输入插件参数,即 X-DashScope-Plugin 请求头的取值'}
|
||||
onChange={handleInputChange}
|
||||
value={inputs.other}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</Form.Field>
|
||||
)
|
||||
}
|
||||
<Form.Field>
|
||||
<Form.Dropdown
|
||||
label='模型'
|
||||
|
Reference in New Issue
Block a user