mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-30 13:23:42 +08:00 
			
		
		
		
	Compare commits
	
		
			33 Commits
		
	
	
		
			v0.5.3
			...
			v0.5.5-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 4335f005a6 | ||
|  | fe26a1448d | ||
|  | 42451d9d02 | ||
|  | 25c4c111ab | ||
|  | 0d50ad4b2b | ||
|  | 959bcdef88 | ||
|  | 39ae8075e4 | ||
|  | b57a0eca16 | ||
|  | 1b4cc78890 | ||
|  | 420c375140 | ||
|  | 01863d3e44 | ||
|  | d0a0e871e1 | ||
|  | bd6fe1e93c | ||
|  | c55bb67818 | ||
|  | 0f949c3782 | ||
|  | a721a5b6f9 | ||
|  | 276163affd | ||
|  | 621eb91b46 | ||
|  | 7e575abb95 | ||
|  | 9db93316c4 | ||
|  | c3dc315e75 | ||
|  | 04acdb1ccb | ||
|  | f0d5e102a3 | ||
|  | abbf2fded0 | ||
|  | ef2c5abb5b | ||
|  | 56b5007379 | ||
|  | d09d317459 | ||
|  | 1c4409ae80 | ||
|  | 5ee24e8acf | ||
|  | 4f2f911e4d | ||
|  | fdb2cccf65 | ||
|  | a3e267df7e | ||
|  | ac7c0f3a76 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -4,4 +4,5 @@ upload | |||||||
| *.exe | *.exe | ||||||
| *.db | *.db | ||||||
| build | build | ||||||
| *.db-journal | *.db-journal | ||||||
|  | logs | ||||||
							
								
								
									
										28
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								README.md
									
									
									
									
									
								
							| @@ -68,12 +68,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) |    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) |    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) |    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||||
|  |    + [x] [360 智脑](https://ai.360.cn) | ||||||
| 2. 支持配置镜像以及众多第三方代理服务: | 2. 支持配置镜像以及众多第三方代理服务: | ||||||
|    + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [OpenAI-SB](https://openai-sb.com) | ||||||
|  |    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||||
|    + [x] [API2D](https://api2d.com/r/197971) |    + [x] [API2D](https://api2d.com/r/197971) | ||||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) |    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) |    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) |  | ||||||
|    + [x] 自定义渠道:例如各种未收录的第三方代理服务 |    + [x] 自定义渠道:例如各种未收录的第三方代理服务 | ||||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| @@ -108,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|  |  | ||||||
| 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | ||||||
|  |  | ||||||
|  | 如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 | ||||||
|  |  | ||||||
| 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 | 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 | ||||||
|  |  | ||||||
| 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 | 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 | ||||||
| @@ -208,6 +211,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
|  |  | ||||||
| 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | ||||||
|  |  | ||||||
|  | #### QChatGPT - QQ机器人 | ||||||
|  | 项目主页:https://github.com/RockChinQ/QChatGPT | ||||||
|  |  | ||||||
|  | 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||||
|  |  | ||||||
|  | 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||||
|  |  | ||||||
| ### 部署到第三方平台 | ### 部署到第三方平台 | ||||||
| <details> | <details> | ||||||
| <summary><strong>部署到 Sealos </strong></summary> | <summary><strong>部署到 Sealos </strong></summary> | ||||||
| @@ -274,8 +284,9 @@ graph LR | |||||||
| 不加的话将会使用负载均衡的方式使用多个渠道。 | 不加的话将会使用负载均衡的方式使用多个渠道。 | ||||||
|  |  | ||||||
| ### 环境变量 | ### 环境变量 | ||||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。 | 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
|  |    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||||
|    + 例子:`SESSION_SECRET=random_string` |    + 例子:`SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||||
| @@ -302,11 +313,19 @@ graph LR | |||||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` |    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||||
|    + 例子:`POLLING_INTERVAL=5` |    + 例子:`POLLING_INTERVAL=5` | ||||||
|  | 10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|  |     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||||
|  |     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||||
|  | 11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||||
|  |     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||||
|  | 12. 请求频率限制: | ||||||
|  |     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||||
|  |     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
|    + 例子:`--port 3000` |    + 例子:`--port 3000` | ||||||
| 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。 | 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 | ||||||
|    + 例子:`--log-dir ./logs` |    + 例子:`--log-dir ./logs` | ||||||
| 3. `--version`: 打印系统版本号并退出。 | 3. `--version`: 打印系统版本号并退出。 | ||||||
| 4. `--help`: 查看命令的使用帮助和参数说明。 | 4. `--help`: 查看命令的使用帮助和参数说明。 | ||||||
| @@ -338,6 +357,7 @@ https://openai.justsong.cn | |||||||
| 5. ChatGPT Next Web 报错:`Failed to fetch` | 5. ChatGPT Next Web 报错:`Failed to fetch` | ||||||
|    + 部署的时候不要设置 `BASE_URL`。 |    + 部署的时候不要设置 `BASE_URL`。 | ||||||
|    + 检查你的接口地址和 API Key 有没有填对。 |    + 检查你的接口地址和 API Key 有没有填对。 | ||||||
|  |    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||||
|    + 上游通道 429 了。 |    + 上游通道 429 了。 | ||||||
|  |  | ||||||
| @@ -351,4 +371,4 @@ https://openai.justsong.cn | |||||||
|  |  | ||||||
| 同样适用于基于本项目的二开项目。 | 同样适用于基于本项目的二开项目。 | ||||||
|  |  | ||||||
| 依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 | 依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 | ||||||
|   | |||||||
| @@ -94,6 +94,13 @@ var RequestInterval = time.Duration(requestInterval) * time.Second | |||||||
|  |  | ||||||
| var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | ||||||
|  |  | ||||||
|  | var BatchUpdateEnabled = false | ||||||
|  | var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RequestIdKey = "X-Oneapi-Request-Id" | ||||||
|  | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	RoleGuestUser  = 0 | 	RoleGuestUser  = 0 | ||||||
| 	RoleCommonUser = 1 | 	RoleCommonUser = 1 | ||||||
| @@ -111,10 +118,10 @@ var ( | |||||||
| // All duration's unit is seconds | // All duration's unit is seconds | ||||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | // Shouldn't larger then RateLimitKeyExpirationDuration | ||||||
| var ( | var ( | ||||||
| 	GlobalApiRateLimitNum            = 180 | 	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) | ||||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	GlobalWebRateLimitNum            = 60 | 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) | ||||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	UploadRateLimitNum            = 10 | 	UploadRateLimitNum            = 10 | ||||||
| @@ -154,45 +161,53 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ChannelTypeUnknown   = 0 | 	ChannelTypeUnknown        = 0 | ||||||
| 	ChannelTypeOpenAI    = 1 | 	ChannelTypeOpenAI         = 1 | ||||||
| 	ChannelTypeAPI2D     = 2 | 	ChannelTypeAPI2D          = 2 | ||||||
| 	ChannelTypeAzure     = 3 | 	ChannelTypeAzure          = 3 | ||||||
| 	ChannelTypeCloseAI   = 4 | 	ChannelTypeCloseAI        = 4 | ||||||
| 	ChannelTypeOpenAISB  = 5 | 	ChannelTypeOpenAISB       = 5 | ||||||
| 	ChannelTypeOpenAIMax = 6 | 	ChannelTypeOpenAIMax      = 6 | ||||||
| 	ChannelTypeOhMyGPT   = 7 | 	ChannelTypeOhMyGPT        = 7 | ||||||
| 	ChannelTypeCustom    = 8 | 	ChannelTypeCustom         = 8 | ||||||
| 	ChannelTypeAILS      = 9 | 	ChannelTypeAILS           = 9 | ||||||
| 	ChannelTypeAIProxy   = 10 | 	ChannelTypeAIProxy        = 10 | ||||||
| 	ChannelTypePaLM      = 11 | 	ChannelTypePaLM           = 11 | ||||||
| 	ChannelTypeAPI2GPT   = 12 | 	ChannelTypeAPI2GPT        = 12 | ||||||
| 	ChannelTypeAIGC2D    = 13 | 	ChannelTypeAIGC2D         = 13 | ||||||
| 	ChannelTypeAnthropic = 14 | 	ChannelTypeAnthropic      = 14 | ||||||
| 	ChannelTypeBaidu     = 15 | 	ChannelTypeBaidu          = 15 | ||||||
| 	ChannelTypeZhipu     = 16 | 	ChannelTypeZhipu          = 16 | ||||||
| 	ChannelTypeAli       = 17 | 	ChannelTypeAli            = 17 | ||||||
| 	ChannelTypeXunfei    = 18 | 	ChannelTypeXunfei         = 18 | ||||||
|  | 	ChannelType360            = 19 | ||||||
|  | 	ChannelTypeOpenRouter     = 20 | ||||||
|  | 	ChannelTypeAIProxyLibrary = 21 | ||||||
|  | 	ChannelTypeFastGPT        = 22 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| 	"",                               // 0 | 	"",                                // 0 | ||||||
| 	"https://api.openai.com",         // 1 | 	"https://api.openai.com",          // 1 | ||||||
| 	"https://oa.api2d.net",           // 2 | 	"https://oa.api2d.net",            // 2 | ||||||
| 	"",                               // 3 | 	"",                                // 3 | ||||||
| 	"https://api.closeai-proxy.xyz",  // 4 | 	"https://api.closeai-proxy.xyz",   // 4 | ||||||
| 	"https://api.openai-sb.com",      // 5 | 	"https://api.openai-sb.com",       // 5 | ||||||
| 	"https://api.openaimax.com",      // 6 | 	"https://api.openaimax.com",       // 6 | ||||||
| 	"https://api.ohmygpt.com",        // 7 | 	"https://api.ohmygpt.com",         // 7 | ||||||
| 	"",                               // 8 | 	"",                                // 8 | ||||||
| 	"https://api.caipacity.com",      // 9 | 	"https://api.caipacity.com",       // 9 | ||||||
| 	"https://api.aiproxy.io",         // 10 | 	"https://api.aiproxy.io",          // 10 | ||||||
| 	"",                               // 11 | 	"",                                // 11 | ||||||
| 	"https://api.api2gpt.com",        // 12 | 	"https://api.api2gpt.com",         // 12 | ||||||
| 	"https://api.aigc2d.com",         // 13 | 	"https://api.aigc2d.com",          // 13 | ||||||
| 	"https://api.anthropic.com",      // 14 | 	"https://api.anthropic.com",       // 14 | ||||||
| 	"https://aip.baidubce.com",       // 15 | 	"https://aip.baidubce.com",        // 15 | ||||||
| 	"https://open.bigmodel.cn",       // 16 | 	"https://open.bigmodel.cn",        // 16 | ||||||
| 	"https://dashscope.aliyuncs.com", // 17 | 	"https://dashscope.aliyuncs.com",  // 17 | ||||||
| 	"",                               // 18 | 	"",                                // 18 | ||||||
|  | 	"https://ai.360.cn",               // 19 | ||||||
|  | 	"https://openrouter.ai/api",       // 20 | ||||||
|  | 	"https://api.aiproxy.io",          // 21 | ||||||
|  | 	"https://fastgpt.run/api/openapi", // 22 | ||||||
| } | } | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ var ( | |||||||
| 	Port         = flag.Int("port", 3000, "the listening port") | 	Port         = flag.Int("port", 3000, "the listening port") | ||||||
| 	PrintVersion = flag.Bool("version", false, "print version and exit") | 	PrintVersion = flag.Bool("version", false, "print version and exit") | ||||||
| 	PrintHelp    = flag.Bool("help", false, "print help and exit") | 	PrintHelp    = flag.Bool("help", false, "print help and exit") | ||||||
| 	LogDir       = flag.String("log-dir", "", "specify the log directory") | 	LogDir       = flag.String("log-dir", "./logs", "specify the log directory") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func printHelp() { | func printHelp() { | ||||||
|   | |||||||
| @@ -1,29 +1,47 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"io" | 	"io" | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetupGinLog() { | const ( | ||||||
|  | 	loggerINFO  = "INFO" | ||||||
|  | 	loggerWarn  = "WARN" | ||||||
|  | 	loggerError = "ERR" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const maxLogCount = 1000000 | ||||||
|  |  | ||||||
|  | var logCount int | ||||||
|  | var setupLogLock sync.Mutex | ||||||
|  | var setupLogWorking bool | ||||||
|  |  | ||||||
|  | func SetupLogger() { | ||||||
| 	if *LogDir != "" { | 	if *LogDir != "" { | ||||||
| 		commonLogPath := filepath.Join(*LogDir, "common.log") | 		ok := setupLogLock.TryLock() | ||||||
| 		errorLogPath := filepath.Join(*LogDir, "error.log") | 		if !ok { | ||||||
| 		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | 			log.Println("setup log is already working") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		defer func() { | ||||||
|  | 			setupLogLock.Unlock() | ||||||
|  | 			setupLogWorking = false | ||||||
|  | 		}() | ||||||
|  | 		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||||
|  | 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Fatal("failed to open log file") | 			log.Fatal("failed to open log file") | ||||||
| 		} | 		} | ||||||
| 		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | 		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) | ||||||
| 		if err != nil { | 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) | ||||||
| 			log.Fatal("failed to open log file") |  | ||||||
| 		} |  | ||||||
| 		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) |  | ||||||
| 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -37,6 +55,36 @@ func SysError(s string) { | |||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func LogInfo(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerINFO, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LogWarn(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerWarn, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LogError(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerError, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func logHelper(ctx context.Context, level string, msg string) { | ||||||
|  | 	writer := gin.DefaultErrorWriter | ||||||
|  | 	if level == loggerINFO { | ||||||
|  | 		writer = gin.DefaultWriter | ||||||
|  | 	} | ||||||
|  | 	id := ctx.Value(RequestIdKey) | ||||||
|  | 	now := time.Now() | ||||||
|  | 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||||
|  | 	logCount++ // we don't need accurate count, so no lock here | ||||||
|  | 	if logCount > maxLogCount && !setupLogWorking { | ||||||
|  | 		logCount = 0 | ||||||
|  | 		setupLogWorking = true | ||||||
|  | 		go func() { | ||||||
|  | 			SetupLogger() | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func FatalLog(v ...any) { | func FatalLog(v ...any) { | ||||||
| 	t := time.Now() | 	t := time.Now() | ||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||||
|   | |||||||
| @@ -13,46 +13,52 @@ import ( | |||||||
| // 1 === $0.002 / 1K tokens | // 1 === $0.002 / 1K tokens | ||||||
| // 1 === ¥0.014 / 1k tokens | // 1 === ¥0.014 / 1k tokens | ||||||
| var ModelRatio = map[string]float64{ | var ModelRatio = map[string]float64{ | ||||||
| 	"gpt-4":                   15, | 	"gpt-4":                     15, | ||||||
| 	"gpt-4-0314":              15, | 	"gpt-4-0314":                15, | ||||||
| 	"gpt-4-0613":              15, | 	"gpt-4-0613":                15, | ||||||
| 	"gpt-4-32k":               30, | 	"gpt-4-32k":                 30, | ||||||
| 	"gpt-4-32k-0314":          30, | 	"gpt-4-32k-0314":            30, | ||||||
| 	"gpt-4-32k-0613":          30, | 	"gpt-4-32k-0613":            30, | ||||||
| 	"gpt-3.5-turbo":           0.75, // $0.0015 / 1K tokens | 	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-0301":      0.75, | 	"gpt-3.5-turbo-0301":        0.75, | ||||||
| 	"gpt-3.5-turbo-0613":      0.75, | 	"gpt-3.5-turbo-0613":        0.75, | ||||||
| 	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens | 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-16k-0613":  1.5, | 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||||
| 	"text-ada-001":            0.2, | 	"text-ada-001":              0.2, | ||||||
| 	"text-babbage-001":        0.25, | 	"text-babbage-001":          0.25, | ||||||
| 	"text-curie-001":          1, | 	"text-curie-001":            1, | ||||||
| 	"text-davinci-002":        10, | 	"text-davinci-002":          10, | ||||||
| 	"text-davinci-003":        10, | 	"text-davinci-003":          10, | ||||||
| 	"text-davinci-edit-001":   10, | 	"text-davinci-edit-001":     10, | ||||||
| 	"code-davinci-edit-001":   10, | 	"code-davinci-edit-001":     10, | ||||||
| 	"whisper-1":               10, | 	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||||
| 	"davinci":                 10, | 	"davinci":                   10, | ||||||
| 	"curie":                   10, | 	"curie":                     10, | ||||||
| 	"babbage":                 10, | 	"babbage":                   10, | ||||||
| 	"ada":                     10, | 	"ada":                       10, | ||||||
| 	"text-embedding-ada-002":  0.05, | 	"text-embedding-ada-002":    0.05, | ||||||
| 	"text-search-ada-doc-001": 10, | 	"text-search-ada-doc-001":   10, | ||||||
| 	"text-moderation-stable":  0.1, | 	"text-moderation-stable":    0.1, | ||||||
| 	"text-moderation-latest":  0.1, | 	"text-moderation-latest":    0.1, | ||||||
| 	"dall-e":                  8, | 	"dall-e":                    8, | ||||||
| 	"claude-instant-1":        0.815,  // $1.63 / 1M tokens | 	"claude-instant-1":          0.815,  // $1.63 / 1M tokens | ||||||
| 	"claude-2":                5.51,   // $11.02 / 1M tokens | 	"claude-2":                  5.51,   // $11.02 / 1M tokens | ||||||
| 	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens | 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens | 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | ||||||
| 	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens | 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||||
| 	"PaLM-2":                  1, | 	"PaLM-2":                    1, | ||||||
| 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
| 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens | 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||||
| 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens | 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens | ||||||
| 	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag | 	"qwen-v1":                   0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"qwen-plus-v1":            0.5715, // Same as above | 	"qwen-plus-v1":              1,      // ¥0.014 / 1k tokens | ||||||
| 	"SparkDesk":               0.8572, // TBD | 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | ||||||
|  | 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||||
|  | 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
|   | |||||||
| @@ -171,6 +171,11 @@ func GetTimestamp() int64 { | |||||||
| 	return time.Now().Unix() | 	return time.Now().Unix() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetTimeString() string { | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||||
|  | } | ||||||
|  |  | ||||||
| func Max(a int, b int) int { | func Max(a int, b int) int { | ||||||
| 	if a >= b { | 	if a >= b { | ||||||
| 		return a | 		return a | ||||||
| @@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int { | |||||||
| 	} | 	} | ||||||
| 	return num | 	return num | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func MessageWithRequestId(message string, id string) string { | ||||||
|  | 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		openAIError := OpenAIError{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "one_api_error", | 			Type:    "upstream_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": openAIError, | ||||||
|   | |||||||
| @@ -14,7 +14,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { | func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
| 	case common.ChannelTypePaLM: | 	case common.ChannelTypePaLM: | ||||||
| 		fallthrough | 		fallthrough | ||||||
| @@ -24,10 +24,19 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr | |||||||
| 		fallthrough | 		fallthrough | ||||||
| 	case common.ChannelTypeZhipu: | 	case common.ChannelTypeZhipu: | ||||||
| 		fallthrough | 		fallthrough | ||||||
|  | 	case common.ChannelTypeAli: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelType360: | ||||||
|  | 		fallthrough | ||||||
| 	case common.ChannelTypeXunfei: | 	case common.ChannelTypeXunfei: | ||||||
| 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||||
| 	case common.ChannelTypeAzure: | 	case common.ChannelTypeAzure: | ||||||
| 		request.Model = "gpt-35-turbo" | 		request.Model = "gpt-35-turbo" | ||||||
|  | 		defer func() { | ||||||
|  | 			if err != nil { | ||||||
|  | 				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
| 	default: | 	default: | ||||||
| 		request.Model = "gpt-3.5-turbo" | 		request.Model = "gpt-3.5-turbo" | ||||||
| 	} | 	} | ||||||
| @@ -174,7 +183,7 @@ func testAllChannels(notify bool) error { | |||||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if shouldDisableChannel(openaiErr) { | 			if shouldDisableChannel(openaiErr, -1) { | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
|   | |||||||
| @@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	channel.CreatedTime = common.GetTimestamp() | 	channel.CreatedTime = common.GetTimestamp() | ||||||
| 	keys := strings.Split(channel.Key, "\n") | 	keys := strings.Split(channel.Key, "\n") | ||||||
| 	channels := make([]model.Channel, 0) | 	channels := make([]model.Channel, 0, len(keys)) | ||||||
| 	for _, key := range keys { | 	for _, key := range keys { | ||||||
| 		if key == "" { | 		if key == "" { | ||||||
| 			continue | 			continue | ||||||
|   | |||||||
| @@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
|  |  | ||||||
| func GitHubOAuth(c *gin.Context) { | func GitHubOAuth(c *gin.Context) { | ||||||
| 	session := sessions.Default(c) | 	session := sessions.Default(c) | ||||||
|  | 	state := c.Query("state") | ||||||
|  | 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||||
|  | 		c.JSON(http.StatusForbidden, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "state is empty or not same", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
| 	username := session.Get("username") | 	username := session.Get("username") | ||||||
| 	if username != nil { | 	if username != nil { | ||||||
| 		GitHubBind(c) | 		GitHubBind(c) | ||||||
| @@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GenerateOAuthCode(c *gin.Context) { | ||||||
|  | 	session := sessions.Default(c) | ||||||
|  | 	state := common.GetRandomString(12) | ||||||
|  | 	session.Set("oauth_state", state) | ||||||
|  | 	err := session.Save() | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    state, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -63,6 +63,15 @@ func init() { | |||||||
| 			Root:       "dall-e", | 			Root:       "dall-e", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "whisper-1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "whisper-1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "gpt-3.5-turbo", | 			Id:         "gpt-3.5-turbo", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| @@ -351,6 +360,15 @@ func init() { | |||||||
| 			Root:       "qwen-plus-v1", | 			Root:       "qwen-plus-v1", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "text-embedding-v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "ali", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "text-embedding-v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "SparkDesk", | 			Id:         "SparkDesk", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| @@ -360,6 +378,51 @@ func init() { | |||||||
| 			Root:       "SparkDesk", | 			Root:       "SparkDesk", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "360GPT_S2_V9", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "360GPT_S2_V9", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "embedding-bert-512-v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "embedding-bert-512-v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "embedding_s1_v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "embedding_s1_v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "semantic_similarity_s1_v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "semantic_similarity_s1_v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "360GPT_S2_V9.4", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "360GPT_S2_V9.4", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | 	openAIModelsMap = make(map[string]OpenAIModels) | ||||||
| 	for _, model := range openAIModels { | 	for _, model := range openAIModels { | ||||||
|   | |||||||
							
								
								
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,220 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | ||||||
|  |  | ||||||
|  | type AIProxyLibraryRequest struct { | ||||||
|  | 	Model     string `json:"model"` | ||||||
|  | 	Query     string `json:"query"` | ||||||
|  | 	LibraryId string `json:"libraryId"` | ||||||
|  | 	Stream    bool   `json:"stream"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryError struct { | ||||||
|  | 	ErrCode int    `json:"errCode"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryDocument struct { | ||||||
|  | 	Title string `json:"title"` | ||||||
|  | 	URL   string `json:"url"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryResponse struct { | ||||||
|  | 	Success   bool                     `json:"success"` | ||||||
|  | 	Answer    string                   `json:"answer"` | ||||||
|  | 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||||
|  | 	AIProxyLibraryError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryStreamResponse struct { | ||||||
|  | 	Content   string                   `json:"content"` | ||||||
|  | 	Finish    bool                     `json:"finish"` | ||||||
|  | 	Model     string                   `json:"model"` | ||||||
|  | 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { | ||||||
|  | 	query := "" | ||||||
|  | 	if len(request.Messages) != 0 { | ||||||
|  | 		query = request.Messages[len(request.Messages)-1].Content | ||||||
|  | 	} | ||||||
|  | 	return &AIProxyLibraryRequest{ | ||||||
|  | 		Model:  request.Model, | ||||||
|  | 		Stream: request.Stream, | ||||||
|  | 		Query:  query, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | ||||||
|  | 	if len(documents) == 0 { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	content := "\n\n参考文档:\n" | ||||||
|  | 	for i, document := range documents { | ||||||
|  | 		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) | ||||||
|  | 	} | ||||||
|  | 	return content | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { | ||||||
|  | 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: content, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: "stop", | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||||
|  | 	choice.FinishReason = &stopFinishReason | ||||||
|  | 	return &ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = response.Content | ||||||
|  | 	return &ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   response.Model, | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	var documents []AIProxyLibraryDocument | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var AIProxyLibraryResponse AIProxyLibraryStreamResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if len(AIProxyLibraryResponse.Documents) != 0 { | ||||||
|  | 				documents = AIProxyLibraryResponse.Documents | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			response := documentsAIProxyLibrary(documents) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var AIProxyLibraryResponse AIProxyLibraryResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if AIProxyLibraryResponse.ErrCode != 0 { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: AIProxyLibraryResponse.Message, | ||||||
|  | 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), | ||||||
|  | 				Code:    AIProxyLibraryResponse.ErrCode, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
| @@ -35,6 +35,29 @@ type AliChatRequest struct { | |||||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | 	Parameters AliParameters `json:"parameters,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingRequest struct { | ||||||
|  | 	Model string `json:"model"` | ||||||
|  | 	Input struct { | ||||||
|  | 		Texts []string `json:"texts"` | ||||||
|  | 	} `json:"input"` | ||||||
|  | 	Parameters *struct { | ||||||
|  | 		TextType string `json:"text_type,omitempty"` | ||||||
|  | 	} `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbedding struct { | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	TextIndex int       `json:"text_index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingResponse struct { | ||||||
|  | 	Output struct { | ||||||
|  | 		Embeddings []AliEmbedding `json:"embeddings"` | ||||||
|  | 	} `json:"output"` | ||||||
|  | 	Usage AliUsage `json:"usage"` | ||||||
|  | 	AliError | ||||||
|  | } | ||||||
|  |  | ||||||
| type AliError struct { | type AliError struct { | ||||||
| 	Code      string `json:"code"` | 	Code      string `json:"code"` | ||||||
| 	Message   string `json:"message"` | 	Message   string `json:"message"` | ||||||
| @@ -44,6 +67,7 @@ type AliError struct { | |||||||
| type AliUsage struct { | type AliUsage struct { | ||||||
| 	InputTokens  int `json:"input_tokens"` | 	InputTokens  int `json:"input_tokens"` | ||||||
| 	OutputTokens int `json:"output_tokens"` | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type AliOutput struct { | type AliOutput struct { | ||||||
| @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { | ||||||
|  | 	return &AliEmbeddingRequest{ | ||||||
|  | 		Model: "text-embedding-v1", | ||||||
|  | 		Input: struct { | ||||||
|  | 			Texts []string `json:"texts"` | ||||||
|  | 		}{ | ||||||
|  | 			Texts: request.ParseInput(), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var aliResponse AliEmbeddingResponse | ||||||
|  | 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if aliResponse.Code != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: aliResponse.Message, | ||||||
|  | 				Type:    aliResponse.Code, | ||||||
|  | 				Param:   aliResponse.RequestId, | ||||||
|  | 				Code:    aliResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||||
|  | 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), | ||||||
|  | 		Model:  "text-embedding-v1", | ||||||
|  | 		Usage:  Usage{TotalTokens: response.Usage.TotalTokens}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, item := range response.Output.Embeddings { | ||||||
|  | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||||
|  | 			Object:    `embedding`, | ||||||
|  | 			Index:     item.TextIndex, | ||||||
|  | 			Embedding: item.Embedding, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return &openAIEmbeddingResponse | ||||||
|  | } | ||||||
|  |  | ||||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||||
| 	choice := OpenAITextResponseChoice{ | 	choice := OpenAITextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
|   | |||||||
							
								
								
									
										148
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,148 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||||
|  | 	audioModel := "whisper-1" | ||||||
|  |  | ||||||
|  | 	tokenId := c.GetInt("token_id") | ||||||
|  | 	channelType := c.GetInt("channel") | ||||||
|  | 	userId := c.GetInt("id") | ||||||
|  | 	group := c.GetString("group") | ||||||
|  |  | ||||||
|  | 	preConsumedTokens := common.PreConsumedQuota | ||||||
|  | 	modelRatio := common.GetModelRatio(audioModel) | ||||||
|  | 	groupRatio := common.GetGroupRatio(group) | ||||||
|  | 	ratio := modelRatio * groupRatio | ||||||
|  | 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||||
|  | 	userQuota, err := model.CacheGetUserQuota(userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	if userQuota > 100*preConsumedQuota { | ||||||
|  | 		// in this case, we do not pre-consume quota | ||||||
|  | 		// because the user has enough quota | ||||||
|  | 		preConsumedQuota = 0 | ||||||
|  | 	} | ||||||
|  | 	if preConsumedQuota > 0 { | ||||||
|  | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// map model name | ||||||
|  | 	modelMapping := c.GetString("model_mapping") | ||||||
|  | 	if modelMapping != "" { | ||||||
|  | 		modelMap := make(map[string]string) | ||||||
|  | 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		if modelMap[audioModel] != "" { | ||||||
|  | 			audioModel = modelMap[audioModel] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
|  | 	requestURL := c.Request.URL.String() | ||||||
|  |  | ||||||
|  | 	if c.GetString("base_url") != "" { | ||||||
|  | 		baseURL = c.GetString("base_url") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||||
|  | 	requestBody := c.Request.Body | ||||||
|  |  | ||||||
|  | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||||
|  | 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
|  | 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
|  |  | ||||||
|  | 	resp, err := httpClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = req.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = c.Request.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	var audioResponse AudioResponse | ||||||
|  |  | ||||||
|  | 	defer func(ctx context.Context) { | ||||||
|  | 		go func() { | ||||||
|  | 			quota := countTokenText(audioResponse.Text, audioModel) | ||||||
|  | 			quotaDelta := quota - preConsumedQuota | ||||||
|  | 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||||
|  | 			} | ||||||
|  | 			err = model.CacheUpdateUserQuota(userId) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error update user quota cache: " + err.Error()) | ||||||
|  | 			} | ||||||
|  | 			if quota != 0 { | ||||||
|  | 				tokenName := c.GetString("token_name") | ||||||
|  | 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
|  | 				model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent) | ||||||
|  | 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
|  | 				channelId := c.GetInt("channel_id") | ||||||
|  | 				model.UpdateChannelUsedQuota(channelId, quota) | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  | 	}(c.Request.Context()) | ||||||
|  |  | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &audioResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||||
|  |  | ||||||
|  | 	for k, v := range resp.Header { | ||||||
|  | 		c.Writer.Header().Set(k, v[0]) | ||||||
|  | 	} | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  |  | ||||||
|  | 	_, err = io.Copy(c.Writer, resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
| @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom | |||||||
| } | } | ||||||
|  |  | ||||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||||
| 	baiduEmbeddingRequest := BaiduEmbeddingRequest{ | 	return &BaiduEmbeddingRequest{ | ||||||
| 		Input: nil, | 		Input: request.ParseInput(), | ||||||
| 	} | 	} | ||||||
| 	switch request.Input.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		baiduEmbeddingRequest.Input = []string{request.Input.(string)} |  | ||||||
| 	case []any: |  | ||||||
| 		for _, item := range request.Input.([]any) { |  | ||||||
| 			if str, ok := item.(string); ok { |  | ||||||
| 				baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return &baiduEmbeddingRequest |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	} | 	} | ||||||
| 	var textResponse ImageResponse | 	var textResponse ImageResponse | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func(ctx context.Context) { | ||||||
| 		if consumeQuota { | 		if consumeQuota { | ||||||
| 			err := model.PostConsumeTokenQuota(tokenId, quota) | 			err := model.PostConsumeTokenQuota(tokenId, quota) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 			if quota != 0 { | 			if quota != 0 { | ||||||
| 				tokenName := c.GetString("token_name") | 				tokenName := c.GetString("token_name") | ||||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) | 				model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent) | ||||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
| 				channelId := c.GetInt("channel_id") | 				channelId := c.GetInt("channel_id") | ||||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | 				model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}(c.Request.Context()) | ||||||
|  |  | ||||||
| 	if consumeQuota { | 	if consumeQuota { | ||||||
| 		responseBody, err := io.ReadAll(resp.Body) | 		responseBody, err := io.ReadAll(resp.Body) | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -22,6 +23,7 @@ const ( | |||||||
| 	APITypeZhipu | 	APITypeZhipu | ||||||
| 	APITypeAli | 	APITypeAli | ||||||
| 	APITypeXunfei | 	APITypeXunfei | ||||||
|  | 	APITypeAIProxyLibrary | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var httpClient *http.Client | var httpClient *http.Client | ||||||
| @@ -104,6 +106,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		apiType = APITypeAli | 		apiType = APITypeAli | ||||||
| 	case common.ChannelTypeXunfei: | 	case common.ChannelTypeXunfei: | ||||||
| 		apiType = APITypeXunfei | 		apiType = APITypeXunfei | ||||||
|  | 	case common.ChannelTypeAIProxyLibrary: | ||||||
|  | 		apiType = APITypeAIProxyLibrary | ||||||
| 	} | 	} | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
| @@ -171,6 +175,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||||
| 	case APITypeAli: | 	case APITypeAli: | ||||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||||
|  | 		if relayMode == RelayModeEmbeddings { | ||||||
|  | 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||||
|  | 		} | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | ||||||
| 	} | 	} | ||||||
| 	var promptTokens int | 	var promptTokens int | ||||||
| 	var completionTokens int | 	var completionTokens int | ||||||
| @@ -202,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		// in this case, we do not pre-consume quota | 		// in this case, we do not pre-consume quota | ||||||
| 		// because the user has enough quota | 		// because the user has enough quota | ||||||
| 		preConsumedQuota = 0 | 		preConsumedQuota = 0 | ||||||
|  | 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||||
| 	} | 	} | ||||||
| 	if consumeQuota && preConsumedQuota > 0 { | 	if consumeQuota && preConsumedQuota > 0 { | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
| @@ -257,8 +267,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeAli: | 	case APITypeAli: | ||||||
| 		aliRequest := requestOpenAI2Ali(textRequest) | 		var jsonStr []byte | ||||||
| 		jsonStr, err := json.Marshal(aliRequest) | 		var err error | ||||||
|  | 		switch relayMode { | ||||||
|  | 		case RelayModeEmbeddings: | ||||||
|  | 			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) | ||||||
|  | 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | ||||||
|  | 		default: | ||||||
|  | 			aliRequest := requestOpenAI2Ali(textRequest) | ||||||
|  | 			jsonStr, err = json.Marshal(aliRequest) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | ||||||
|  | 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||||
|  | 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| @@ -282,6 +308,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 				req.Header.Set("api-key", apiKey) | 				req.Header.Set("api-key", apiKey) | ||||||
| 			} else { | 			} else { | ||||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||||
|  | 				if channelType == common.ChannelTypeOpenRouter { | ||||||
|  | 					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") | ||||||
|  | 					req.Header.Set("X-Title", "One API") | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		case APITypeClaude: | 		case APITypeClaude: | ||||||
| 			req.Header.Set("x-api-key", apiKey) | 			req.Header.Set("x-api-key", apiKey) | ||||||
| @@ -298,6 +328,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if textRequest.Stream { | 			if textRequest.Stream { | ||||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | 				req.Header.Set("X-DashScope-SSE", "enable") | ||||||
| 			} | 			} | ||||||
|  | 		default: | ||||||
|  | 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||||
| 		} | 		} | ||||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
| @@ -317,8 +349,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
|  |  | ||||||
| 		if resp.StatusCode != http.StatusOK { | 		if resp.StatusCode != http.StatusOK { | ||||||
| 			return errorWrapper( | 			if preConsumedQuota != 0 { | ||||||
| 				fmt.Errorf("bad status code: %d", resp.StatusCode), "bad_status_code", resp.StatusCode) | 				go func(ctx context.Context) { | ||||||
|  | 					// return pre-consumed quota | ||||||
|  | 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||||
|  | 					if err != nil { | ||||||
|  | 						common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 					} | ||||||
|  | 				}(c.Request.Context()) | ||||||
|  | 			} | ||||||
|  | 			return relayErrorHandler(resp) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -326,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	tokenName := c.GetString("token_name") | 	tokenName := c.GetString("token_name") | ||||||
| 	channelId := c.GetInt("channel_id") | 	channelId := c.GetInt("channel_id") | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func(ctx context.Context) { | ||||||
| 		// c.Writer.Flush() | 		// c.Writer.Flush() | ||||||
| 		go func() { | 		go func() { | ||||||
| 			if consumeQuota { | 			if consumeQuota { | ||||||
| @@ -349,22 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 				quotaDelta := quota - preConsumedQuota | 				quotaDelta := quota - preConsumedQuota | ||||||
| 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("error consuming token remain quota: " + err.Error()) | 					common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||||
| 				} | 				} | ||||||
| 				err = model.CacheUpdateUserQuota(userId) | 				err = model.CacheUpdateUserQuota(userId) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("error update user quota cache: " + err.Error()) | 					common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||||
| 				} | 				} | ||||||
| 				if quota != 0 { | 				if quota != 0 { | ||||||
| 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | 					model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||||
| 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
|  |  | ||||||
| 					model.UpdateChannelUsedQuota(channelId, quota) | 					model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| 	}() | 	}(c.Request.Context()) | ||||||
| 	switch apiType { | 	switch apiType { | ||||||
| 	case APITypeOpenAI: | 	case APITypeOpenAI: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| @@ -485,7 +524,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := aliHandler(c, resp) | 			var err *OpenAIErrorWithStatusCode | ||||||
|  | 			var usage *Usage | ||||||
|  | 			switch relayMode { | ||||||
|  | 			case RelayModeEmbeddings: | ||||||
|  | 				err, usage = aliEmbeddingHandler(c, resp) | ||||||
|  | 			default: | ||||||
|  | 				err, usage = aliHandler(c, resp) | ||||||
|  | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -513,6 +559,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} else { | 		} else { | ||||||
| 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) | 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage := aiProxyLibraryStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := aiProxyLibraryHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
| 	default: | 	default: | ||||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,16 +1,38 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/pkoukk/tiktoken-go" | 	"github.com/pkoukk/tiktoken-go" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"strconv" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var stopFinishReason = "stop" | var stopFinishReason = "stop" | ||||||
|  |  | ||||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||||
|  |  | ||||||
|  | func InitTokenEncoders() { | ||||||
|  | 	common.SysLog("initializing token encoders") | ||||||
|  | 	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | 	for model, _ := range common.ModelRatio { | ||||||
|  | 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||||
|  | 		if err != nil { | ||||||
|  | 			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) | ||||||
|  | 			tokenEncoderMap[model] = fallbackTokenEncoder | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		tokenEncoderMap[model] = tokenEncoder | ||||||
|  | 	} | ||||||
|  | 	common.SysLog("token encoders initialized") | ||||||
|  | } | ||||||
|  |  | ||||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||||
| 	if tokenEncoder, ok := tokenEncoderMap[model]; ok { | 	if tokenEncoder, ok := tokenEncoderMap[model]; ok { | ||||||
| 		return tokenEncoder | 		return tokenEncoder | ||||||
| @@ -95,13 +117,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func shouldDisableChannel(err *OpenAIError) bool { | func shouldDisableChannel(err *OpenAIError, statusCode int) bool { | ||||||
| 	if !common.AutomaticDisableChannelEnabled { | 	if !common.AutomaticDisableChannelEnabled { | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
|  | 	if statusCode == http.StatusUnauthorized { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
| 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| @@ -115,3 +140,30 @@ func setEventStreamHeaders(c *gin.Context) { | |||||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { | ||||||
|  | 	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ | ||||||
|  | 		StatusCode: resp.StatusCode, | ||||||
|  | 		OpenAIError: OpenAIError{ | ||||||
|  | 			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), | ||||||
|  | 			Type:    "upstream_error", | ||||||
|  | 			Code:    "bad_response_status_code", | ||||||
|  | 			Param:   strconv.Itoa(resp.StatusCode), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	var textResponse TextResponse | ||||||
|  | 	err = json.Unmarshal(responseBody, &textResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error | ||||||
|  | 	return | ||||||
|  | } | ||||||
|   | |||||||
| @@ -24,6 +24,7 @@ const ( | |||||||
| 	RelayModeModerations | 	RelayModeModerations | ||||||
| 	RelayModeImagesGenerations | 	RelayModeImagesGenerations | ||||||
| 	RelayModeEdits | 	RelayModeEdits | ||||||
|  | 	RelayModeAudio | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/chat | // https://platform.openai.com/docs/api-reference/chat | ||||||
| @@ -40,6 +41,26 @@ type GeneralOpenAIRequest struct { | |||||||
| 	Input       any       `json:"input,omitempty"` | 	Input       any       `json:"input,omitempty"` | ||||||
| 	Instruction string    `json:"instruction,omitempty"` | 	Instruction string    `json:"instruction,omitempty"` | ||||||
| 	Size        string    `json:"size,omitempty"` | 	Size        string    `json:"size,omitempty"` | ||||||
|  | 	Functions   any       `json:"functions,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r GeneralOpenAIRequest) ParseInput() []string { | ||||||
|  | 	if r.Input == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	var input []string | ||||||
|  | 	switch r.Input.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		input = []string{r.Input.(string)} | ||||||
|  | 	case []any: | ||||||
|  | 		input = make([]string, 0, len(r.Input.([]any))) | ||||||
|  | 		for _, item := range r.Input.([]any) { | ||||||
|  | 			if str, ok := item.(string); ok { | ||||||
|  | 				input = append(input, str) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return input | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatRequest struct { | type ChatRequest struct { | ||||||
| @@ -62,6 +83,10 @@ type ImageRequest struct { | |||||||
| 	Size   string `json:"size"` | 	Size   string `json:"size"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type AudioResponse struct { | ||||||
|  | 	Text string `json:"text,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type Usage struct { | type Usage struct { | ||||||
| 	PromptTokens     int `json:"prompt_tokens"` | 	PromptTokens     int `json:"prompt_tokens"` | ||||||
| 	CompletionTokens int `json:"completion_tokens"` | 	CompletionTokens int `json:"completion_tokens"` | ||||||
| @@ -158,15 +183,20 @@ func Relay(c *gin.Context) { | |||||||
| 		relayMode = RelayModeImagesGenerations | 		relayMode = RelayModeImagesGenerations | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||||
| 		relayMode = RelayModeEdits | 		relayMode = RelayModeEdits | ||||||
|  | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||||
|  | 		relayMode = RelayModeAudio | ||||||
| 	} | 	} | ||||||
| 	var err *OpenAIErrorWithStatusCode | 	var err *OpenAIErrorWithStatusCode | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case RelayModeImagesGenerations: | 	case RelayModeImagesGenerations: | ||||||
| 		err = relayImageHelper(c, relayMode) | 		err = relayImageHelper(c, relayMode) | ||||||
|  | 	case RelayModeAudio: | ||||||
|  | 		err = relayAudioHelper(c, relayMode) | ||||||
| 	default: | 	default: | ||||||
| 		err = relayTextHelper(c, relayMode) | 		err = relayTextHelper(c, relayMode) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		requestId := c.GetString(common.RequestIdKey) | ||||||
| 		retryTimesStr := c.Query("retry") | 		retryTimesStr := c.Query("retry") | ||||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||||
| 		if retryTimesStr == "" { | 		if retryTimesStr == "" { | ||||||
| @@ -178,14 +208,15 @@ func Relay(c *gin.Context) { | |||||||
| 			if err.StatusCode == http.StatusTooManyRequests { | 			if err.StatusCode == http.StatusTooManyRequests { | ||||||
| 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | ||||||
| 			} | 			} | ||||||
|  | 			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | ||||||
| 			c.JSON(err.StatusCode, gin.H{ | 			c.JSON(err.StatusCode, gin.H{ | ||||||
| 				"error": err.OpenAIError, | 				"error": err.OpenAIError, | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 		channelId := c.GetInt("channel_id") | 		channelId := c.GetInt("channel_id") | ||||||
| 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 		if shouldDisableChannel(&err.OpenAIError) { | 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | ||||||
| 			channelId := c.GetInt("channel_id") | 			channelId := c.GetInt("channel_id") | ||||||
| 			channelName := c.GetString("channel_name") | 			channelName := c.GetString("channel_name") | ||||||
| 			disableChannel(channelId, channelName, err.Message) | 			disableChannel(channelId, channelName, err.Message) | ||||||
|   | |||||||
| @@ -523,5 +523,6 @@ | |||||||
|   "按照如下格式输入:": "Enter in the following format:", |   "按照如下格式输入:": "Enter in the following format:", | ||||||
|   "模型版本": "Model version", |   "模型版本": "Model version", | ||||||
|   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", |   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", | ||||||
|   "点击查看": "click to view" |   "点击查看": "click to view", | ||||||
|  |   "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										15
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								main.go
									
									
									
									
									
								
							| @@ -21,7 +21,7 @@ var buildFS embed.FS | |||||||
| var indexPage []byte | var indexPage []byte | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| 	common.SetupGinLog() | 	common.SetupLogger() | ||||||
| 	common.SysLog("One API " + common.Version + " started") | 	common.SysLog("One API " + common.Version + " started") | ||||||
| 	if os.Getenv("GIN_MODE") != "debug" { | 	if os.Getenv("GIN_MODE") != "debug" { | ||||||
| 		gin.SetMode(gin.ReleaseMode) | 		gin.SetMode(gin.ReleaseMode) | ||||||
| @@ -77,13 +77,20 @@ func main() { | |||||||
| 		} | 		} | ||||||
| 		go controller.AutomaticallyTestChannels(frequency) | 		go controller.AutomaticallyTestChannels(frequency) | ||||||
| 	} | 	} | ||||||
|  | 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | ||||||
|  | 		common.BatchUpdateEnabled = true | ||||||
|  | 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | ||||||
|  | 		model.InitBatchUpdater() | ||||||
|  | 	} | ||||||
|  | 	controller.InitTokenEncoders() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
| 	server := gin.Default() | 	server := gin.New() | ||||||
|  | 	server.Use(gin.Recovery()) | ||||||
| 	// This will cause SSE not to work!!! | 	// This will cause SSE not to work!!! | ||||||
| 	//server.Use(gzip.Gzip(gzip.DefaultCompression)) | 	//server.Use(gzip.Gzip(gzip.DefaultCompression)) | ||||||
| 	server.Use(middleware.CORS()) | 	server.Use(middleware.RequestId()) | ||||||
|  | 	middleware.SetUpLogger(server) | ||||||
| 	// Initialize session store | 	// Initialize session store | ||||||
| 	store := cookie.NewStore([]byte(common.SessionSecret)) | 	store := cookie.NewStore([]byte(common.SessionSecret)) | ||||||
| 	server.Use(sessions.Sessions("session", store)) | 	server.Use(sessions.Sessions("session", store)) | ||||||
|   | |||||||
| @@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 		key = parts[0] | 		key = parts[0] | ||||||
| 		token, err := model.ValidateUserToken(key) | 		token, err := model.ValidateUserToken(key) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			c.JSON(http.StatusUnauthorized, gin.H{ | 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||||
| 				"error": gin.H{ |  | ||||||
| 					"message": err.Error(), |  | ||||||
| 					"type":    "one_api_error", |  | ||||||
| 				}, |  | ||||||
| 			}) |  | ||||||
| 			c.Abort() |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if !model.CacheIsUserEnabled(token.UserId) { | 		userEnabled, err := model.IsUserEnabled(token.UserId) | ||||||
| 			c.JSON(http.StatusForbidden, gin.H{ | 		if err != nil { | ||||||
| 				"error": gin.H{ | 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
| 					"message": "用户已被封禁", | 			return | ||||||
| 					"type":    "one_api_error", | 		} | ||||||
| 				}, | 		if !userEnabled { | ||||||
| 			}) | 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||||
| 			c.Abort() |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		c.Set("id", token.UserId) | 		c.Set("id", token.UserId) | ||||||
| @@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			if model.IsAdmin(token.UserId) { | 			if model.IsAdmin(token.UserId) { | ||||||
| 				c.Set("channelId", parts[1]) | 				c.Set("channelId", parts[1]) | ||||||
| 			} else { | 			} else { | ||||||
| 				c.JSON(http.StatusForbidden, gin.H{ | 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "普通用户不支持指定渠道", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -25,48 +25,27 @@ func Distribute() func(c *gin.Context) { | |||||||
| 		if ok { | 		if ok { | ||||||
| 			id, err := strconv.Atoi(channelId.(string)) | 			id, err := strconv.Atoi(channelId.(string)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的渠道 ID", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			channel, err = model.GetChannelById(id, true) | 			channel, err = model.GetChannelById(id, true) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的渠道 ID", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if channel.Status != common.ChannelStatusEnabled { | 			if channel.Status != common.ChannelStatusEnabled { | ||||||
| 				c.JSON(http.StatusForbidden, gin.H{ | 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "该渠道已被禁用", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			// Select a channel for the user | 			// Select a channel for the user | ||||||
| 			var modelRequest ModelRequest | 			var modelRequest ModelRequest | ||||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | 			var err error | ||||||
|  | 			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||||
|  | 				err = common.UnmarshalBodyReusable(c, &modelRequest) | ||||||
|  | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的请求", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||||
| @@ -84,6 +63,11 @@ func Distribute() func(c *gin.Context) { | |||||||
| 					modelRequest.Model = "dall-e" | 					modelRequest.Model = "dall-e" | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||||
|  | 				if modelRequest.Model == "" { | ||||||
|  | 					modelRequest.Model = "whisper-1" | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||||
| @@ -91,13 +75,7 @@ func Distribute() func(c *gin.Context) { | |||||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||||
| 					message = "数据库一致性已被破坏,请联系管理员" | 					message = "数据库一致性已被破坏,请联系管理员" | ||||||
| 				} | 				} | ||||||
| 				c.JSON(http.StatusServiceUnavailable, gin.H{ | 				abortWithMessage(c, http.StatusServiceUnavailable, message) | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": message, |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @@ -107,8 +85,13 @@ func Distribute() func(c *gin.Context) { | |||||||
| 		c.Set("model_mapping", channel.ModelMapping) | 		c.Set("model_mapping", channel.ModelMapping) | ||||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||||
| 		c.Set("base_url", channel.BaseURL) | 		c.Set("base_url", channel.BaseURL) | ||||||
| 		if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { | 		switch channel.Type { | ||||||
|  | 		case common.ChannelTypeAzure: | ||||||
| 			c.Set("api_version", channel.Other) | 			c.Set("api_version", channel.Other) | ||||||
|  | 		case common.ChannelTypeXunfei: | ||||||
|  | 			c.Set("api_version", channel.Other) | ||||||
|  | 		case common.ChannelTypeAIProxyLibrary: | ||||||
|  | 			c.Set("library_id", channel.Other) | ||||||
| 		} | 		} | ||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func SetUpLogger(server *gin.Engine) { | ||||||
|  | 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | ||||||
|  | 		var requestID string | ||||||
|  | 		if param.Keys != nil { | ||||||
|  | 			requestID = param.Keys[common.RequestIdKey].(string) | ||||||
|  | 		} | ||||||
|  | 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | ||||||
|  | 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | ||||||
|  | 			requestID, | ||||||
|  | 			param.StatusCode, | ||||||
|  | 			param.Latency, | ||||||
|  | 			param.ClientIP, | ||||||
|  | 			param.Method, | ||||||
|  | 			param.Path, | ||||||
|  | 		) | ||||||
|  | 	})) | ||||||
|  | } | ||||||
							
								
								
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RequestId() func(c *gin.Context) { | ||||||
|  | 	return func(c *gin.Context) { | ||||||
|  | 		id := common.GetTimeString() + common.GetRandomString(8) | ||||||
|  | 		c.Set(common.RequestIdKey, id) | ||||||
|  | 		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) | ||||||
|  | 		c.Request = c.Request.WithContext(ctx) | ||||||
|  | 		c.Header(common.RequestIdKey, id) | ||||||
|  | 		c.Next() | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||||
|  | 	c.JSON(statusCode, gin.H{ | ||||||
|  | 		"error": gin.H{ | ||||||
|  | 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), | ||||||
|  | 			"type":    "one_api_error", | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	c.Abort() | ||||||
|  | 	common.LogError(c.Request.Context(), message) | ||||||
|  | } | ||||||
| @@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheIsUserEnabled(userId int) bool { | func CacheIsUserEnabled(userId int) (bool, error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return IsUserEnabled(userId) | 		return IsUserEnabled(userId) | ||||||
| 	} | 	} | ||||||
| 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) | 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) | ||||||
| 	if err != nil { | 	if err == nil { | ||||||
| 		status := common.UserStatusDisabled | 		return enabled == "1", nil | ||||||
| 		if IsUserEnabled(userId) { |  | ||||||
| 			status = common.UserStatusEnabled |  | ||||||
| 		} |  | ||||||
| 		enabled = fmt.Sprintf("%d", status) |  | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("Redis set user enabled error: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	return enabled == "1" |  | ||||||
|  | 	userEnabled, err := IsUserEnabled(userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	enabled = "0" | ||||||
|  | 	if userEnabled { | ||||||
|  | 		enabled = "1" | ||||||
|  | 	} | ||||||
|  | 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("Redis set user enabled error: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	return userEnabled, err | ||||||
| } | } | ||||||
|  |  | ||||||
| var group2model2channels map[string]map[string][]*Channel | var group2model2channels map[string]map[string][]*Channel | ||||||
|   | |||||||
| @@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelUsedQuota(id int, quota int) { | func UpdateChannelUsedQuota(id int, quota int) { | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	updateChannelUsedQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateChannelUsedQuota(id int, quota int) { | ||||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||||
|   | |||||||
| @@ -1,6 +1,8 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| ) | ) | ||||||
| @@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||||
|  | 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
| 	if !common.LogConsumeEnabled { | 	if !common.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN | |||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to record log: " + err.Error()) | 		common.LogError(ctx, "failed to record log: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { | |||||||
| 	} | 	} | ||||||
| 	token, err = CacheGetTokenByKey(key) | 	token, err = CacheGetTokenByKey(key) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
|  | 		if token.Status == common.TokenStatusExhausted { | ||||||
|  | 			return nil, errors.New("该令牌额度已用尽") | ||||||
|  | 		} else if token.Status == common.TokenStatusExpired { | ||||||
|  | 			return nil, errors.New("该令牌已过期") | ||||||
|  | 		} | ||||||
| 		if token.Status != common.TokenStatusEnabled { | 		if token.Status != common.TokenStatusEnabled { | ||||||
| 			return nil, errors.New("该令牌状态不可用") | 			return nil, errors.New("该令牌状态不可用") | ||||||
| 		} | 		} | ||||||
| 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | ||||||
| 			token.Status = common.TokenStatusExpired | 			if !common.RedisEnabled { | ||||||
| 			err := token.SelectUpdate() | 				token.Status = common.TokenStatusExpired | ||||||
| 			if err != nil { | 				err := token.SelectUpdate() | ||||||
| 				common.SysError("failed to update token status" + err.Error()) | 				if err != nil { | ||||||
|  | 					common.SysError("failed to update token status" + err.Error()) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			return nil, errors.New("该令牌已过期") | 			return nil, errors.New("该令牌已过期") | ||||||
| 		} | 		} | ||||||
| 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||||
| 			token.Status = common.TokenStatusExhausted | 			if !common.RedisEnabled { | ||||||
| 			err := token.SelectUpdate() | 				// in this case, we can make sure the token is exhausted | ||||||
| 			if err != nil { | 				token.Status = common.TokenStatusExhausted | ||||||
| 				common.SysError("failed to update token status" + err.Error()) | 				err := token.SelectUpdate() | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to update token status" + err.Error()) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			return nil, errors.New("该令牌额度已用尽") | 			return nil, errors.New("该令牌额度已用尽") | ||||||
| 		} | 		} | ||||||
| 		go func() { |  | ||||||
| 			token.AccessedTime = common.GetTimestamp() |  | ||||||
| 			err := token.SelectUpdate() |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("failed to update token" + err.Error()) |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 		return token, nil | 		return token, nil | ||||||
| 	} | 	} | ||||||
| 	return nil, errors.New("无效的令牌") | 	return nil, errors.New("无效的令牌") | ||||||
| @@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return increaseTokenQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func increaseTokenQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota": gorm.Expr("remain_quota + ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||||
| 			"used_quota":   gorm.Expr("used_quota - ?", quota), | 			"used_quota":    gorm.Expr("used_quota - ?", quota), | ||||||
|  | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
| @@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return decreaseTokenQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func decreaseTokenQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota": gorm.Expr("remain_quota - ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||||
| 			"used_quota":   gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
|  | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
|   | |||||||
| @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { | |||||||
| 	return user.Role >= common.RoleAdminUser | 	return user.Role >= common.RoleAdminUser | ||||||
| } | } | ||||||
|  |  | ||||||
| func IsUserEnabled(userId int) bool { | func IsUserEnabled(userId int) (bool, error) { | ||||||
| 	if userId == 0 { | 	if userId == 0 { | ||||||
| 		return false | 		return false, errors.New("user id is empty") | ||||||
| 	} | 	} | ||||||
| 	var user User | 	var user User | ||||||
| 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error | 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("no such user " + err.Error()) | 		return false, err | ||||||
| 		return false |  | ||||||
| 	} | 	} | ||||||
| 	return user.Status == common.UserStatusEnabled | 	return user.Status == common.UserStatusEnabled, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func ValidateAccessToken(token string) (user *User) { | func ValidateAccessToken(token string) (user *User) { | ||||||
| @@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return increaseUserQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func increaseUserQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return decreaseUserQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func decreaseUserQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
| 			"request_count": gorm.Expr("request_count + ?", 1), | 			"request_count": gorm.Expr("request_count + ?", count), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
							
								
								
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | |||||||
|  | package model | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	BatchUpdateTypeUserQuota = iota | ||||||
|  | 	BatchUpdateTypeTokenQuota | ||||||
|  | 	BatchUpdateTypeUsedQuotaAndRequestCount | ||||||
|  | 	BatchUpdateTypeChannelUsedQuota | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var batchUpdateStores []map[int]int | ||||||
|  | var batchUpdateLocks []sync.Mutex | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
|  | 		batchUpdateStores = append(batchUpdateStores, make(map[int]int)) | ||||||
|  | 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func InitBatchUpdater() { | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) | ||||||
|  | 			batchUpdate() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func addNewRecord(type_ int, id int, value int) { | ||||||
|  | 	batchUpdateLocks[type_].Lock() | ||||||
|  | 	defer batchUpdateLocks[type_].Unlock() | ||||||
|  | 	if _, ok := batchUpdateStores[type_][id]; !ok { | ||||||
|  | 		batchUpdateStores[type_][id] = value | ||||||
|  | 	} else { | ||||||
|  | 		batchUpdateStores[type_][id] += value | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func batchUpdate() { | ||||||
|  | 	common.SysLog("batch update started") | ||||||
|  | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
|  | 		batchUpdateLocks[i].Lock() | ||||||
|  | 		store := batchUpdateStores[i] | ||||||
|  | 		batchUpdateStores[i] = make(map[int]int) | ||||||
|  | 		batchUpdateLocks[i].Unlock() | ||||||
|  |  | ||||||
|  | 		for key, value := range store { | ||||||
|  | 			switch i { | ||||||
|  | 			case BatchUpdateTypeUserQuota: | ||||||
|  | 				err := increaseUserQuota(key, value) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to batch update user quota: " + err.Error()) | ||||||
|  | 				} | ||||||
|  | 			case BatchUpdateTypeTokenQuota: | ||||||
|  | 				err := increaseTokenQuota(key, value) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||||
|  | 				} | ||||||
|  | 			case BatchUpdateTypeUsedQuotaAndRequestCount: | ||||||
|  | 				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect | ||||||
|  | 			case BatchUpdateTypeChannelUsedQuota: | ||||||
|  | 				updateChannelUsedQuota(key, value) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	common.SysLog("batch update finished") | ||||||
|  | } | ||||||
| @@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | ||||||
| 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | ||||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | ||||||
|  | 		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) | ||||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | ||||||
| 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | ||||||
| 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetRelayRouter(router *gin.Engine) { | func SetRelayRouter(router *gin.Engine) { | ||||||
|  | 	router.Use(middleware.CORS()) | ||||||
| 	// https://platform.openai.com/docs/api-reference/introduction | 	// https://platform.openai.com/docs/api-reference/introduction | ||||||
| 	modelsRouter := router.Group("/v1/models") | 	modelsRouter := router.Group("/v1/models") | ||||||
| 	modelsRouter.Use(middleware.TokenAuth()) | 	modelsRouter.Use(middleware.TokenAuth()) | ||||||
| @@ -26,8 +27,8 @@ func SetRelayRouter(router *gin.Engine) { | |||||||
| 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented) | 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented) | ||||||
| 		relayV1Router.POST("/embeddings", controller.Relay) | 		relayV1Router.POST("/embeddings", controller.Relay) | ||||||
| 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | ||||||
| 		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) | 		relayV1Router.POST("/audio/transcriptions", controller.Relay) | ||||||
| 		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) | 		relayV1Router.POST("/audio/translations", controller.Relay) | ||||||
| 		relayV1Router.GET("/files", controller.RelayNotImplemented) | 		relayV1Router.GET("/files", controller.RelayNotImplemented) | ||||||
| 		relayV1Router.POST("/files", controller.RelayNotImplemented) | 		relayV1Router.POST("/files", controller.RelayNotImplemented) | ||||||
| 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) | 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link } from 'react-router-dom'; | ||||||
| import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; | ||||||
|  |  | ||||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||||
| import { renderGroup, renderNumber } from '../helpers/render'; | import { renderGroup, renderNumber } from '../helpers/render'; | ||||||
| @@ -195,6 +195,7 @@ const ChannelsTable = () => { | |||||||
|       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); |       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|  |       showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,8 +13,8 @@ const GitHubOAuth = () => { | |||||||
|  |  | ||||||
|   let navigate = useNavigate(); |   let navigate = useNavigate(); | ||||||
|  |  | ||||||
|   const sendCode = async (code, count) => { |   const sendCode = async (code, state, count) => { | ||||||
|     const res = await API.get(`/api/oauth/github?code=${code}`); |     const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); | ||||||
|     const { success, message, data } = res.data; |     const { success, message, data } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       if (message === 'bind') { |       if (message === 'bind') { | ||||||
| @@ -36,13 +36,14 @@ const GitHubOAuth = () => { | |||||||
|       count++; |       count++; | ||||||
|       setPrompt(`出现错误,第 ${count} 次重试中...`); |       setPrompt(`出现错误,第 ${count} 次重试中...`); | ||||||
|       await new Promise((resolve) => setTimeout(resolve, count * 2000)); |       await new Promise((resolve) => setTimeout(resolve, count * 2000)); | ||||||
|       await sendCode(code, count); |       await sendCode(code, state, count); | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let code = searchParams.get('code'); |     let code = searchParams.get('code'); | ||||||
|     sendCode(code, 0).then(); |     let state = searchParams.get('state'); | ||||||
|  |     sendCode(code, state, 0).then(); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f | |||||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||||
| import { UserContext } from '../context/User'; | import { UserContext } from '../context/User'; | ||||||
| import { API, getLogo, showError, showSuccess } from '../helpers'; | import { API, getLogo, showError, showSuccess } from '../helpers'; | ||||||
|  | import { getOAuthState, onGitHubOAuthClicked } from './utils'; | ||||||
|  |  | ||||||
| const LoginForm = () => { | const LoginForm = () => { | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
| @@ -31,12 +32,6 @@ const LoginForm = () => { | |||||||
|  |  | ||||||
|   const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); |   const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); | ||||||
|  |  | ||||||
|   const onGitHubOAuthClicked = () => { |  | ||||||
|     window.open( |  | ||||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` |  | ||||||
|     ); |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   const onWeChatLoginClicked = () => { |   const onWeChatLoginClicked = () => { | ||||||
|     setShowWeChatLoginModal(true); |     setShowWeChatLoginModal(true); | ||||||
|   }; |   }; | ||||||
| @@ -131,7 +126,7 @@ const LoginForm = () => { | |||||||
|                 circular |                 circular | ||||||
|                 color='black' |                 color='black' | ||||||
|                 icon='github' |                 icon='github' | ||||||
|                 onClick={onGitHubOAuthClicked} |                 onClick={()=>onGitHubOAuthClicked(status.github_client_id)} | ||||||
|               /> |               /> | ||||||
|             ) : ( |             ) : ( | ||||||
|               <></> |               <></> | ||||||
|   | |||||||
| @@ -324,7 +324,7 @@ const LogsTable = () => { | |||||||
|               .map((log, idx) => { |               .map((log, idx) => { | ||||||
|                 if (log.deleted) return <></>; |                 if (log.deleted) return <></>; | ||||||
|                 return ( |                 return ( | ||||||
|                   <Table.Row key={log.created_at}> |                   <Table.Row key={log.id}> | ||||||
|                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> |                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> | ||||||
|                     { |                     { | ||||||
|                       isAdminUser && ( |                       isAdminUser && ( | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom'; | |||||||
| import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | ||||||
| import Turnstile from 'react-turnstile'; | import Turnstile from 'react-turnstile'; | ||||||
| import { UserContext } from '../context/User'; | import { UserContext } from '../context/User'; | ||||||
|  | import { onGitHubOAuthClicked } from './utils'; | ||||||
|  |  | ||||||
| const PersonalSetting = () => { | const PersonalSetting = () => { | ||||||
|   const [userState, userDispatch] = useContext(UserContext); |   const [userState, userDispatch] = useContext(UserContext); | ||||||
| @@ -130,12 +131,6 @@ const PersonalSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const openGitHubOAuth = () => { |  | ||||||
|     window.open( |  | ||||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` |  | ||||||
|     ); |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   const sendVerificationCode = async () => { |   const sendVerificationCode = async () => { | ||||||
|     setDisableButton(true); |     setDisableButton(true); | ||||||
|     if (inputs.email === '') return; |     if (inputs.email === '') return; | ||||||
| @@ -249,7 +244,7 @@ const PersonalSetting = () => { | |||||||
|       </Modal> |       </Modal> | ||||||
|       { |       { | ||||||
|         status.github_oauth && ( |         status.github_oauth && ( | ||||||
|           <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button> |           <Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button> | ||||||
|         ) |         ) | ||||||
|       } |       } | ||||||
|       <Button |       <Button | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | import { API, showError } from '../helpers'; | ||||||
|  |  | ||||||
|  | export async function getOAuthState() { | ||||||
|  |   const res = await API.get('/api/oauth/state'); | ||||||
|  |   const { success, message, data } = res.data; | ||||||
|  |   if (success) { | ||||||
|  |     return data; | ||||||
|  |   } else { | ||||||
|  |     showError(message); | ||||||
|  |     return ''; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export async function onGitHubOAuthClicked(github_client_id) { | ||||||
|  |   const state = await getOAuthState(); | ||||||
|  |   if (!state) return; | ||||||
|  |   window.open( | ||||||
|  |     `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` | ||||||
|  |   ); | ||||||
|  | } | ||||||
| @@ -7,7 +7,11 @@ export const CHANNEL_OPTIONS = [ | |||||||
|   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, |   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, |   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, |   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||||
|  |   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|  |   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|  |   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|  |   { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, | ||||||
|   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, |   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||||
|   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, |   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, | ||||||
|   { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, |   { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | ||||||
| import { useParams, useNavigate } from 'react-router-dom'; | import { useNavigate, useParams } from 'react-router-dom'; | ||||||
| import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | ||||||
| import { CHANNEL_OPTIONS } from '../../constants'; | import { CHANNEL_OPTIONS } from '../../constants'; | ||||||
|  |  | ||||||
| @@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = { | |||||||
|   'gpt-4-32k-0314': 'gpt-4-32k' |   'gpt-4-32k-0314': 'gpt-4-32k' | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | function type2secretPrompt(type) { | ||||||
|  |   // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') | ||||||
|  |   switch (type) { | ||||||
|  |     case 15: | ||||||
|  |       return '按照如下格式输入:APIKey|SecretKey'; | ||||||
|  |     case 18: | ||||||
|  |       return '按照如下格式输入:APPID|APISecret|APIKey'; | ||||||
|  |     case 22: | ||||||
|  |       return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; | ||||||
|  |     default: | ||||||
|  |       return '请输入渠道对应的鉴权密钥'; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| const EditChannel = () => { | const EditChannel = () => { | ||||||
|   const params = useParams(); |   const params = useParams(); | ||||||
|   const navigate = useNavigate(); |   const navigate = useNavigate(); | ||||||
| @@ -19,7 +33,7 @@ const EditChannel = () => { | |||||||
|   const handleCancel = () => { |   const handleCancel = () => { | ||||||
|     navigate('/channel'); |     navigate('/channel'); | ||||||
|   }; |   }; | ||||||
|    |  | ||||||
|   const originInputs = { |   const originInputs = { | ||||||
|     name: '', |     name: '', | ||||||
|     type: 1, |     type: 1, | ||||||
| @@ -53,7 +67,7 @@ const EditChannel = () => { | |||||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; |           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||||
|           break; |           break; | ||||||
|         case 17: |         case 17: | ||||||
|           localModels = ['qwen-v1', 'qwen-plus-v1']; |           localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1']; | ||||||
|           break; |           break; | ||||||
|         case 16: |         case 16: | ||||||
|           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; |           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||||
| @@ -61,6 +75,9 @@ const EditChannel = () => { | |||||||
|         case 18: |         case 18: | ||||||
|           localModels = ['SparkDesk']; |           localModels = ['SparkDesk']; | ||||||
|           break; |           break; | ||||||
|  |         case 19: | ||||||
|  |           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; | ||||||
|  |           break; | ||||||
|       } |       } | ||||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); |       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||||
|     } |     } | ||||||
| @@ -190,6 +207,24 @@ const EditChannel = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const addCustomModel = () => { | ||||||
|  |     if (customModel.trim() === '') return; | ||||||
|  |     if (inputs.models.includes(customModel)) return; | ||||||
|  |     let localModels = [...inputs.models]; | ||||||
|  |     localModels.push(customModel); | ||||||
|  |     let localModelOptions = []; | ||||||
|  |     localModelOptions.push({ | ||||||
|  |       key: customModel, | ||||||
|  |       text: customModel, | ||||||
|  |       value: customModel | ||||||
|  |     }); | ||||||
|  |     setModelOptions(modelOptions => { | ||||||
|  |       return [...modelOptions, ...localModelOptions]; | ||||||
|  |     }); | ||||||
|  |     setCustomModel(''); | ||||||
|  |     handleInputChange(null, { name: 'models', value: localModels }); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <> |     <> | ||||||
|       <Segment loading={loading}> |       <Segment loading={loading}> | ||||||
| @@ -292,6 +327,20 @@ const EditChannel = () => { | |||||||
|               </Form.Field> |               </Form.Field> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|  |           { | ||||||
|  |             inputs.type === 21 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='知识库 ID' | ||||||
|  |                   name='other' | ||||||
|  |                   placeholder={'请输入知识库 ID,例如:123456'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.other} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Dropdown |             <Form.Dropdown | ||||||
|               label='模型' |               label='模型' | ||||||
| @@ -319,29 +368,19 @@ const EditChannel = () => { | |||||||
|             }}>清除所有模型</Button> |             }}>清除所有模型</Button> | ||||||
|             <Input |             <Input | ||||||
|               action={ |               action={ | ||||||
|                 <Button type={'button'} onClick={() => { |                 <Button type={'button'} onClick={addCustomModel}>填入</Button> | ||||||
|                   if (customModel.trim() === '') return; |  | ||||||
|                   if (inputs.models.includes(customModel)) return; |  | ||||||
|                   let localModels = [...inputs.models]; |  | ||||||
|                   localModels.push(customModel); |  | ||||||
|                   let localModelOptions = []; |  | ||||||
|                   localModelOptions.push({ |  | ||||||
|                     key: customModel, |  | ||||||
|                     text: customModel, |  | ||||||
|                     value: customModel |  | ||||||
|                   }); |  | ||||||
|                   setModelOptions(modelOptions => { |  | ||||||
|                     return [...modelOptions, ...localModelOptions]; |  | ||||||
|                   }); |  | ||||||
|                   setCustomModel(''); |  | ||||||
|                   handleInputChange(null, { name: 'models', value: localModels }); |  | ||||||
|                 }}>填入</Button> |  | ||||||
|               } |               } | ||||||
|               placeholder='输入自定义模型名称' |               placeholder='输入自定义模型名称' | ||||||
|               value={customModel} |               value={customModel} | ||||||
|               onChange={(e, { value }) => { |               onChange={(e, { value }) => { | ||||||
|                 setCustomModel(value); |                 setCustomModel(value); | ||||||
|               }} |               }} | ||||||
|  |               onKeyDown={(e) => { | ||||||
|  |                 if (e.key === 'Enter') { | ||||||
|  |                   addCustomModel(); | ||||||
|  |                   e.preventDefault(); | ||||||
|  |                 } | ||||||
|  |               }} | ||||||
|             /> |             /> | ||||||
|           </div> |           </div> | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
| @@ -372,7 +411,7 @@ const EditChannel = () => { | |||||||
|                 label='密钥' |                 label='密钥' | ||||||
|                 name='key' |                 name='key' | ||||||
|                 required |                 required | ||||||
|                 placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} |                 placeholder={type2secretPrompt(inputs.type)} | ||||||
|                 onChange={handleInputChange} |                 onChange={handleInputChange} | ||||||
|                 value={inputs.key} |                 value={inputs.key} | ||||||
|                 autoComplete='new-password' |                 autoComplete='new-password' | ||||||
| @@ -390,7 +429,7 @@ const EditChannel = () => { | |||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           { |           { | ||||||
|             inputs.type !== 3 && inputs.type !== 8 && ( |             inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( | ||||||
|               <Form.Field> |               <Form.Field> | ||||||
|                 <Form.Input |                 <Form.Input | ||||||
|                   label='代理' |                   label='代理' | ||||||
| @@ -403,6 +442,20 @@ const EditChannel = () => { | |||||||
|               </Form.Field> |               </Form.Field> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|  |           { | ||||||
|  |             inputs.type === 22 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='私有部署地址' | ||||||
|  |                   name='base_url' | ||||||
|  |                   placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.base_url} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|           <Button onClick={handleCancel}>取消</Button> |           <Button onClick={handleCancel}>取消</Button> | ||||||
|           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> |           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> | ||||||
|         </Form> |         </Form> | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user