mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 19:03:43 +08:00 
			
		
		
		
	Compare commits
	
		
			38 Commits
		
	
	
		
			v0.5.7-alp
			...
			v0.5.9-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 0e73418cdf | ||
|  | 9889377f0e | ||
|  | b273464e77 | ||
|  | b4e43d97fd | ||
|  | 3347a44023 | ||
|  | 923e24534b | ||
|  | b4d67ca614 | ||
|  | d85e356b6e | ||
|  | 495fc628e4 | ||
|  | 76f9288c34 | ||
|  | 915d13fdd4 | ||
|  | 969f539777 | ||
|  | 54e5f8ecd2 | ||
|  | 34d517cfa2 | ||
|  | ddcaf95f5f | ||
|  | 1d15157f7d | ||
|  | de7b9710a5 | ||
|  | 58bb3ab6f6 | ||
|  | d306cb5229 | ||
|  | 6c5307d0c4 | ||
|  | 7c4505bdfc | ||
|  | 9d43ec57d8 | ||
|  | e5311892d1 | ||
|  | bc7c9105f4 | ||
|  | 3fe76c8af7 | ||
|  | c70c614018 | ||
|  | 0d87de697c | ||
|  | aec343dc38 | ||
|  | 89d458b9cf | ||
|  | 63fafba112 | ||
|  | a398f35968 | ||
|  | 57aa637c77 | ||
|  | 3b483639a4 | ||
|  | 22980b4c44 | ||
|  | 64cdb7eafb | ||
|  | 824444244b | ||
|  | fbe9985f57 | ||
|  | a27a5bcc06 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -5,4 +5,5 @@ upload | ||||
| *.db | ||||
| build | ||||
| *.db-journal | ||||
| logs | ||||
| logs | ||||
| data | ||||
| @@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | ||||
|  | ||||
| > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. First, fork the code. | ||||
| 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. | ||||
| 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | ||||
|   | ||||
| @@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for | ||||
|  | ||||
| > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. まず、コードをフォークする。 | ||||
| 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 | ||||
| 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 | ||||
|   | ||||
							
								
								
									
										31
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								README.md
									
									
									
									
									
								
							| @@ -51,15 +51,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|   <a href="https://iamazing.cn/page/reward">赞赏支持</a> | ||||
| </p> | ||||
|  | ||||
| > **Note** | ||||
| > [!NOTE] | ||||
| > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||
| >  | ||||
| > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | ||||
|  | ||||
| > **Warning** | ||||
| > [!WARNING] | ||||
| > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | ||||
|  | ||||
| > **Warning** | ||||
| > [!WARNING] | ||||
| > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | ||||
|  | ||||
| ## 功能 | ||||
| @@ -75,7 +75,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||
| 2. 支持配置镜像以及众多第三方代理服务: | ||||
|    + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||
|    + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) | ||||
|    + [x] [API2D](https://api2d.com/r/197971) | ||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||
| @@ -92,14 +92,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 12. 支持**用户邀请奖励**。 | ||||
| 13. 支持以美元为单位显示额度。 | ||||
| 14. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||
| 15. 支持模型映射,重定向用户的请求模型。 | ||||
| 15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。 | ||||
| 16. 支持失败自动重试。 | ||||
| 17. 支持绘图接口。 | ||||
| 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 | ||||
| 19. 支持丰富的**自定义**设置, | ||||
|     1. 支持自定义系统名称,logo 以及页脚。 | ||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||
| 20. 支持通过系统访问令牌访问管理 API。 | ||||
| 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
| @@ -160,6 +160,19 @@ sudo service nginx restart | ||||
|  | ||||
| 初始账号用户名为 `root`,密码为 `123456`。 | ||||
|  | ||||
|  | ||||
| ### 基于 Docker Compose 进行部署 | ||||
|  | ||||
| > 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 | ||||
|  | ||||
| ```shell | ||||
| # 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 | ||||
| docker-compose up -d | ||||
|  | ||||
| # 查看部署状态 | ||||
| docker-compose ps | ||||
| ``` | ||||
|  | ||||
| ### 手动部署 | ||||
| 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: | ||||
|    ```shell | ||||
| @@ -249,6 +262,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
|  | ||||
| > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. 首先 fork 一份代码。 | ||||
| 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | ||||
| 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 | ||||
| @@ -352,6 +367,10 @@ graph LR | ||||
| 13. 请求频率限制: | ||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||
| 14. 编码器缓存设置: | ||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
|   | ||||
| @@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | ||||
| var DisplayInCurrencyEnabled = true | ||||
| var DisplayTokenStatEnabled = true | ||||
|  | ||||
| var UsingSQLite = false | ||||
|  | ||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||
|  | ||||
| var SessionSecret = uuid.New().String() | ||||
| var SQLitePath = "one-api.db" | ||||
|  | ||||
| var OptionMap map[string]string | ||||
| var OptionMapRWMutex sync.RWMutex | ||||
| @@ -98,6 +95,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||
| var BatchUpdateEnabled = false | ||||
| var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||
|  | ||||
| var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second | ||||
|  | ||||
| const ( | ||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | ||||
| ) | ||||
|   | ||||
							
								
								
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package common | ||||
|  | ||||
| var UsingSQLite = false | ||||
| var UsingPostgreSQL = false | ||||
|  | ||||
| var SQLitePath = "one-api.db" | ||||
| @@ -1,11 +1,13 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"net/smtp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| @@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 		SMTPFrom = SMTPAccount | ||||
| 	} | ||||
| 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) | ||||
|  | ||||
| 	// Extract domain from SMTPFrom | ||||
| 	parts := strings.Split(SMTPFrom, "@") | ||||
| 	var domain string | ||||
| 	if len(parts) > 1 { | ||||
| 		domain = parts[1] | ||||
| 	} | ||||
| 	// Generate a unique Message-ID | ||||
| 	buf := make([]byte, 16) | ||||
| 	_, err := rand.Read(buf) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	messageId := fmt.Sprintf("<%x@%s>", buf, domain) | ||||
|  | ||||
| 	mail := []byte(fmt.Sprintf("To: %s\r\n"+ | ||||
| 		"From: %s<%s>\r\n"+ | ||||
| 		"Subject: %s\r\n"+ | ||||
| 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 | ||||
| 		"Date: %s\r\n"+ | ||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, content)) | ||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||
| 	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) | ||||
| 	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) | ||||
| 	to := strings.Split(receiver, ";") | ||||
| 	var err error | ||||
|  | ||||
| 	if SMTPPort == 465 { | ||||
| 		tlsConfig := &tls.Config{ | ||||
| 			InsecureSkipVerify: true, | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| @@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	err = json.Unmarshal(requestBody, &v) | ||||
| 	contentType := c.Request.Header.Get("Content-Type") | ||||
| 	if strings.HasPrefix(contentType, "application/json") { | ||||
| 		err = json.Unmarshal(requestBody, &v) | ||||
| 	} else { | ||||
| 		// skip for now | ||||
| 		// TODO: someday non json request have variant model, we will need to implementation this | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -3,8 +3,32 @@ package common | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var DalleSizeRatios = map[string]map[string]float64{ | ||||
| 	"dall-e-2": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1.125, | ||||
| 		"1024x1024": 1.25, | ||||
| 	}, | ||||
| 	"dall-e-3": { | ||||
| 		"1024x1024": 1, | ||||
| 		"1024x1792": 2, | ||||
| 		"1792x1024": 2, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var DalleGenerationImageAmounts = map[string][2]int{ | ||||
| 	"dall-e-2": {1, 10}, | ||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||
| } | ||||
|  | ||||
| var DalleImagePromptLengthLimitations = map[string]int{ | ||||
| 	"dall-e-2": 1000, | ||||
| 	"dall-e-3": 4000, | ||||
| } | ||||
|  | ||||
| // ModelRatio | ||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||
| @@ -19,12 +43,15 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-4-32k":                 30, | ||||
| 	"gpt-4-32k-0314":            30, | ||||
| 	"gpt-4-32k-0613":            30, | ||||
| 	"gpt-4-1106-preview":        5,    // $0.01 / 1K tokens | ||||
| 	"gpt-4-vision-preview":      5,    // $0.01 / 1K tokens | ||||
| 	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-0301":        0.75, | ||||
| 	"gpt-3.5-turbo-0613":        0.75, | ||||
| 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | ||||
| 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||
| 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-1106":        0.5,  // $0.001 / 1K tokens | ||||
| 	"text-ada-001":              0.2, | ||||
| 	"text-babbage-001":          0.25, | ||||
| 	"text-curie-001":            1, | ||||
| @@ -32,7 +59,11 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-davinci-003":          10, | ||||
| 	"text-davinci-edit-001":     10, | ||||
| 	"code-davinci-edit-001":     10, | ||||
| 	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||
| 	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||
| 	"tts-1":                     7.5, // $0.015 / 1K characters | ||||
| 	"tts-1-1106":                7.5, | ||||
| 	"tts-1-hd":                  15, // $0.030 / 1K characters | ||||
| 	"tts-1-hd-1106":             15, | ||||
| 	"davinci":                   10, | ||||
| 	"curie":                     10, | ||||
| 	"babbage":                   10, | ||||
| @@ -41,13 +72,18 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-search-ada-doc-001":   10, | ||||
| 	"text-moderation-stable":    0.1, | ||||
| 	"text-moderation-latest":    0.1, | ||||
| 	"dall-e":                    8, | ||||
| 	"dall-e-2":                  8,      // $0.016 - $0.020 / image | ||||
| 	"dall-e-3":                  20,     // $0.040 - $0.120 / image | ||||
| 	"claude-instant-1":          0.815,  // $1.63 / 1M tokens | ||||
| 	"claude-2":                  5.51,   // $11.02 / 1M tokens | ||||
| 	"claude-2.0":                5.51,   // $11.02 / 1M tokens | ||||
| 	"claude-2.1":                5.51,   // $11.02 / 1M tokens | ||||
| 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | ||||
| 	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens | ||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                    1, | ||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens | ||||
| @@ -86,9 +122,24 @@ func GetModelRatio(name string) float64 { | ||||
|  | ||||
| func GetCompletionRatio(name string) float64 { | ||||
| 	if strings.HasPrefix(name, "gpt-3.5") { | ||||
| 		if strings.HasSuffix(name, "1106") { | ||||
| 			return 2 | ||||
| 		} | ||||
| 		if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { | ||||
| 			// TODO: clear this after 2023-12-11 | ||||
| 			now := time.Now() | ||||
| 			// https://platform.openai.com/docs/models/continuous-model-upgrades | ||||
| 			// if after 2023-12-11, use 2 | ||||
| 			if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { | ||||
| 				return 2 | ||||
| 			} | ||||
| 		} | ||||
| 		return 1.333333 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-4") { | ||||
| 		if strings.HasSuffix(name, "preview") { | ||||
| 			return 3 | ||||
| 		} | ||||
| 		return 2 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "claude-instant-1") { | ||||
|   | ||||
| @@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int { | ||||
| func MessageWithRequestId(message string, id string) string { | ||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||
| } | ||||
|  | ||||
| func String2Int(str string) int { | ||||
| 	num, err := strconv.Atoi(str) | ||||
| 	if err != nil { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|   | ||||
| @@ -5,13 +5,15 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||
| @@ -42,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 	} | ||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) | ||||
| 		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) | ||||
| 	} else { | ||||
| 		if channel.GetBaseURL() != "" { | ||||
| 			requestURL = channel.GetBaseURL() | ||||
| 		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { | ||||
| 			requestURL = baseURL | ||||
| 		} | ||||
| 		requestURL += "/v1/chat/completions" | ||||
| 	} | ||||
|  | ||||
| 		requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(request) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| @@ -70,10 +72,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	var response TextResponse | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&response) | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil | ||||
| 	} | ||||
| 	if response.Usage.CompletionTokens == 0 { | ||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||
| 	} | ||||
|   | ||||
| @@ -127,8 +127,8 @@ func DeleteChannel(c *gin.Context) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func DeleteManuallyDisabledChannel(c *gin.Context) { | ||||
| 	rows, err := model.DeleteChannelByStatus(common.ChannelStatusManuallyDisabled) | ||||
| func DeleteDisabledChannel(c *gin.Context) { | ||||
| 	rows, err := model.DeleteDisabledChannel() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
|   | ||||
| @@ -55,12 +55,21 @@ func init() { | ||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| 	openAIModels = []OpenAIModels{ | ||||
| 		{ | ||||
| 			Id:         "dall-e", | ||||
| 			Id:         "dall-e-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "dall-e", | ||||
| 			Root:       "dall-e-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "dall-e-3", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "dall-e-3", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -72,6 +81,42 @@ func init() { | ||||
| 			Root:       "whisper-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-hd", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-hd", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-hd-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-hd-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo", | ||||
| 			Object:     "model", | ||||
| @@ -117,6 +162,15 @@ func init() { | ||||
| 			Root:       "gpt-3.5-turbo-16k-0613", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-3.5-turbo-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo-instruct", | ||||
| 			Object:     "model", | ||||
| @@ -180,6 +234,24 @@ func init() { | ||||
| 			Root:       "gpt-4-32k-0613", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-4-1106-preview", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-4-1106-preview", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-4-vision-preview", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-4-vision-preview", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "text-embedding-ada-002", | ||||
| 			Object:     "model", | ||||
| @@ -274,7 +346,7 @@ func init() { | ||||
| 			Id:         "claude-instant-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-instant-1", | ||||
| 			Parent:     nil, | ||||
| @@ -283,11 +355,29 @@ func init() { | ||||
| 			Id:         "claude-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2.1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2.1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2.0", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2.0", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot", | ||||
| 			Object:     "model", | ||||
| @@ -306,6 +396,15 @@ func init() { | ||||
| 			Root:       "ERNIE-Bot-turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot-4", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "ERNIE-Bot-4", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "Embedding-V1", | ||||
| 			Object:     "model", | ||||
| @@ -324,6 +423,15 @@ func init() { | ||||
| 			Root:       "PaLM-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_turbo", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_pro", | ||||
| 			Object:     "model", | ||||
|   | ||||
| @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct { | ||||
| func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { | ||||
| 	query := "" | ||||
| 	if len(request.Messages) != 0 { | ||||
| 		query = request.Messages[len(request.Messages)-1].Content | ||||
| 		query = request.Messages[len(request.Messages)-1].StringContent() | ||||
| 	} | ||||
| 	return &AIProxyLibraryRequest{ | ||||
| 		Model:  request.Model, | ||||
|   | ||||
| @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				User: message.StringContent(), | ||||
| 				Bot:  "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			if i == len(request.Messages)-1 { | ||||
| 				prompt = message.Content | ||||
| 				prompt = message.StringContent() | ||||
| 				break | ||||
| 			} | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  request.Messages[i+1].Content, | ||||
| 				User: message.StringContent(), | ||||
| 				Bot:  request.Messages[i+1].StringContent(), | ||||
| 			}) | ||||
| 			i++ | ||||
| 		} | ||||
|   | ||||
| @@ -6,12 +6,12 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| @@ -22,16 +22,41 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
| 	tokenName := c.GetString("token_name") | ||||
|  | ||||
| 	var ttsRequest TextToSpeechRequest | ||||
| 	if relayMode == RelayModeAudioSpeech { | ||||
| 		// Read JSON | ||||
| 		err := common.UnmarshalBodyReusable(c, &ttsRequest) | ||||
| 		// Check if JSON is valid | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "invalid_json", http.StatusBadRequest) | ||||
| 		} | ||||
| 		audioModel = ttsRequest.Model | ||||
| 		// Check if text is too long 4096 | ||||
| 		if len(ttsRequest.Input) > 4096 { | ||||
| 			return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	preConsumedTokens := common.PreConsumedQuota | ||||
| 	modelRatio := common.GetModelRatio(audioModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||
| 	var quota int | ||||
| 	var preConsumedQuota int | ||||
| 	switch relayMode { | ||||
| 	case RelayModeAudioSpeech: | ||||
| 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) | ||||
| 		quota = preConsumedQuota | ||||
| 	default: | ||||
| 		preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) | ||||
| 	} | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	// Check if user quota is enough | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| @@ -66,19 +91,38 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
|  | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
|  | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||
| 		query := c.Request.URL.Query() | ||||
| 		apiVersion := query.Get("api-version") | ||||
| 		if apiVersion == "" { | ||||
| 			apiVersion = c.GetString("api_version") | ||||
| 		} | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | ||||
| 	} | ||||
|  | ||||
| 	requestBody := c.Request.Body | ||||
|  | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
|  | ||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		req.Header.Set("api-key", apiKey) | ||||
| 		req.ContentLength = c.Request.ContentLength | ||||
| 	} else { | ||||
| 		req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
|  | ||||
| @@ -95,47 +139,44 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var audioResponse AudioResponse | ||||
|  | ||||
| 	if relayMode != RelayModeAudioSpeech { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		var whisperResponse WhisperResponse | ||||
| 		err = json.Unmarshal(responseBody, &whisperResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		quota = countTokenText(whisperResponse.Text, audioModel) | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		if preConsumedQuota > 0 { | ||||
| 			// we need to roll back the pre-consumed quota | ||||
| 			defer func(ctx context.Context) { | ||||
| 				go func() { | ||||
| 					// negative means add quota back for token & user | ||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||
| 					if err != nil { | ||||
| 						common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) | ||||
| 					} | ||||
| 				}() | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		return relayErrorHandler(resp) | ||||
| 	} | ||||
| 	quotaDelta := quota - preConsumedQuota | ||||
| 	defer func(ctx context.Context) { | ||||
| 		go func() { | ||||
| 			quota := countTokenText(audioResponse.Text, audioModel) | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		}() | ||||
| 		go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | ||||
| 	}(c.Request.Context()) | ||||
|  | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &audioResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
|   | ||||
| @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "assistant", | ||||
| @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 		} else if message.Role == "assistant" { | ||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||
| 		} else if message.Role == "system" { | ||||
| 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||
| 			if prompt == "" { | ||||
| 				prompt = message.StringContent() | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	prompt += "\n\nAssistant:" | ||||
|   | ||||
| @@ -14,37 +14,71 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	if _, ok := common.DalleGenerationImageAmounts[element]; !ok { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	min := common.DalleGenerationImageAmounts[element][0] | ||||
| 	max := common.DalleGenerationImageAmounts[element][1] | ||||
|  | ||||
| 	return value >= min && value <= max | ||||
| } | ||||
|  | ||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	imageModel := "dall-e" | ||||
| 	imageModel := "dall-e-2" | ||||
| 	imageSize := "1024x1024" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	var imageRequest ImageRequest | ||||
| 	if consumeQuota { | ||||
| 		err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Size validation | ||||
| 	if imageRequest.Size != "" { | ||||
| 		imageSize = imageRequest.Size | ||||
| 	} | ||||
|  | ||||
| 	// Model validation | ||||
| 	if imageRequest.Model != "" { | ||||
| 		imageModel = imageRequest.Model | ||||
| 	} | ||||
|  | ||||
| 	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] | ||||
|  | ||||
| 	// Check if model is supported | ||||
| 	if hasValidSize { | ||||
| 		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { | ||||
| 			if imageSize == "1024x1024" { | ||||
| 				imageCostRatio *= 2 | ||||
| 			} else { | ||||
| 				imageCostRatio *= 1.5 | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Not "256x256", "512x512", or "1024x1024" | ||||
| 	if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { | ||||
| 		return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	// Check prompt length | ||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { | ||||
| 		return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// N should between 1 and 10 | ||||
| 	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { | ||||
| 		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	// Number of generated images validation | ||||
| 	if isWithinRange(imageModel, imageRequest.N) == false { | ||||
| 		return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// map model name | ||||
| @@ -61,16 +95,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
|  | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
|  | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| @@ -87,18 +117,9 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
|  | ||||
| 	sizeRatio := 1.0 | ||||
| 	// Size | ||||
| 	if imageRequest.Size == "256x256" { | ||||
| 		sizeRatio = 1 | ||||
| 	} else if imageRequest.Size == "512x512" { | ||||
| 		sizeRatio = 1.125 | ||||
| 	} else if imageRequest.Size == "1024x1024" { | ||||
| 		sizeRatio = 1.25 | ||||
| 	} | ||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||
|  | ||||
| 	if consumeQuota && userQuota-quota < 0 { | ||||
| 	if userQuota-quota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
|  | ||||
| @@ -127,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	var textResponse ImageResponse | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if consumeQuota { | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 		} | ||||
| 		err = model.CacheUpdateUserQuota(userId) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error update user quota cache: " + err.Error()) | ||||
| 		} | ||||
| 		if quota != 0 { | ||||
| 			tokenName := c.GetString("token_name") | ||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 		} | ||||
| 	}(c.Request.Context()) | ||||
|  | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
|  | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
|   | ||||
| @@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse TextResponse | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		if textResponse.Error.Type != "" { | ||||
| 			return &OpenAIErrorWithStatusCode{ | ||||
| 				OpenAIError: textResponse.Error, | ||||
| 				StatusCode:  resp.StatusCode, | ||||
| 			}, nil | ||||
| 		} | ||||
| 		// Reset response body | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if textResponse.Error.Type != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: textResponse.Error, | ||||
| 			StatusCode:  resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	// Reset response body | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
|  | ||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 	// So the httpClient will be confused by the response. | ||||
| @@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err := io.Copy(c.Writer, resp.Body) | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| @@ -132,7 +131,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp | ||||
| 	if textResponse.Usage.TotalTokens == 0 { | ||||
| 		completionTokens := 0 | ||||
| 		for _, choice := range textResponse.Choices { | ||||
| 			completionTokens += countTokenText(choice.Message.Content, model) | ||||
| 			completionTokens += countTokenText(choice.Message.StringContent(), model) | ||||
| 		} | ||||
| 		textResponse.Usage = Usage{ | ||||
| 			PromptTokens:     promptTokens, | ||||
|   | ||||
| @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| 	} | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		palmMessage := PaLMChatMessage{ | ||||
| 			Content: message.Content, | ||||
| 			Content: message.StringContent(), | ||||
| 		} | ||||
| 		if message.Role == "user" { | ||||
| 			palmMessage.Author = "0" | ||||
|   | ||||
| @@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, TencentMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, TencentMessage{ | ||||
| 				Role:    "assistant", | ||||
| @@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | ||||
| 			continue | ||||
| 		} | ||||
| 		messages = append(messages, TencentMessage{ | ||||
| 			Content: message.Content, | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    message.Role, | ||||
| 		}) | ||||
| 	} | ||||
|   | ||||
| @@ -6,13 +6,15 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -31,7 +33,14 @@ var httpClient *http.Client | ||||
| var impatientHTTPClient *http.Client | ||||
|  | ||||
| func init() { | ||||
| 	httpClient = &http.Client{} | ||||
| 	if common.RelayTimeout == 0 { | ||||
| 		httpClient = &http.Client{} | ||||
| 	} else { | ||||
| 		httpClient = &http.Client{ | ||||
| 			Timeout: time.Duration(common.RelayTimeout) * time.Second, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	impatientHTTPClient = &http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| @@ -42,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
| 	var textRequest GeneralOpenAIRequest | ||||
| 	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { | ||||
| 		err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		} | ||||
| 	err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { | ||||
| 		textRequest.Model = "text-moderation-latest" | ||||
| @@ -118,12 +124,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	if channelType == common.ChannelTypeOpenAI { | ||||
| 		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { | ||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) | ||||
| 		} | ||||
| 	} | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| @@ -143,7 +144,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | ||||
| 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||
|  | ||||
| 			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||
| 			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||
| @@ -156,6 +159,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 		case "ERNIE-Bot-turbo": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "ERNIE-Bot-4": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| @@ -227,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 		preConsumedQuota = 0 | ||||
| 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||
| 	} | ||||
| 	if consumeQuota && preConsumedQuota > 0 { | ||||
| 	if preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| @@ -361,11 +366,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			} | ||||
| 		case APITypeTencent: | ||||
| 			req.Header.Set("Authorization", apiKey) | ||||
| 		case APITypePaLM: | ||||
| 			// do not set Authorization header | ||||
| 		default: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 		} | ||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 		if isStream && c.Request.Header.Get("Accept") == "" { | ||||
| 			req.Header.Set("Accept", "text/event-stream") | ||||
| 		} | ||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| @@ -401,39 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	defer func(ctx context.Context) { | ||||
| 		// c.Writer.Flush() | ||||
| 		go func() { | ||||
| 			if consumeQuota { | ||||
| 				quota := 0 | ||||
| 				completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 				promptTokens = textResponse.Usage.PromptTokens | ||||
| 				completionTokens = textResponse.Usage.CompletionTokens | ||||
|  | ||||
| 				quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||
| 				quota = int(float64(quota) * ratio) | ||||
| 				if ratio != 0 && quota <= 0 { | ||||
| 					quota = 1 | ||||
| 				} | ||||
| 				totalTokens := promptTokens + completionTokens | ||||
| 				if totalTokens == 0 { | ||||
| 					// in this case, must be some error happened | ||||
| 					// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 					quota = 0 | ||||
| 				} | ||||
| 				quotaDelta := quota - preConsumedQuota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 				} | ||||
| 				err = model.CacheUpdateUserQuota(userId) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||
| 				} | ||||
| 				if quota != 0 { | ||||
| 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||
| 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 					model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 				} | ||||
| 			quota := 0 | ||||
| 			completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 			promptTokens = textResponse.Usage.PromptTokens | ||||
| 			completionTokens = textResponse.Usage.CompletionTokens | ||||
| 			quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||
| 			if ratio != 0 && quota <= 0 { | ||||
| 				quota = 1 | ||||
| 			} | ||||
| 			totalTokens := promptTokens + completionTokens | ||||
| 			if totalTokens == 0 { | ||||
| 				// in this case, must be some error happened | ||||
| 				// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 				quota = 0 | ||||
| 			} | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
|  | ||||
| 		}() | ||||
| 	}(c.Request.Context()) | ||||
| 	switch apiType { | ||||
| @@ -447,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) | ||||
| 			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|   | ||||
| @@ -1,15 +1,18 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| ) | ||||
|  | ||||
| var stopFinishReason = "stop" | ||||
| @@ -84,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int { | ||||
| 	tokenNum := 0 | ||||
| 	for _, message := range messages { | ||||
| 		tokenNum += tokensPerMessage | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Content) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.StringContent()) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Role) | ||||
| 		if message.Name != nil { | ||||
| 			tokenNum += tokensPerName | ||||
| @@ -176,3 +179,40 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr | ||||
| 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func getFullRequestURL(baseURL string, requestURL string, channelType int) string { | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
|  | ||||
| 	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { | ||||
| 		switch channelType { | ||||
| 		case common.ChannelTypeOpenAI: | ||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) | ||||
| 		case common.ChannelTypeAzure: | ||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return fullRequestURL | ||||
| } | ||||
|  | ||||
| func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||
| 	// quotaDelta is remaining quota to be consumed | ||||
| 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 	if err != nil { | ||||
| 		common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 	} | ||||
| 	err = model.CacheUpdateUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		common.SysError("error update user quota cache: " + err.Error()) | ||||
| 	} | ||||
| 	// totalQuota is total quota consumed | ||||
| 	if totalQuota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||
| 	} | ||||
| 	if totalQuota <= 0 { | ||||
| 		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "assistant", | ||||
| @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma | ||||
| 		} else { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| @@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin | ||||
| 	for !stop { | ||||
| 		select { | ||||
| 		case xunfeiResponse = <-dataChan: | ||||
| 			if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
| 			content += xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||
| @@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, | ||||
| 		common.SysLog("api_version not found, use default: " + apiVersion) | ||||
| 	} | ||||
| 	domain := "general" | ||||
| 	if apiVersion == "v2.1" { | ||||
| 		domain = "generalv2" | ||||
| 	if apiVersion != "v1.1" { | ||||
| 		domain += strings.Split(apiVersion, ".")[0] | ||||
| 	} | ||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	return domain, authUrl | ||||
|   | ||||
| @@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "user", | ||||
| @@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 		} else { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -12,10 +12,49 @@ import ( | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string  `json:"role"` | ||||
| 	Content string  `json:"content"` | ||||
| 	Content any     `json:"content"` | ||||
| 	Name    *string `json:"name,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageURL struct { | ||||
| 	Url    string `json:"url,omitempty"` | ||||
| 	Detail string `json:"detail,omitempty"` | ||||
| } | ||||
|  | ||||
| type TextContent struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageContent struct { | ||||
| 	Type     string    `json:"type,omitempty"` | ||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m Message) StringContent() string { | ||||
| 	content, ok := m.Content.(string) | ||||
| 	if ok { | ||||
| 		return content | ||||
| 	} | ||||
| 	contentList, ok := m.Content.([]any) | ||||
| 	if ok { | ||||
| 		var contentStr string | ||||
| 		for _, contentItem := range contentList { | ||||
| 			contentMap, ok := contentItem.(map[string]any) | ||||
| 			if !ok { | ||||
| 				continue | ||||
| 			} | ||||
| 			if contentMap["type"] == "text" { | ||||
| 				if subStr, ok := contentMap["text"].(string); ok { | ||||
| 					contentStr += subStr | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return contentStr | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	RelayModeUnknown = iota | ||||
| 	RelayModeChatCompletions | ||||
| @@ -24,24 +63,37 @@ const ( | ||||
| 	RelayModeModerations | ||||
| 	RelayModeImagesGenerations | ||||
| 	RelayModeEdits | ||||
| 	RelayModeAudio | ||||
| 	RelayModeAudioSpeech | ||||
| 	RelayModeAudioTranscription | ||||
| 	RelayModeAudioTranslation | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/chat | ||||
|  | ||||
| type ResponseFormat struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Model       string    `json:"model,omitempty"` | ||||
| 	Messages    []Message `json:"messages,omitempty"` | ||||
| 	Prompt      any       `json:"prompt,omitempty"` | ||||
| 	Stream      bool      `json:"stream,omitempty"` | ||||
| 	MaxTokens   int       `json:"max_tokens,omitempty"` | ||||
| 	Temperature float64   `json:"temperature,omitempty"` | ||||
| 	TopP        float64   `json:"top_p,omitempty"` | ||||
| 	N           int       `json:"n,omitempty"` | ||||
| 	Input       any       `json:"input,omitempty"` | ||||
| 	Instruction string    `json:"instruction,omitempty"` | ||||
| 	Size        string    `json:"size,omitempty"` | ||||
| 	Functions   any       `json:"functions,omitempty"` | ||||
| 	Model            string          `json:"model,omitempty"` | ||||
| 	Messages         []Message       `json:"messages,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Stream           bool            `json:"stream,omitempty"` | ||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	N                int             `json:"n,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| 	Functions        any             `json:"functions,omitempty"` | ||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed             float64         `json:"seed,omitempty"` | ||||
| 	Tools            any             `json:"tools,omitempty"` | ||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
| @@ -77,16 +129,30 @@ type TextRequest struct { | ||||
| 	//Stream   bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| // ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create | ||||
| type ImageRequest struct { | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	N      int    `json:"n"` | ||||
| 	Size   string `json:"size"` | ||||
| 	Model          string `json:"model"` | ||||
| 	Prompt         string `json:"prompt" binding:"required"` | ||||
| 	N              int    `json:"n"` | ||||
| 	Size           string `json:"size"` | ||||
| 	Quality        string `json:"quality"` | ||||
| 	ResponseFormat string `json:"response_format"` | ||||
| 	Style          string `json:"style"` | ||||
| 	User           string `json:"user"` | ||||
| } | ||||
|  | ||||
| type AudioResponse struct { | ||||
| type WhisperResponse struct { | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type TextToSpeechRequest struct { | ||||
| 	Model          string  `json:"model" binding:"required"` | ||||
| 	Input          string  `json:"input" binding:"required"` | ||||
| 	Voice          string  `json:"voice" binding:"required"` | ||||
| 	Speed          float64 `json:"speed"` | ||||
| 	ResponseFormat string  `json:"response_format"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	PromptTokens     int `json:"prompt_tokens"` | ||||
| 	CompletionTokens int `json:"completion_tokens"` | ||||
| @@ -183,14 +249,22 @@ func Relay(c *gin.Context) { | ||||
| 		relayMode = RelayModeImagesGenerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||
| 		relayMode = RelayModeEdits | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 		relayMode = RelayModeAudio | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { | ||||
| 		relayMode = RelayModeAudioSpeech | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { | ||||
| 		relayMode = RelayModeAudioTranscription | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 		relayMode = RelayModeAudioTranslation | ||||
| 	} | ||||
| 	var err *OpenAIErrorWithStatusCode | ||||
| 	switch relayMode { | ||||
| 	case RelayModeImagesGenerations: | ||||
| 		err = relayImageHelper(c, relayMode) | ||||
| 	case RelayModeAudio: | ||||
| 	case RelayModeAudioSpeech: | ||||
| 		fallthrough | ||||
| 	case RelayModeAudioTranslation: | ||||
| 		fallthrough | ||||
| 	case RelayModeAudioTranscription: | ||||
| 		err = relayAudioHelper(c, relayMode) | ||||
| 	default: | ||||
| 		err = relayTextHelper(c, relayMode) | ||||
|   | ||||
| @@ -9,21 +9,21 @@ services: | ||||
|     ports: | ||||
|       - "3000:3000" | ||||
|     volumes: | ||||
|       - ./data:/data | ||||
|       - ./data/oneapi:/data | ||||
|       - ./logs:/app/logs | ||||
|     environment: | ||||
|       - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库 | ||||
|       - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库 | ||||
|       - REDIS_CONN_STRING=redis://redis | ||||
|       - SESSION_SECRET=random_string  # 修改为随机字符串 | ||||
|       - TZ=Asia/Shanghai | ||||
| #      - NODE_TYPE=slave  # 多机部署时从节点取消注释该行 | ||||
| #      - SYNC_FREQUENCY=60  # 需要定期从数据库加载数据时取消注释该行 | ||||
| #      - FRONTEND_BASE_URL=https://openai.justsong.cn  # 多机部署时从节点取消注释该行 | ||||
|  | ||||
|     depends_on: | ||||
|       - redis | ||||
|       - db | ||||
|     healthcheck: | ||||
|       test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] | ||||
|       test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] | ||||
|       interval: 30s | ||||
|       timeout: 10s | ||||
|       retries: 3 | ||||
| @@ -32,3 +32,18 @@ services: | ||||
|     image: redis:latest | ||||
|     container_name: redis | ||||
|     restart: always | ||||
|  | ||||
|   db: | ||||
|     image: mysql:8.2.0 | ||||
|     restart: always | ||||
|     container_name: mysql | ||||
|     volumes: | ||||
|       - ./data/mysql:/var/lib/mysql  # 挂载目录,持久化存储 | ||||
|     ports: | ||||
|       - '3306:3306' | ||||
|     environment: | ||||
|       TZ: Asia/Shanghai   # 设置时区 | ||||
|       MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码 | ||||
|       MYSQL_USER: oneapi   # 创建专用用户 | ||||
|       MYSQL_PASSWORD: '123456'    # 设置专用用户密码 | ||||
|       MYSQL_DATABASE: one-api   # 自动创建数据库 | ||||
| @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 		c.Set("id", token.UserId) | ||||
| 		c.Set("token_id", token.Id) | ||||
| 		c.Set("token_name", token.Name) | ||||
| 		requestURL := c.Request.URL.String() | ||||
| 		consumeQuota := true | ||||
| 		if strings.HasPrefix(requestURL, "/v1/models") { | ||||
| 			consumeQuota = false | ||||
| 		} | ||||
| 		c.Set("consume_quota", consumeQuota) | ||||
| 		if len(parts) > 1 { | ||||
| 			if model.IsAdmin(token.UserId) { | ||||
| 				c.Set("channelId", parts[1]) | ||||
|   | ||||
| @@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 		} else { | ||||
| 			// Select a channel for the user | ||||
| 			var modelRequest ModelRequest | ||||
| 			var err error | ||||
| 			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 				err = common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			} | ||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||
| 				return | ||||
| @@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) { | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e" | ||||
| 					modelRequest.Model = "dall-e-2" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "whisper-1" | ||||
| 				} | ||||
|   | ||||
| @@ -15,10 +15,17 @@ type Ability struct { | ||||
|  | ||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	ability := Ability{} | ||||
| 	groupCol := "`group`" | ||||
| 	trueVal := "1" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 		trueVal = "true" | ||||
| 	} | ||||
|  | ||||
| 	var err error = nil | ||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) | ||||
| 	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	if common.UsingSQLite { | ||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||
| 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	if common.UsingSQLite || common.UsingPostgreSQL { | ||||
| 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||
| 	} else { | ||||
| 		err = channelQuery.Order("RAND()").First(&ability).Error | ||||
|   | ||||
| @@ -21,14 +21,18 @@ var ( | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	var token Token | ||||
| 	if !common.RedisEnabled { | ||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | ||||
| 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||
| 		return &token, err | ||||
| 	} | ||||
| 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) | ||||
| 	if err != nil { | ||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | ||||
| 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|   | ||||
| @@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| } | ||||
|  | ||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| @@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { | ||||
| 	return &channel, err | ||||
| } | ||||
|  | ||||
| func GetRandomChannel() (*Channel, error) { | ||||
| 	channel := Channel{} | ||||
| 	var err error = nil | ||||
| 	if common.UsingSQLite { | ||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error | ||||
| 	} else { | ||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error | ||||
| 	} | ||||
| 	return &channel, err | ||||
| } | ||||
|  | ||||
| func BatchInsertChannels(channels []Channel) error { | ||||
| 	var err error | ||||
| 	err = DB.Create(&channels).Error | ||||
| @@ -181,3 +174,8 @@ func DeleteChannelByStatus(status int64) (int64, error) { | ||||
| 	result := DB.Where("status = ?", status).Delete(&Channel{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|  | ||||
| func DeleteDisabledChannel() (int64, error) { | ||||
| 	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|   | ||||
| @@ -94,7 +94,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | ||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||
| 	} | ||||
| 	if channel != 0 { | ||||
| 		tx = tx.Where("channel = ?", channel) | ||||
| 		tx = tx.Where("channel_id = ?", channel) | ||||
| 	} | ||||
| 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | ||||
| 	return logs, err | ||||
| @@ -151,7 +151,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| 	} | ||||
| 	if channel != 0 { | ||||
| 		tx = tx.Where("channel = ?", channel) | ||||
| 		tx = tx.Where("channel_id = ?", channel) | ||||
| 	} | ||||
| 	tx.Where("type = ?", LogTypeConsume).Scan("a) | ||||
| 	return quota | ||||
|   | ||||
| @@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 		if strings.HasPrefix(dsn, "postgres://") { | ||||
| 			// Use PostgreSQL | ||||
| 			common.SysLog("using PostgreSQL as database") | ||||
| 			common.UsingPostgreSQL = true | ||||
| 			return gorm.Open(postgres.New(postgres.Config{ | ||||
| 				DSN:                  dsn, | ||||
| 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||
|   | ||||
| @@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 	} | ||||
| 	redemption := &Redemption{} | ||||
|  | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error | ||||
| 		if err != nil { | ||||
| 			return errors.New("无效的兑换码") | ||||
| 		} | ||||
|   | ||||
| @@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) { | ||||
| } | ||||
|  | ||||
| func GetUserGroup(id int) (group string, err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error | ||||
| 	groupCol := "`group`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error | ||||
| 	return group, err | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										3
									
								
								pull_request_template.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								pull_request_template.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| close #issue_number | ||||
|  | ||||
| 我已确认该 PR 已自测通过,相关截图如下: | ||||
| @@ -74,7 +74,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | ||||
| 			channelRoute.POST("/", controller.AddChannel) | ||||
| 			channelRoute.PUT("/", controller.UpdateChannel) | ||||
| 			channelRoute.DELETE("/manually_disabled", controller.DeleteManuallyDisabledChannel) | ||||
| 			channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel) | ||||
| 			channelRoute.DELETE("/:id", controller.DeleteChannel) | ||||
| 		} | ||||
| 		tokenRoute := apiRouter.Group("/token") | ||||
|   | ||||
| @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) { | ||||
| 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/transcriptions", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/translations", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/speech", controller.Relay) | ||||
| 		relayV1Router.GET("/files", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/files", controller.RelayNotImplemented) | ||||
| 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Input, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; | ||||
| import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | ||||
|  | ||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||
| import { renderGroup, renderNumber } from '../helpers/render'; | ||||
| @@ -55,6 +55,7 @@ const ChannelsTable = () => { | ||||
|   const [searchKeyword, setSearchKeyword] = useState(''); | ||||
|   const [searching, setSearching] = useState(false); | ||||
|   const [updatingBalance, setUpdatingBalance] = useState(false); | ||||
|   const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test")); | ||||
|  | ||||
|   const loadChannels = async (startIdx) => { | ||||
|     const res = await API.get(`/api/channel/?p=${startIdx}`); | ||||
| @@ -226,7 +227,6 @@ const ChannelsTable = () => { | ||||
|       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); | ||||
|     } else { | ||||
|       showError(message); | ||||
|       showNotice('当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。'); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
| @@ -240,11 +240,11 @@ const ChannelsTable = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const deleteAllManuallyDisabledChannels = async () => { | ||||
|     const res = await API.delete(`/api/channel/manually_disabled`); | ||||
|   const deleteAllDisabledChannels = async () => { | ||||
|     const res = await API.delete(`/api/channel/disabled`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       showSuccess(`已删除所有手动禁用渠道,共计 ${data} 个`); | ||||
|       showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); | ||||
|       await refresh(); | ||||
|     } else { | ||||
|       showError(message); | ||||
| @@ -286,17 +286,15 @@ const ChannelsTable = () => { | ||||
|     if (channels.length === 0) return; | ||||
|     setLoading(true); | ||||
|     let sortedChannels = [...channels]; | ||||
|     if (typeof sortedChannels[0][key] === 'string') { | ||||
|       sortedChannels.sort((a, b) => { | ||||
|     sortedChannels.sort((a, b) => { | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       }); | ||||
|     } else { | ||||
|       sortedChannels.sort((a, b) => { | ||||
|         if (a[key] === b[key]) return 0; | ||||
|         if (a[key] > b[key]) return -1; | ||||
|         if (a[key] < b[key]) return 1; | ||||
|       }); | ||||
|     } | ||||
|       } | ||||
|     }); | ||||
|     if (sortedChannels[0].id === channels[0].id) { | ||||
|       sortedChannels.reverse(); | ||||
|     } | ||||
| @@ -304,6 +302,7 @@ const ChannelsTable = () => { | ||||
|     setLoading(false); | ||||
|   }; | ||||
|  | ||||
|  | ||||
|   return ( | ||||
|     <> | ||||
|       <Form onSubmit={searchChannels}> | ||||
| @@ -317,7 +316,19 @@ const ChannelsTable = () => { | ||||
|           onChange={handleKeywordChange} | ||||
|         /> | ||||
|       </Form> | ||||
|       { | ||||
|         showPrompt && ( | ||||
|           <Message onDismiss={() => { | ||||
|             setShowPrompt(false); | ||||
|             setPromptShown("channel-test"); | ||||
|           }}> | ||||
|             当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo | ||||
|             模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 | ||||
|  | ||||
|             另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 | ||||
|           </Message> | ||||
|         ) | ||||
|       } | ||||
|       <Table basic compact size='small'> | ||||
|         <Table.Header> | ||||
|           <Table.Row> | ||||
| @@ -519,14 +530,14 @@ const ChannelsTable = () => { | ||||
|               <Popup | ||||
|                 trigger={ | ||||
|                   <Button size='small' loading={loading}> | ||||
|                     删除所有手动禁用渠道 | ||||
|                     删除禁用渠道 | ||||
|                   </Button> | ||||
|                 } | ||||
|                 on='click' | ||||
|                 flowing | ||||
|                 hoverable | ||||
|               > | ||||
|                 <Button size='small' loading={loading} negative onClick={deleteAllManuallyDisabledChannels}> | ||||
|                 <Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}> | ||||
|                   确认删除 | ||||
|                 </Button> | ||||
|               </Popup> | ||||
|   | ||||
| @@ -130,7 +130,13 @@ const RedemptionsTable = () => { | ||||
|     setLoading(true); | ||||
|     let sortedRedemptions = [...redemptions]; | ||||
|     sortedRedemptions.sort((a, b) => { | ||||
|       return ('' + a[key]).localeCompare(b[key]); | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       } | ||||
|     }); | ||||
|     if (sortedRedemptions[0].id === redemptions[0].id) { | ||||
|       sortedRedemptions.reverse(); | ||||
|   | ||||
| @@ -228,7 +228,13 @@ const TokensTable = () => { | ||||
|     setLoading(true); | ||||
|     let sortedTokens = [...tokens]; | ||||
|     sortedTokens.sort((a, b) => { | ||||
|       return ('' + a[key]).localeCompare(b[key]); | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       } | ||||
|     }); | ||||
|     if (sortedTokens[0].id === tokens[0].id) { | ||||
|       sortedTokens.reverse(); | ||||
|   | ||||
| @@ -133,7 +133,13 @@ const UsersTable = () => { | ||||
|     setLoading(true); | ||||
|     let sortedUsers = [...users]; | ||||
|     sortedUsers.sort((a, b) => { | ||||
|       return ('' + a[key]).localeCompare(b[key]); | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       } | ||||
|     }); | ||||
|     if (sortedUsers[0].id === users[0].id) { | ||||
|       sortedUsers.reverse(); | ||||
|   | ||||
| @@ -186,4 +186,14 @@ export const verifyJSON = (str) => { | ||||
|     return false; | ||||
|   } | ||||
|   return true; | ||||
| }; | ||||
| }; | ||||
|  | ||||
| export function shouldShowPrompt(id) { | ||||
|   let prompt = localStorage.getItem(`prompt-${id}`); | ||||
|   return !prompt; | ||||
|  | ||||
| } | ||||
|  | ||||
| export function setPromptShown(id) { | ||||
|   localStorage.setItem(`prompt-${id}`, 'true'); | ||||
| } | ||||
| @@ -60,19 +60,19 @@ const EditChannel = () => { | ||||
|       let localModels = []; | ||||
|       switch (value) { | ||||
|         case 14: | ||||
|           localModels = ['claude-instant-1', 'claude-2']; | ||||
|           localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; | ||||
|           break; | ||||
|         case 11: | ||||
|           localModels = ['PaLM-2']; | ||||
|           break; | ||||
|         case 15: | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; | ||||
|           break; | ||||
|         case 17: | ||||
|           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; | ||||
|           break; | ||||
|         case 16: | ||||
|           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           break; | ||||
|         case 18: | ||||
|           localModels = ['SparkDesk']; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user