Compare commits

...

26 Commits

Author SHA1 Message Date
JustSong
366b82128f fix: remove incorrect logging 2023-12-10 20:44:37 +08:00
JustSong
2a70744dbf feat: add panic recover middleware 2023-12-10 19:53:33 +08:00
Qiying Wang
4c5feee0b6 feat: add image counter for gpt-4 vision (#795) 2023-12-10 19:39:46 +08:00
igophper
9ba5388367 feat: refactor response parsing logic to support multiple formats (#782)
* feat: Refactor response parsing logic to support multiple formats

The parsing logic for responses in relay.go and relay-audio.go was refactored to support multiple response formats - 'json', 'text', 'srt', 'verbose_json', and 'vtt'. The existing `WhisperResponse` struct was renamed to `WhisperJsonResponse` and a new struct `WhisperVerboseJsonResponse` was added to support the 'verbose_json' format. Additional parsing functions were added to extract text from these new response types. This change was necessary to make the parsing logic more flexible and extendable for different types of responses.

* chore: update name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-10 18:39:14 +08:00
JustSong
379074f7d0 feat: support plugin for ali channel (close #797) 2023-12-10 17:22:52 +08:00
JustSong
01f7b0186f chore: add routes 2023-12-03 20:45:11 +08:00
Tillman Bailee
a3f80a3392 feat: enable channel when test succeed (#771)
* 增加功能: 渠道 - 测试所有通道; 设置 - 运营设置 - 监控设置 - 成功时自动启用通道

* refactor: update implementation

---------

Co-authored-by: liyujie <29959257@qq.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-03 20:10:57 +08:00
Zhengyi Dong
8f5b83562b fix: fix "invalidPayload" error when request Azure dall-e-3 api without optional parameter (#764)
* fix: based on #754 add 'omitempty' in ImageRequest to fit official api reference for relay

* Revert "fix: based on #754 add 'omitempty' in ImageRequest to fit official api reference for relay"

This reverts commit b526006ce0.

* fix: add missing omitempty

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-03 17:43:30 +08:00
ShinChven ✨
b7570d5c77 feat: support dalle for Azure (#754)
* feat: Add Message-ID to email headers to comply with RFC 5322

- Extract domain from SMTPFrom
- Generate a unique Message-ID
- Add Message-ID to email headers

* chore: check slice length

* feat: Add Azure compatibility for relayImageHelper

- Handle Azure channel requestURL compatibility
- Set api-key header for Azure channel authentication
- Handle Azure channel request body

fixes: https://github.com/songquanpeng/one-api/issues/751

* refactor: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-03 17:34:59 +08:00
JustSong
0e73418cdf fix: fix log recording & error handling for relay-audio 2023-11-26 12:05:16 +08:00
JustSong
9889377f0e feat: support claude-2.x (close #736) 2023-11-24 21:39:44 +08:00
JustSong
b273464e77 docs: update readme 2023-11-24 21:23:16 +08:00
JustSong
b4e43d97fd docs: add pr template 2023-11-24 21:21:03 +08:00
Ian Li
3347a44023 feat: support Azure's Whisper model (#720) 2023-11-24 21:10:18 +08:00
Tillman Bailee
923e24534b fix: add Date header for email (#742)
* 修复自建邮箱发送错误: INVALID HEADER Missing required header field: "Date"

* chore: fix style

---------

Co-authored-by: liyujie <29959257@qq.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-24 20:56:53 +08:00
ShinChven ✨
b4d67ca614 fix: add Message-ID header for email (#732)
* feat: Add Message-ID to email headers to comply with RFC 5322

- Extract domain from SMTPFrom
- Generate a unique Message-ID
- Add Message-ID to email headers

* chore: check slice length

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-24 20:52:59 +08:00
igophper
d85e356b6e refactor: remove consumeQuota related logic (#738)
* feat: 删除relay-text中的consumeQuota变量

该变量始终为true,可以删除

* chore: remove useless code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-24 20:42:29 +08:00
JustSong
495fc628e4 feat: support gpt-4 with vision (#683, #714) 2023-11-19 18:38:54 +08:00
JustSong
76f9288c34 feat: update request struct (close #708) 2023-11-19 17:50:30 +08:00
JustSong
915d13fdd4 docs: update readme (#724) 2023-11-19 17:22:35 +08:00
Ian Li
969f539777 fix: skip JSON deserialization when accessing transcriptions and translations (#718)
* fix: Skip JSON deserialization when accessing transcriptions and translations.

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-19 16:11:39 +08:00
Buer
54e5f8ecd2 feat: support cloudflare gateway for azure (#666)
* 🐛 Fix cloudflare gateway request failure

* 🐛 fix channel test url error
2023-11-19 15:52:35 +08:00
Mikey
34d517cfa2 fix: cloudflare test & expose detailed info about test failures (#715)
* fix: cloudflare test & expose detailed info about test failures

* fix: cloudflare test & expose detailed info about test failures

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 21:45:55 +08:00
ckt1031
ddcaf95f5f feat: support tts model (#713)
* Added support for Text-to-Speech models and
endpoints

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 21:18:51 +08:00
ckt1031
1d15157f7d feat: keep sync with dall-e updates (#679)
* Updated ImageRequest struct and OpenAIModels,
added new Dall-E models and size ratios

* Fixed suspect `or`

* Refactored size ratio calculation in
relayImageHelper function

* Updated the format of resolution keys in
DalleSizeRatios map

* Added error handling for unsupported image size in
relayImageHelper function

* Added validation for number of generated images
and defined image generation ratios

* Refactored variable name from
DalleGenerationImageAmountRatios to
DalleGenerationImageAmounts

* Added validation for prompt length in
relayImageHelper function

* Updated model validation and removed size not
supported error in relayImageHelper function

* Refactored image size and model validation in
relayImageHelper function

* chore: discard binary file

* chore: update impl

---------

Co-authored-by: cktsun1031 <65409152+cktsun1031@users.noreply.github.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 20:03:16 +08:00
管宜尧
de7b9710a5 fix: fix PaLM not working issue (#667)
* bugfix for #515 最新版本谷歌PaLM模型无法使用

* update

* chore: remove unrelated file

* chore: add comment

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 19:40:59 +08:00
35 changed files with 1078 additions and 267 deletions

View File

@@ -51,15 +51,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
<a href="https://iamazing.cn/page/reward">赞赏支持</a>
</p>
> **Note**
> [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
>
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> **Warning**
> [!WARNING]
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> **Warning**
> [!WARNING]
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`
## 功能
@@ -92,14 +92,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
12. 支持**用户邀请奖励**。
13. 支持以美元为单位显示额度。
14. 支持发布公告,设置充值链接,设置新用户初始额度。
15. 支持模型映射,重定向用户的请求模型。
15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功
16. 支持失败自动重试。
17. 支持绘图接口。
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
20. 支持通过系统访问令牌访问管理 API。
20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法)
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。

View File

@@ -78,6 +78,7 @@ var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false

View File

@@ -1,11 +1,13 @@
package common
import (
"crypto/rand"
"crypto/tls"
"encoding/base64"
"fmt"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
@@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error {
SMTPFrom = SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
}
// Generate a unique Message-ID
buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
return err
}
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, content))
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
var err error
if SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"strings"
)
func UnmarshalBodyReusable(c *gin.Context, v any) error {
@@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil {
return err
}
err = json.Unmarshal(requestBody, &v)
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil {
return err
}

47
common/image/image.go Normal file
View File

@@ -0,0 +1,47 @@
package image
import (
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"net/http"
"regexp"
"strings"
_ "golang.org/x/image/webp"
)
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
img, _, err := image.DecodeConfig(resp.Body)
if err != nil {
return
}
return img.Width, img.Height, nil
}
var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
encoded = strings.TrimPrefix(encoded, "data:image/png;base64,")
base64 := strings.NewReader(reg.ReplaceAllString(encoded, ""))
img, _, err := image.DecodeConfig(base64)
if err != nil {
return
}
return img.Width, img.Height, nil
}
func GetImageSize(image string) (width int, height int, err error) {
if strings.HasPrefix(image, "data:image/") {
return GetImageSizeFromBase64(image)
}
return GetImageSizeFromUrl(image)
}

154
common/image/image_test.go Normal file
View File

@@ -0,0 +1,154 @@
package image_test
import (
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strconv"
"strings"
"testing"
img "one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
)
type CountingReader struct {
reader io.Reader
BytesRead int
}
func (r *CountingReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.BytesRead += n
return n, err
}
var (
cases = []struct {
url string
format string
width int
height int
}{
{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669},
{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592},
{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985},
{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533},
{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230},
}
)
func TestDecode(t *testing.T) {
// Bytes read: varies sometimes
// jpeg: 1063892
// png: 294462
// webp: 99529
// gif: 956153
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 4096
// png: 4096
// webp: 4096
// gif: 4096
// jpeg#01: 4096
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestBase64(t *testing.T) {
// Bytes read:
// jpeg: 1063892
// png: 294462
// webp: 99072
// gif: 953856
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 1536
// png: 768
// webp: 768
// gif: 1536
// jpeg#01: 3840
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestGetImageSize(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
width, height, err := img.GetImageSize(c.url)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

View File

@@ -6,6 +6,29 @@ import (
"time"
)
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
@@ -36,7 +59,11 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10,
"curie": 10,
"babbage": 10,
@@ -45,9 +72,12 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e": 8,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"claude-2.0": 5.51, // $11.02 / 1M tokens
"claude-2.1": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens

View File

@@ -5,14 +5,15 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
@@ -43,16 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else {
if channel.GetBaseURL() != "" {
requestURL = channel.GetBaseURL()
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = baseURL
}
requestURL += "/v1/chat/completions"
}
// for Cloudflare AI gateway: https://github.com/songquanpeng/one-api/pull/639
requestURL = strings.Replace(requestURL, "/v1/v1", "/v1", 1)
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
jsonData, err := json.Marshal(request)
if err != nil {
return err, nil
@@ -73,11 +72,18 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
body, err := io.ReadAll(resp.Body)
if err != nil {
return err, nil
}
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
if response.Usage.CompletionTokens == 0 {
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"
}
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
return nil, nil
@@ -139,20 +145,32 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
// enable & notify
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
@@ -175,20 +193,21 @@ func testAllChannels(notify bool) error {
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if milliseconds > disableThreshold {
if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
disableChannel(channel.Id, channel.Name, err.Error())
}
if shouldDisableChannel(openaiErr, -1) {
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
disableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}

View File

@@ -55,12 +55,21 @@ func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "dall-e",
Id: "dall-e-2",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e",
Root: "dall-e-2",
Parent: nil,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-3",
Parent: nil,
},
{
@@ -72,6 +81,42 @@ func init() {
Root: "whisper-1",
Parent: nil,
},
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
@@ -315,6 +360,24 @@ func init() {
Root: "claude-2",
Parent: nil,
},
{
Id: "claude-2.1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.1",
Parent: nil,
},
{
Id: "claude-2.0",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.0",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",

View File

@@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content
query = request.Messages[len(request.Messages)-1].StringContent()
}
return &AIProxyLibraryRequest{
Model: request.Model,

View File

@@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.Content,
User: message.StringContent(),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.Content
prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
User: message.Content,
Bot: request.Messages[i+1].Content,
User: message.StringContent(),
Bot: request.Messages[i+1].StringContent(),
})
i++
}

View File

@@ -1,6 +1,7 @@
package controller
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -11,6 +12,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -21,16 +23,41 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
tokenName := c.GetString("token_name")
var ttsRequest TextToSpeechRequest
if relayMode == RelayModeAudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
if err != nil {
return errorWrapper(err, "invalid_json", http.StatusBadRequest)
}
audioModel = ttsRequest.Model
// Check if text is too long 4096
if len(ttsRequest.Input) > 4096 {
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
}
}
preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
var quota int
var preConsumedQuota int
switch relayMode {
case RelayModeAudioSpeech:
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
// Check if user quota is enough
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
@@ -70,13 +97,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := GetAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
}
requestBody := &bytes.Buffer{}
_, err = io.Copy(requestBody, c.Request.Body)
if err != nil {
return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
responseFormat := c.DefaultPostForm("response_format", "json")
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
req.Header.Set("api-key", apiKey)
req.ContentLength = c.Request.ContentLength
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -93,47 +141,65 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var audioResponse AudioResponse
if relayMode != RelayModeAudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var openAIErr TextResponse
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
if openAIErr.Error.Message != "" {
return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
}
}
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
default:
return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
}
if err != nil {
return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
}
quota = countTokenText(text, audioModel)
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
@@ -149,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
return nil
}
func getTextFromVTT(body []byte) (string, error) {
return getTextFromSRT(body)
}
func getTextFromVerboseJSON(body []byte) (string, error) {
var whisperResponse WhisperVerboseJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}
func getTextFromSRT(body []byte) (string, error) {
scanner := bufio.NewScanner(strings.NewReader(string(body)))
var builder strings.Builder
var textLine bool
for scanner.Scan() {
line := scanner.Text()
if textLine {
builder.WriteString(line)
textLine = false
continue
} else if strings.Contains(line, "-->") {
textLine = true
continue
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return builder.String(), nil
}
func getTextFromText(body []byte) (string, error) {
return strings.TrimSuffix(string(body), "\n"), nil
}
func getTextFromJSON(body []byte) (string, error) {
var whisperResponse WhisperJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}

View File

@@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
@@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.Content,
Content: message.StringContent(),
})
}
}

View File

@@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"

View File

@@ -6,44 +6,80 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
func isWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}
min := common.DalleGenerationImageAmounts[element][0]
max := common.DalleGenerationImageAmounts[element][1]
return value >= min && value <= max
}
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e"
imageModel := "dall-e-2"
imageSize := "1024x1024"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var imageRequest ImageRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
}
// Model validation
if imageRequest.Model != "" {
imageModel = imageRequest.Model
}
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
// Check if model is supported
if hasValidSize {
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
if imageSize == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
} else {
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
// Check prompt length
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// N should between 1 and 10
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
// map model name
@@ -66,8 +102,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
}
var requestBody io.Reader
if isModelMapped {
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
@@ -82,18 +125,9 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
if userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
@@ -101,7 +135,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -122,43 +162,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])

View File

@@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText
}
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
@@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
@@ -132,7 +131,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model)
completionTokens += countTokenText(choice.Message.StringContent(), model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,

View File

@@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.Content,
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"

View File

@@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
@@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
continue
}
messages = append(messages, TencentMessage{
Content: message.Content,
Content: message.StringContent(),
Role: message.Role,
})
}

View File

@@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
@@ -132,11 +129,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
apiVersion := GetAPIVersion(c)
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
@@ -147,7 +140,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
@@ -233,7 +228,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if consumeQuota && preConsumedQuota > 0 {
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
@@ -365,8 +360,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
@@ -410,37 +410,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
if consumeQuota {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
switch apiType {
@@ -454,7 +453,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}

View File

@@ -1,15 +1,21 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io"
"math"
"net/http"
"one-api/common"
"one-api/common/image"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
)
var stopFinishReason = "stop"
@@ -84,7 +90,33 @@ func countTokenMessages(messages []Message, model string) int {
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content)
switch v := message.Content.(type) {
case string:
tokenNum += getTokenNum(tokenEncoder, v)
case []any:
for _, it := range v {
m := it.(map[string]any)
switch m["type"] {
case "text":
tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
case "image_url":
imageUrl, ok := m["image_url"].(map[string]any)
if ok {
url := imageUrl["url"].(string)
detail := ""
if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail)
if err != nil {
common.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
}
}
}
}
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
@@ -95,13 +127,81 @@ func countTokenMessages(messages []Message, model string) int {
return tokenNum
}
const (
lowDetailCost = 85
highDetailCostPerTile = 170
additionalCost = 85
)
// https://platform.openai.com/docs/guides/vision/calculating-costs
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
func countImageTokens(url string, detail string) (_ int, err error) {
var fetchSize = true
var width, height int
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
// According to the official guide, "low" disable the high-res model,
// and only receive low-res 512px x 512px version of the image, indicating
// that image is treated as low-res when size is smaller than 512px x 512px,
// then we can assume that image size larger than 512px x 512px is treated
// as high-res. Then we have the following logic:
// if detail == "" || detail == "auto" {
// width, height, err = image.GetImageSize(url)
// if err != nil {
// return 0, err
// }
// fetchSize = false
// // not sure if this is correct
// if width > 512 || height > 512 {
// detail = "high"
// } else {
// detail = "low"
// }
// }
// However, in my test, it seems to be always the same as "high".
// The following image, which is 125x50, is still treated as high-res, taken
// 255 tokens in the response of non-stream chat completion api.
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
if detail == "" || detail == "auto" {
// assume by test, not sure if this is correct
detail = "high"
}
switch detail {
case "low":
return lowDetailCost, nil
case "high":
if fetchSize {
width, height, err = image.GetImageSize(url)
if err != nil {
return 0, err
}
}
if width > 2048 || height > 2048 { // max(width, height) > 2048
ratio := float64(2048) / math.Max(float64(width), float64(height))
width = int(float64(width) * ratio)
height = int(float64(height) * ratio)
}
if width > 768 && height > 768 { // min(width, height) > 768
ratio := float64(768) / math.Min(float64(width), float64(height))
width = int(float64(width) * ratio)
height = int(float64(height) * ratio)
}
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
result := numSquares*highDetailCostPerTile + additionalCost
return result, nil
default:
return 0, errors.New("invalid detail option")
}
}
func countTokenInput(input any, model string) int {
switch input.(type) {
switch v := input.(type) {
case string:
return countTokenText(input.(string), model)
return countTokenText(v, model)
case []string:
text := ""
for _, s := range input.([]string) {
for _, s := range v {
text += s
}
return countTokenText(text, model)
@@ -142,6 +242,19 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
return false
}
func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
@@ -179,10 +292,45 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}

View File

@@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
@@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.Content,
Content: message.StringContent(),
})
}
}

View File

@@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
if message.Role == "system" {
messages = append(messages, ZhipuMessage{
Role: "system",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, ZhipuMessage{
Role: "user",
@@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} else {
messages = append(messages, ZhipuMessage{
Role: message.Role,
Content: message.Content,
Content: message.StringContent(),
})
}
}

View File

@@ -12,10 +12,49 @@ import (
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == "text" {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
@@ -24,24 +63,37 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudio
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
@@ -77,16 +129,51 @@ type TextRequest struct {
//Stream bool `json:"stream"`
}
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct {
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type AudioResponse struct {
type WhisperJSONResponse struct {
Text string `json:"text,omitempty"`
}
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
@@ -183,14 +270,22 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode = RelayModeAudio
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudio:
case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)

6
go.mod
View File

@@ -15,7 +15,9 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.14.0
golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
@@ -26,6 +28,7 @@ require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
@@ -50,12 +53,13 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

6
go.sum
View File

@@ -152,6 +152,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
@@ -168,8 +170,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -119,6 +119,7 @@
" 年 ": " y ",
"未测试": "Not tested",
"通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!",
@@ -139,6 +140,7 @@
"启用": "Enable",
"编辑": "Edit",
"添加新的渠道": "Add a new channel",
"测试所有通道": "Test all channels",
"测试所有已启用通道": "Test all enabled channels",
"更新所有已启用通道余额": "Update the balance of all enabled channels",
"刷新": "Refresh",

View File

@@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])

View File

@@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
var modelRequest ModelRequest
var err error
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
@@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) {
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
@@ -92,6 +89,8 @@ func Distribute() func(c *gin.Context) {
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
c.Next()
}

26
middleware/recover.go Normal file
View File

@@ -0,0 +1,26 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
)
func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
"type": "one_api_panic",
},
})
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -34,6 +34,7 @@ func InitOptionMap() {
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
@@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
common.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
common.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled":

3
pull_request_template.md Normal file
View File

@@ -0,0 +1,3 @@
close #issue_number
我已确认该 PR 已自测通过,相关截图如下:

View File

@@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("/:model", controller.RetrieveModel)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)
@@ -29,17 +29,44 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented)
relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented)
relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented)
relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
relayV1Router.POST("/assistants", controller.RelayNotImplemented)
relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented)
relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)
relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented)
relayV1Router.GET("/assistants", controller.RelayNotImplemented)
relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented)
relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented)
relayV1Router.POST("/threads", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id", controller.RelayNotImplemented)
relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
}
}

View File

@@ -234,7 +234,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
if (success) {
showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
showInfo('已成功开始测试所有通道,请刷新页面查看结果。');
} else {
showError(message);
}

View File

@@ -16,6 +16,7 @@ const OperationSetting = () => {
ChatLink: '',
QuotaPerUnit: 0,
AutomaticDisableChannelEnabled: '',
AutomaticEnableChannelEnabled: '',
ChannelDisableThreshold: 0,
LogConsumeEnabled: '',
DisplayInCurrencyEnabled: '',
@@ -269,6 +270,12 @@ const OperationSetting = () => {
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
label='成功时自动启用通道'
name='AutomaticEnableChannelEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('monitor').then();

View File

@@ -60,7 +60,7 @@ const EditChannel = () => {
let localModels = [];
switch (value) {
case 14:
localModels = ['claude-instant-1', 'claude-2'];
localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'];
break;
case 11:
localModels = ['PaLM-2'];
@@ -343,6 +343,20 @@ const EditChannel = () => {
</Form.Field>
)
}
{
inputs.type === 17 && (
<Form.Field>
<Form.Input
label='插件参数'
name='other'
placeholder={'请输入插件参数,即 X-DashScope-Plugin 请求头的取值'}
onChange={handleInputChange}
value={inputs.other}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Form.Field>
<Form.Dropdown
label='模型'