mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-23 01:43:42 +08:00 
			
		
		
		
	Compare commits
	
		
			50 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 42569c83c0 | ||
|  | b373882814 | ||
|  | e2ed0399f0 | ||
|  | eed9f5fdf0 | ||
|  | f2c51a494c | ||
|  | 8a4d6f3327 | ||
|  | cf4e33cb12 | ||
|  | 5d60305570 | ||
|  | d062bc60e4 | ||
|  | 39c1882970 | ||
|  | 9c42c7dfd9 | ||
|  | 903aaeded0 | ||
|  | bdd4be562d | ||
|  | 37afb313b5 | ||
|  | c9ebcab8b8 | ||
|  | 86261cc656 | ||
|  | 8491785c9d | ||
|  | e848a3f7fa | ||
|  | 318adf5985 | ||
|  | 965d7fc3d2 | ||
|  | aa3f605894 | ||
|  | 7b8eff1f22 | ||
|  | e80cd508ba | ||
|  | d37f836d53 | ||
|  | e0b2d1ae47 | ||
|  | 797ead686b | ||
|  | 0d22cf9ead | ||
|  | 48989d4a0b | ||
|  | 6227eee5bc | ||
|  | cbf8f07747 | ||
|  | 4a96031ce6 | ||
|  | 92886093ae | ||
|  | 0c022f17cb | ||
|  | 83f95935de | ||
|  | aa03c89133 | ||
|  | 505817ca17 | ||
|  | cb5a3df616 | ||
|  | 7772064d87 | ||
|  | c50c609565 | ||
|  | 498dea2dbb | ||
|  | c725cc8842 | ||
|  | af8908db54 | ||
|  | d8029550f7 | ||
|  | f44fbe3fe7 | ||
|  | 1c8922153d | ||
|  | f3c07e1451 | ||
|  | 40ceb29e54 | ||
|  | 0699ecd0af | ||
|  | ee9e746520 | ||
|  | a763681c2e | 
							
								
								
									
										11
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -7,6 +7,11 @@ on: | ||||
|     tags: | ||||
|       - '*' | ||||
|       - '!*-alpha*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   release: | ||||
|     runs-on: ubuntu-latest | ||||
| @@ -18,13 +23,13 @@ jobs: | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|       - name: Build Frontend | ||||
|       - name: Build Frontend (theme default) | ||||
|         env: | ||||
|           CI: "" | ||||
|         run: | | ||||
|           cd web | ||||
|           npm install | ||||
|           REACT_APP_VERSION=$(git describe --tags) npm run build | ||||
|           git describe --tags > VERSION | ||||
|           REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh | ||||
|           cd .. | ||||
|       - name: Set up Go | ||||
|         uses: actions/setup-go@v3 | ||||
|   | ||||
							
								
								
									
										11
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -7,6 +7,11 @@ on: | ||||
|     tags: | ||||
|       - '*' | ||||
|       - '!*-alpha*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   release: | ||||
|     runs-on: macos-latest | ||||
| @@ -18,13 +23,13 @@ jobs: | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|       - name: Build Frontend | ||||
|       - name: Build Frontend (theme default) | ||||
|         env: | ||||
|           CI: "" | ||||
|         run: | | ||||
|           cd web | ||||
|           npm install | ||||
|           REACT_APP_VERSION=$(git describe --tags) npm run build | ||||
|           git describe --tags > VERSION | ||||
|           REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh | ||||
|           cd .. | ||||
|       - name: Set up Go | ||||
|         uses: actions/setup-go@v3 | ||||
|   | ||||
							
								
								
									
										11
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -7,6 +7,11 @@ on: | ||||
|     tags: | ||||
|       - '*' | ||||
|       - '!*-alpha*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   release: | ||||
|     runs-on: windows-latest | ||||
| @@ -21,14 +26,14 @@ jobs: | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|       - name: Build Frontend | ||||
|       - name: Build Frontend (theme default) | ||||
|         env: | ||||
|           CI: "" | ||||
|         run: | | ||||
|           cd web | ||||
|           cd web/default | ||||
|           npm install | ||||
|           REACT_APP_VERSION=$(git describe --tags) npm run build | ||||
|           cd .. | ||||
|           cd ../.. | ||||
|       - name: Set up Go | ||||
|         uses: actions/setup-go@v3 | ||||
|         with: | ||||
|   | ||||
							
								
								
									
										17
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								Dockerfile
									
									
									
									
									
								
							| @@ -1,10 +1,15 @@ | ||||
| FROM node:16 as builder | ||||
|  | ||||
| WORKDIR /build | ||||
| COPY web/package.json . | ||||
| RUN npm install | ||||
| COPY ./web . | ||||
| WORKDIR /web | ||||
| COPY ./VERSION . | ||||
| COPY ./web . | ||||
|  | ||||
| WORKDIR /web/default | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| WORKDIR /web/berry | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| FROM golang AS builder2 | ||||
| @@ -17,7 +22,7 @@ WORKDIR /build | ||||
| ADD go.mod go.sum ./ | ||||
| RUN go mod download | ||||
| COPY . . | ||||
| COPY --from=builder /build/build ./web/build | ||||
| COPY --from=builder /web/build ./web/build | ||||
| RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||
|  | ||||
| FROM alpine | ||||
| @@ -30,4 +35,4 @@ RUN apk update \ | ||||
| COPY --from=builder2 /build/one-api / | ||||
| EXPOSE 3000 | ||||
| WORKDIR /data | ||||
| ENTRYPOINT ["/one-api"] | ||||
| ENTRYPOINT ["/one-api"] | ||||
| @@ -3,7 +3,7 @@ | ||||
| </p> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
| </p> | ||||
|  | ||||
| <div align="center"> | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| </p> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
| </p> | ||||
|  | ||||
| <div align="center"> | ||||
|   | ||||
| @@ -4,7 +4,7 @@ | ||||
|  | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
| </p> | ||||
|  | ||||
| <div align="center"> | ||||
| @@ -99,6 +99,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||
|  | ||||
| ## 部署 | ||||
| ### 基于 Docker 进行部署 | ||||
| @@ -366,6 +367,8 @@ graph LR | ||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
| @@ -411,6 +414,9 @@ https://openai.justsong.cn | ||||
| 8. 升级之前数据库需要做变更吗? | ||||
|    + 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||
|    + 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||
| 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? | ||||
|    + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 | ||||
|    + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 | ||||
|  | ||||
| ## 相关项目 | ||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||
|   | ||||
							
								
								
									
										127
									
								
								common/config/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								common/config/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,127 @@ | ||||
| package config | ||||
|  | ||||
| import ( | ||||
| 	"one-api/common/helper" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| ) | ||||
|  | ||||
| var SystemName = "One API" | ||||
| var ServerAddress = "http://localhost:3000" | ||||
| var Footer = "" | ||||
| var Logo = "" | ||||
| var TopUpLink = "" | ||||
| var ChatLink = "" | ||||
| var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | ||||
| var DisplayInCurrencyEnabled = true | ||||
| var DisplayTokenStatEnabled = true | ||||
|  | ||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||
|  | ||||
| var SessionSecret = uuid.New().String() | ||||
|  | ||||
| var OptionMap map[string]string | ||||
| var OptionMapRWMutex sync.RWMutex | ||||
|  | ||||
| var ItemsPerPage = 10 | ||||
| var MaxRecentItems = 100 | ||||
|  | ||||
| var PasswordLoginEnabled = true | ||||
| var PasswordRegisterEnabled = true | ||||
| var EmailVerificationEnabled = false | ||||
| var GitHubOAuthEnabled = false | ||||
| var WeChatAuthEnabled = false | ||||
| var TurnstileCheckEnabled = false | ||||
| var RegisterEnabled = true | ||||
|  | ||||
| var EmailDomainRestrictionEnabled = false | ||||
| var EmailDomainWhitelist = []string{ | ||||
| 	"gmail.com", | ||||
| 	"163.com", | ||||
| 	"126.com", | ||||
| 	"qq.com", | ||||
| 	"outlook.com", | ||||
| 	"hotmail.com", | ||||
| 	"icloud.com", | ||||
| 	"yahoo.com", | ||||
| 	"foxmail.com", | ||||
| } | ||||
|  | ||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||
|  | ||||
| var LogConsumeEnabled = true | ||||
|  | ||||
| var SMTPServer = "" | ||||
| var SMTPPort = 587 | ||||
| var SMTPAccount = "" | ||||
| var SMTPFrom = "" | ||||
| var SMTPToken = "" | ||||
|  | ||||
| var GitHubClientId = "" | ||||
| var GitHubClientSecret = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
|  | ||||
| var TurnstileSiteKey = "" | ||||
| var TurnstileSecretKey = "" | ||||
|  | ||||
| var QuotaForNewUser = 0 | ||||
| var QuotaForInviter = 0 | ||||
| var QuotaForInvitee = 0 | ||||
| var ChannelDisableThreshold = 5.0 | ||||
| var AutomaticDisableChannelEnabled = false | ||||
| var AutomaticEnableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| var RetryTimes = 0 | ||||
|  | ||||
| var RootUserEmail = "" | ||||
|  | ||||
| var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||
|  | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | ||||
|  | ||||
| var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second | ||||
|  | ||||
| var BatchUpdateEnabled = false | ||||
| var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) | ||||
|  | ||||
| var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second | ||||
|  | ||||
| var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | ||||
|  | ||||
| var Theme = helper.GetOrDefaultEnvString("THEME", "default") | ||||
| var ValidThemes = map[string]bool{ | ||||
| 	"default": true, | ||||
| 	"berry":   true, | ||||
| } | ||||
|  | ||||
| // All duration's unit is seconds | ||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | ||||
| var ( | ||||
| 	GlobalApiRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) | ||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	GlobalWebRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) | ||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	UploadRateLimitNum            = 10 | ||||
| 	UploadRateLimitDuration int64 = 60 | ||||
|  | ||||
| 	DownloadRateLimitNum            = 10 | ||||
| 	DownloadRateLimitDuration int64 = 60 | ||||
|  | ||||
| 	CriticalRateLimitNum            = 20 | ||||
| 	CriticalRateLimitDuration int64 = 20 * 60 | ||||
| ) | ||||
|  | ||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||
| @@ -1,106 +1,9 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| ) | ||||
| import "time" | ||||
|  | ||||
| var StartTime = time.Now().Unix() // unit: second | ||||
| var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change | ||||
| var SystemName = "One API" | ||||
| var ServerAddress = "http://localhost:3000" | ||||
| var Footer = "" | ||||
| var Logo = "" | ||||
| var TopUpLink = "" | ||||
| var ChatLink = "" | ||||
| var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | ||||
| var DisplayInCurrencyEnabled = true | ||||
| var DisplayTokenStatEnabled = true | ||||
|  | ||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||
|  | ||||
| var SessionSecret = uuid.New().String() | ||||
|  | ||||
| var OptionMap map[string]string | ||||
| var OptionMapRWMutex sync.RWMutex | ||||
|  | ||||
| var ItemsPerPage = 10 | ||||
| var MaxRecentItems = 100 | ||||
|  | ||||
| var PasswordLoginEnabled = true | ||||
| var PasswordRegisterEnabled = true | ||||
| var EmailVerificationEnabled = false | ||||
| var GitHubOAuthEnabled = false | ||||
| var WeChatAuthEnabled = false | ||||
| var TurnstileCheckEnabled = false | ||||
| var RegisterEnabled = true | ||||
|  | ||||
| var EmailDomainRestrictionEnabled = false | ||||
| var EmailDomainWhitelist = []string{ | ||||
| 	"gmail.com", | ||||
| 	"163.com", | ||||
| 	"126.com", | ||||
| 	"qq.com", | ||||
| 	"outlook.com", | ||||
| 	"hotmail.com", | ||||
| 	"icloud.com", | ||||
| 	"yahoo.com", | ||||
| 	"foxmail.com", | ||||
| } | ||||
|  | ||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||
|  | ||||
| var LogConsumeEnabled = true | ||||
|  | ||||
| var SMTPServer = "" | ||||
| var SMTPPort = 587 | ||||
| var SMTPAccount = "" | ||||
| var SMTPFrom = "" | ||||
| var SMTPToken = "" | ||||
|  | ||||
| var GitHubClientId = "" | ||||
| var GitHubClientSecret = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
|  | ||||
| var TurnstileSiteKey = "" | ||||
| var TurnstileSecretKey = "" | ||||
|  | ||||
| var QuotaForNewUser = 0 | ||||
| var QuotaForInviter = 0 | ||||
| var QuotaForInvitee = 0 | ||||
| var ChannelDisableThreshold = 5.0 | ||||
| var AutomaticDisableChannelEnabled = false | ||||
| var AutomaticEnableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| var RetryTimes = 0 | ||||
|  | ||||
| var RootUserEmail = "" | ||||
|  | ||||
| var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||
|  | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | ||||
|  | ||||
| var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||
|  | ||||
| var BatchUpdateEnabled = false | ||||
| var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||
|  | ||||
| var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second | ||||
|  | ||||
| const ( | ||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	RoleGuestUser  = 0 | ||||
| @@ -109,34 +12,6 @@ const ( | ||||
| 	RoleRootUser   = 100 | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	FileUploadPermission    = RoleGuestUser | ||||
| 	FileDownloadPermission  = RoleGuestUser | ||||
| 	ImageUploadPermission   = RoleGuestUser | ||||
| 	ImageDownloadPermission = RoleGuestUser | ||||
| ) | ||||
|  | ||||
| // All duration's unit is seconds | ||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | ||||
| var ( | ||||
| 	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) | ||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) | ||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	UploadRateLimitNum            = 10 | ||||
| 	UploadRateLimitDuration int64 = 60 | ||||
|  | ||||
| 	DownloadRateLimitNum            = 10 | ||||
| 	DownloadRateLimitDuration int64 = 60 | ||||
|  | ||||
| 	CriticalRateLimitNum            = 20 | ||||
| 	CriticalRateLimitDuration int64 = 20 * 60 | ||||
| ) | ||||
|  | ||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||
|  | ||||
| const ( | ||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	UserStatusDisabled = 2 // also don't use 0 | ||||
| @@ -191,29 +66,29 @@ const ( | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| 	"",                                  // 0 | ||||
| 	"https://api.openai.com",            // 1 | ||||
| 	"https://oa.api2d.net",              // 2 | ||||
| 	"",                                  // 3 | ||||
| 	"https://api.closeai-proxy.xyz",     // 4 | ||||
| 	"https://api.openai-sb.com",         // 5 | ||||
| 	"https://api.openaimax.com",         // 6 | ||||
| 	"https://api.ohmygpt.com",           // 7 | ||||
| 	"",                                  // 8 | ||||
| 	"https://api.caipacity.com",         // 9 | ||||
| 	"https://api.aiproxy.io",            // 10 | ||||
| 	"",                                  // 11 | ||||
| 	"https://api.api2gpt.com",           // 12 | ||||
| 	"https://api.aigc2d.com",            // 13 | ||||
| 	"https://api.anthropic.com",         // 14 | ||||
| 	"https://aip.baidubce.com",          // 15 | ||||
| 	"https://open.bigmodel.cn",          // 16 | ||||
| 	"https://dashscope.aliyuncs.com",    // 17 | ||||
| 	"",                                  // 18 | ||||
| 	"https://ai.360.cn",                 // 19 | ||||
| 	"https://openrouter.ai/api",         // 20 | ||||
| 	"https://api.aiproxy.io",            // 21 | ||||
| 	"https://fastgpt.run/api/openapi",   // 22 | ||||
| 	"https://hunyuan.cloud.tencent.com", //23 | ||||
| 	"",                                  //24 | ||||
| 	"",                              // 0 | ||||
| 	"https://api.openai.com",        // 1 | ||||
| 	"https://oa.api2d.net",          // 2 | ||||
| 	"",                              // 3 | ||||
| 	"https://api.closeai-proxy.xyz", // 4 | ||||
| 	"https://api.openai-sb.com",     // 5 | ||||
| 	"https://api.openaimax.com",     // 6 | ||||
| 	"https://api.ohmygpt.com",       // 7 | ||||
| 	"",                              // 8 | ||||
| 	"https://api.caipacity.com",     // 9 | ||||
| 	"https://api.aiproxy.io",        // 10 | ||||
| 	"https://generativelanguage.googleapis.com", // 11 | ||||
| 	"https://api.api2gpt.com",                   // 12 | ||||
| 	"https://api.aigc2d.com",                    // 13 | ||||
| 	"https://api.anthropic.com",                 // 14 | ||||
| 	"https://aip.baidubce.com",                  // 15 | ||||
| 	"https://open.bigmodel.cn",                  // 16 | ||||
| 	"https://dashscope.aliyuncs.com",            // 17 | ||||
| 	"",                                          // 18 | ||||
| 	"https://ai.360.cn",                         // 19 | ||||
| 	"https://openrouter.ai/api",                 // 20 | ||||
| 	"https://api.aiproxy.io",                    // 21 | ||||
| 	"https://fastgpt.run/api/openapi",           // 22 | ||||
| 	"https://hunyuan.cloud.tencent.com",         // 23 | ||||
| 	"https://generativelanguage.googleapis.com", // 24 | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,9 @@ | ||||
| package common | ||||
|  | ||||
| import "one-api/common/helper" | ||||
|  | ||||
| var UsingSQLite = false | ||||
| var UsingPostgreSQL = false | ||||
|  | ||||
| var SQLitePath = "one-api.db" | ||||
| var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) | ||||
| var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) | ||||
|   | ||||
| @@ -6,18 +6,19 @@ import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"net/smtp" | ||||
| 	"one-api/common/config" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| 	if SMTPFrom == "" { // for compatibility | ||||
| 		SMTPFrom = SMTPAccount | ||||
| 	if config.SMTPFrom == "" { // for compatibility | ||||
| 		config.SMTPFrom = config.SMTPAccount | ||||
| 	} | ||||
| 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) | ||||
|  | ||||
| 	// Extract domain from SMTPFrom | ||||
| 	parts := strings.Split(SMTPFrom, "@") | ||||
| 	parts := strings.Split(config.SMTPFrom, "@") | ||||
| 	var domain string | ||||
| 	if len(parts) > 1 { | ||||
| 		domain = parts[1] | ||||
| @@ -36,21 +37,21 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 | ||||
| 		"Date: %s\r\n"+ | ||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||
| 	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) | ||||
| 	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) | ||||
| 		receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||
| 	auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) | ||||
| 	addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) | ||||
| 	to := strings.Split(receiver, ";") | ||||
|  | ||||
| 	if SMTPPort == 465 { | ||||
| 	if config.SMTPPort == 465 { | ||||
| 		tlsConfig := &tls.Config{ | ||||
| 			InsecureSkipVerify: true, | ||||
| 			ServerName:         SMTPServer, | ||||
| 			ServerName:         config.SMTPServer, | ||||
| 		} | ||||
| 		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) | ||||
| 		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		client, err := smtp.NewClient(conn, SMTPServer) | ||||
| 		client, err := smtp.NewClient(conn, config.SMTPServer) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -58,7 +59,7 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 		if err = client.Auth(auth); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if err = client.Mail(SMTPFrom); err != nil { | ||||
| 		if err = client.Mail(config.SMTPFrom); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		receiverEmails := strings.Split(receiver, ";") | ||||
| @@ -80,7 +81,7 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) | ||||
| 		err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
|   | ||||
| @@ -31,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func SetEventStreamHeaders(c *gin.Context) { | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,9 @@ | ||||
| package common | ||||
|  | ||||
| import "encoding/json" | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| var GroupRatio = map[string]float64{ | ||||
| 	"default": 1, | ||||
| @@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{ | ||||
| func GroupRatio2JSONString() string { | ||||
| 	jsonBytes, err := json.Marshal(GroupRatio) | ||||
| 	if err != nil { | ||||
| 		SysError("error marshalling model ratio: " + err.Error()) | ||||
| 		logger.SysError("error marshalling model ratio: " + err.Error()) | ||||
| 	} | ||||
| 	return string(jsonBytes) | ||||
| } | ||||
| @@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error { | ||||
| func GetGroupRatio(name string) float64 { | ||||
| 	ratio, ok := GroupRatio[name] | ||||
| 	if !ok { | ||||
| 		SysError("group ratio not found: " + name) | ||||
| 		logger.SysError("group ratio not found: " + name) | ||||
| 		return 1 | ||||
| 	} | ||||
| 	return ratio | ||||
|   | ||||
							
								
								
									
										224
									
								
								common/helper/helper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										224
									
								
								common/helper/helper.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,224 @@ | ||||
| package helper | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/google/uuid" | ||||
| 	"html/template" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"one-api/common/logger" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func OpenBrowser(url string) { | ||||
| 	var err error | ||||
|  | ||||
| 	switch runtime.GOOS { | ||||
| 	case "linux": | ||||
| 		err = exec.Command("xdg-open", url).Start() | ||||
| 	case "windows": | ||||
| 		err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() | ||||
| 	case "darwin": | ||||
| 		err = exec.Command("open", url).Start() | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetIp() (ip string) { | ||||
| 	ips, err := net.InterfaceAddrs() | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return ip | ||||
| 	} | ||||
|  | ||||
| 	for _, a := range ips { | ||||
| 		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { | ||||
| 			if ipNet.IP.To4() != nil { | ||||
| 				ip = ipNet.IP.String() | ||||
| 				if strings.HasPrefix(ip, "10") { | ||||
| 					return | ||||
| 				} | ||||
| 				if strings.HasPrefix(ip, "172") { | ||||
| 					return | ||||
| 				} | ||||
| 				if strings.HasPrefix(ip, "192.168") { | ||||
| 					return | ||||
| 				} | ||||
| 				ip = "" | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| var sizeKB = 1024 | ||||
| var sizeMB = sizeKB * 1024 | ||||
| var sizeGB = sizeMB * 1024 | ||||
|  | ||||
| func Bytes2Size(num int64) string { | ||||
| 	numStr := "" | ||||
| 	unit := "B" | ||||
| 	if num/int64(sizeGB) > 1 { | ||||
| 		numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) | ||||
| 		unit = "GB" | ||||
| 	} else if num/int64(sizeMB) > 1 { | ||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) | ||||
| 		unit = "MB" | ||||
| 	} else if num/int64(sizeKB) > 1 { | ||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) | ||||
| 		unit = "KB" | ||||
| 	} else { | ||||
| 		numStr = fmt.Sprintf("%d", num) | ||||
| 	} | ||||
| 	return numStr + " " + unit | ||||
| } | ||||
|  | ||||
| func Seconds2Time(num int) (time string) { | ||||
| 	if num/31104000 > 0 { | ||||
| 		time += strconv.Itoa(num/31104000) + " 年 " | ||||
| 		num %= 31104000 | ||||
| 	} | ||||
| 	if num/2592000 > 0 { | ||||
| 		time += strconv.Itoa(num/2592000) + " 个月 " | ||||
| 		num %= 2592000 | ||||
| 	} | ||||
| 	if num/86400 > 0 { | ||||
| 		time += strconv.Itoa(num/86400) + " 天 " | ||||
| 		num %= 86400 | ||||
| 	} | ||||
| 	if num/3600 > 0 { | ||||
| 		time += strconv.Itoa(num/3600) + " 小时 " | ||||
| 		num %= 3600 | ||||
| 	} | ||||
| 	if num/60 > 0 { | ||||
| 		time += strconv.Itoa(num/60) + " 分钟 " | ||||
| 		num %= 60 | ||||
| 	} | ||||
| 	time += strconv.Itoa(num) + " 秒" | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func Interface2String(inter interface{}) string { | ||||
| 	switch inter.(type) { | ||||
| 	case string: | ||||
| 		return inter.(string) | ||||
| 	case int: | ||||
| 		return fmt.Sprintf("%d", inter.(int)) | ||||
| 	case float64: | ||||
| 		return fmt.Sprintf("%f", inter.(float64)) | ||||
| 	} | ||||
| 	return "Not Implemented" | ||||
| } | ||||
|  | ||||
| func UnescapeHTML(x string) interface{} { | ||||
| 	return template.HTML(x) | ||||
| } | ||||
|  | ||||
| func IntMax(a int, b int) int { | ||||
| 	if a >= b { | ||||
| 		return a | ||||
| 	} else { | ||||
| 		return b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUUID() string { | ||||
| 	code := uuid.New().String() | ||||
| 	code = strings.Replace(code, "-", "", -1) | ||||
| 	return code | ||||
| } | ||||
|  | ||||
| const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||
|  | ||||
| func init() { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| } | ||||
|  | ||||
| func GenerateKey() string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, 48) | ||||
| 	for i := 0; i < 16; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	uuid_ := GetUUID() | ||||
| 	for i := 0; i < 32; i++ { | ||||
| 		c := uuid_[i] | ||||
| 		if i%2 == 0 && c >= 'a' && c <= 'z' { | ||||
| 			c = c - 'a' + 'A' | ||||
| 		} | ||||
| 		key[i+16] = c | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetRandomString(length int) string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, length) | ||||
| 	for i := 0; i < length; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetTimestamp() int64 { | ||||
| 	return time.Now().Unix() | ||||
| } | ||||
|  | ||||
| func GetTimeString() string { | ||||
| 	now := time.Now() | ||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||
| } | ||||
|  | ||||
| func Max(a int, b int) int { | ||||
| 	if a >= b { | ||||
| 		return a | ||||
| 	} else { | ||||
| 		return b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvInt(env string, defaultValue int) int { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.Atoi(os.Getenv(env)) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvString(env string, defaultValue string) string { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) | ||||
| } | ||||
|  | ||||
| func AssignOrDefault(value string, defaultValue string) string { | ||||
| 	if len(value) != 0 { | ||||
| 		return value | ||||
| 	} | ||||
| 	return defaultValue | ||||
| } | ||||
|  | ||||
| func MessageWithRequestId(message string, id string) string { | ||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||
| } | ||||
|  | ||||
| func String2Int(str string) int { | ||||
| 	num, err := strconv.Atoi(str) | ||||
| 	if err != nil { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
| @@ -1,6 +1,8 @@ | ||||
| package image | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/base64" | ||||
| 	"image" | ||||
| 	_ "image/gif" | ||||
| 	_ "image/jpeg" | ||||
| @@ -8,11 +10,30 @@ import ( | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
|  | ||||
| 	_ "golang.org/x/image/webp" | ||||
| ) | ||||
|  | ||||
| // Regex to match data URL pattern | ||||
| var	dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) | ||||
|  | ||||
| func IsImageUrl(url string) (bool, error) { | ||||
| 	resp, err := http.Head(url) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 	return true, nil | ||||
| } | ||||
|  | ||||
| func GetImageSizeFromUrl(url string) (width int, height int, err error) { | ||||
| 	isImage, err := IsImageUrl(url) | ||||
| 	if !isImage { | ||||
| 		return | ||||
| 	} | ||||
| 	resp, err := http.Get(url) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| @@ -25,17 +46,60 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { | ||||
| 	return img.Width, img.Height, nil | ||||
| } | ||||
|  | ||||
| func GetImageFromUrl(url string) (mimeType string, data string, err error) { | ||||
| 	// Check if the URL is a data URL | ||||
| 	matches := dataURLPattern.FindStringSubmatch(url) | ||||
| 	if len(matches) == 3 { | ||||
| 		// URL is a data URL | ||||
| 		mimeType = "image/" + matches[1] | ||||
| 		data = matches[2] | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	isImage, err := IsImageUrl(url) | ||||
| 	if !isImage { | ||||
| 		return | ||||
| 	} | ||||
| 	resp, err := http.Get(url) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	buffer := bytes.NewBuffer(nil) | ||||
| 	_, err = buffer.ReadFrom(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	mimeType = resp.Header.Get("Content-Type") | ||||
| 	data = base64.StdEncoding.EncodeToString(buffer.Bytes()) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	reg = regexp.MustCompile(`data:image/([^;]+);base64,`) | ||||
| ) | ||||
|  | ||||
| var readerPool = sync.Pool{ | ||||
| 	New: func() interface{} { | ||||
| 		return &bytes.Reader{} | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { | ||||
| 	encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") | ||||
| 	base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) | ||||
| 	img, _, err := image.DecodeConfig(base64) | ||||
| 	decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, "")) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	reader := readerPool.Get().(*bytes.Reader) | ||||
| 	defer readerPool.Put(reader) | ||||
| 	reader.Reset(decoded) | ||||
|  | ||||
| 	img, _, err := image.DecodeConfig(reader) | ||||
| 	if err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	return img.Width, img.Height, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestGetImageSizeFromBase64(t *testing.T) { | ||||
| 	for i, c := range cases { | ||||
| 		t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { | ||||
| 			resp, err := http.Get(c.url) | ||||
| 			assert.NoError(t, err) | ||||
| 			defer resp.Body.Close() | ||||
| 			data, err := io.ReadAll(resp.Body) | ||||
| 			assert.NoError(t, err) | ||||
| 			encoded := base64.StdEncoding.EncodeToString(data) | ||||
| 			width, height, err := img.GetImageSizeFromBase64(encoded) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, c.width, width) | ||||
| 			assert.Equal(t, c.height, height) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,8 @@ import ( | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| ) | ||||
| @@ -37,9 +39,9 @@ func init() { | ||||
|  | ||||
| 	if os.Getenv("SESSION_SECRET") != "" { | ||||
| 		if os.Getenv("SESSION_SECRET") == "random_string" { | ||||
| 			SysError("SESSION_SECRET is set to an example value, please change it to a random string.") | ||||
| 			logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") | ||||
| 		} else { | ||||
| 			SessionSecret = os.Getenv("SESSION_SECRET") | ||||
| 			config.SessionSecret = os.Getenv("SESSION_SECRET") | ||||
| 		} | ||||
| 	} | ||||
| 	if os.Getenv("SQLITE_PATH") != "" { | ||||
| @@ -57,5 +59,6 @@ func init() { | ||||
| 				log.Fatal(err) | ||||
| 			} | ||||
| 		} | ||||
| 		logger.LogDir = *LogDir | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										7
									
								
								common/logger/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								common/logger/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package logger | ||||
|  | ||||
| const ( | ||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | ||||
| ) | ||||
|  | ||||
| var LogDir string | ||||
| @@ -1,4 +1,4 @@ | ||||
| package common | ||||
| package logger | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| @@ -25,7 +25,7 @@ var setupLogLock sync.Mutex | ||||
| var setupLogWorking bool | ||||
| 
 | ||||
| func SetupLogger() { | ||||
| 	if *LogDir != "" { | ||||
| 	if LogDir != "" { | ||||
| 		ok := setupLogLock.TryLock() | ||||
| 		if !ok { | ||||
| 			log.Println("setup log is already working") | ||||
| @@ -35,7 +35,7 @@ func SetupLogger() { | ||||
| 			setupLogLock.Unlock() | ||||
| 			setupLogWorking = false | ||||
| 		}() | ||||
| 		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 		logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||
| 		if err != nil { | ||||
| 			log.Fatal("failed to open log file") | ||||
| @@ -55,18 +55,30 @@ func SysError(s string) { | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||
| } | ||||
| 
 | ||||
| func LogInfo(ctx context.Context, msg string) { | ||||
| func Info(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerINFO, msg) | ||||
| } | ||||
| 
 | ||||
| func LogWarn(ctx context.Context, msg string) { | ||||
| func Warn(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerWarn, msg) | ||||
| } | ||||
| 
 | ||||
| func LogError(ctx context.Context, msg string) { | ||||
| func Error(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerError, msg) | ||||
| } | ||||
| 
 | ||||
| func Infof(ctx context.Context, format string, a ...any) { | ||||
| 	Info(ctx, fmt.Sprintf(format, a)) | ||||
| } | ||||
| 
 | ||||
| func Warnf(ctx context.Context, format string, a ...any) { | ||||
| 	Warn(ctx, fmt.Sprintf(format, a)) | ||||
| } | ||||
| 
 | ||||
| func Errorf(ctx context.Context, format string, a ...any) { | ||||
| 	Error(ctx, fmt.Sprintf(format, a)) | ||||
| } | ||||
| 
 | ||||
| func logHelper(ctx context.Context, level string, msg string) { | ||||
| 	writer := gin.DefaultErrorWriter | ||||
| 	if level == loggerINFO { | ||||
| @@ -90,11 +102,3 @@ func FatalLog(v ...any) { | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||
| 	os.Exit(1) | ||||
| } | ||||
| 
 | ||||
| func LogQuota(quota int) string { | ||||
| 	if DisplayInCurrencyEnabled { | ||||
| 		return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) | ||||
| 	} else { | ||||
| 		return fmt.Sprintf("%d 点额度", quota) | ||||
| 	} | ||||
| } | ||||
| @@ -2,6 +2,7 @@ package common | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"one-api/common/logger" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| @@ -52,6 +53,8 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||
| 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-1106":        0.5,  // $0.001 / 1K tokens | ||||
| 	"davinci-002":               1,    // $0.002 / 1K tokens | ||||
| 	"babbage-002":               0.2,  // $0.0004 / 1K tokens | ||||
| 	"text-ada-001":              0.2, | ||||
| 	"text-babbage-001":          0.25, | ||||
| 	"text-curie-001":            1, | ||||
| @@ -84,6 +87,7 @@ var ModelRatio = map[string]float64{ | ||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                    1, | ||||
| 	"gemini-pro":                1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-pro-vision":         1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||
| @@ -104,7 +108,7 @@ var ModelRatio = map[string]float64{ | ||||
| func ModelRatio2JSONString() string { | ||||
| 	jsonBytes, err := json.Marshal(ModelRatio) | ||||
| 	if err != nil { | ||||
| 		SysError("error marshalling model ratio: " + err.Error()) | ||||
| 		logger.SysError("error marshalling model ratio: " + err.Error()) | ||||
| 	} | ||||
| 	return string(jsonBytes) | ||||
| } | ||||
| @@ -115,9 +119,12 @@ func UpdateModelRatioByJSONString(jsonStr string) error { | ||||
| } | ||||
|  | ||||
| func GetModelRatio(name string) float64 { | ||||
| 	if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { | ||||
| 		name = strings.TrimSuffix(name, "-internet") | ||||
| 	} | ||||
| 	ratio, ok := ModelRatio[name] | ||||
| 	if !ok { | ||||
| 		SysError("model ratio not found: " + name) | ||||
| 		logger.SysError("model ratio not found: " + name) | ||||
| 		return 30 | ||||
| 	} | ||||
| 	return ratio | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package common | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"one-api/common/logger" | ||||
| 	"os" | ||||
| 	"time" | ||||
| ) | ||||
| @@ -14,18 +15,18 @@ var RedisEnabled = true | ||||
| func InitRedisClient() (err error) { | ||||
| 	if os.Getenv("REDIS_CONN_STRING") == "" { | ||||
| 		RedisEnabled = false | ||||
| 		SysLog("REDIS_CONN_STRING not set, Redis is not enabled") | ||||
| 		logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") | ||||
| 		return nil | ||||
| 	} | ||||
| 	if os.Getenv("SYNC_FREQUENCY") == "" { | ||||
| 		RedisEnabled = false | ||||
| 		SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||
| 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||
| 		return nil | ||||
| 	} | ||||
| 	SysLog("Redis is enabled") | ||||
| 	logger.SysLog("Redis is enabled") | ||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||
| 	if err != nil { | ||||
| 		FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 	} | ||||
| 	RDB = redis.NewClient(opt) | ||||
|  | ||||
| @@ -34,7 +35,7 @@ func InitRedisClient() (err error) { | ||||
|  | ||||
| 	_, err = RDB.Ping(ctx).Result() | ||||
| 	if err != nil { | ||||
| 		FatalLog("Redis ping test failed: " + err.Error()) | ||||
| 		logger.FatalLog("Redis ping test failed: " + err.Error()) | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| @@ -42,7 +43,7 @@ func InitRedisClient() (err error) { | ||||
| func ParseRedisOption() *redis.Options { | ||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||
| 	if err != nil { | ||||
| 		FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 	} | ||||
| 	return opt | ||||
| } | ||||
|   | ||||
							
								
								
									
										205
									
								
								common/utils.go
									
									
									
									
									
								
							
							
						
						
									
										205
									
								
								common/utils.go
									
									
									
									
									
								
							| @@ -2,208 +2,13 @@ package common | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/google/uuid" | ||||
| 	"html/template" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"one-api/common/config" | ||||
| ) | ||||
|  | ||||
| func OpenBrowser(url string) { | ||||
| 	var err error | ||||
|  | ||||
| 	switch runtime.GOOS { | ||||
| 	case "linux": | ||||
| 		err = exec.Command("xdg-open", url).Start() | ||||
| 	case "windows": | ||||
| 		err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() | ||||
| 	case "darwin": | ||||
| 		err = exec.Command("open", url).Start() | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetIp() (ip string) { | ||||
| 	ips, err := net.InterfaceAddrs() | ||||
| 	if err != nil { | ||||
| 		log.Println(err) | ||||
| 		return ip | ||||
| 	} | ||||
|  | ||||
| 	for _, a := range ips { | ||||
| 		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { | ||||
| 			if ipNet.IP.To4() != nil { | ||||
| 				ip = ipNet.IP.String() | ||||
| 				if strings.HasPrefix(ip, "10") { | ||||
| 					return | ||||
| 				} | ||||
| 				if strings.HasPrefix(ip, "172") { | ||||
| 					return | ||||
| 				} | ||||
| 				if strings.HasPrefix(ip, "192.168") { | ||||
| 					return | ||||
| 				} | ||||
| 				ip = "" | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| var sizeKB = 1024 | ||||
| var sizeMB = sizeKB * 1024 | ||||
| var sizeGB = sizeMB * 1024 | ||||
|  | ||||
| func Bytes2Size(num int64) string { | ||||
| 	numStr := "" | ||||
| 	unit := "B" | ||||
| 	if num/int64(sizeGB) > 1 { | ||||
| 		numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) | ||||
| 		unit = "GB" | ||||
| 	} else if num/int64(sizeMB) > 1 { | ||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) | ||||
| 		unit = "MB" | ||||
| 	} else if num/int64(sizeKB) > 1 { | ||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) | ||||
| 		unit = "KB" | ||||
| func LogQuota(quota int) string { | ||||
| 	if config.DisplayInCurrencyEnabled { | ||||
| 		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) | ||||
| 	} else { | ||||
| 		numStr = fmt.Sprintf("%d", num) | ||||
| 	} | ||||
| 	return numStr + " " + unit | ||||
| } | ||||
|  | ||||
| func Seconds2Time(num int) (time string) { | ||||
| 	if num/31104000 > 0 { | ||||
| 		time += strconv.Itoa(num/31104000) + " 年 " | ||||
| 		num %= 31104000 | ||||
| 	} | ||||
| 	if num/2592000 > 0 { | ||||
| 		time += strconv.Itoa(num/2592000) + " 个月 " | ||||
| 		num %= 2592000 | ||||
| 	} | ||||
| 	if num/86400 > 0 { | ||||
| 		time += strconv.Itoa(num/86400) + " 天 " | ||||
| 		num %= 86400 | ||||
| 	} | ||||
| 	if num/3600 > 0 { | ||||
| 		time += strconv.Itoa(num/3600) + " 小时 " | ||||
| 		num %= 3600 | ||||
| 	} | ||||
| 	if num/60 > 0 { | ||||
| 		time += strconv.Itoa(num/60) + " 分钟 " | ||||
| 		num %= 60 | ||||
| 	} | ||||
| 	time += strconv.Itoa(num) + " 秒" | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func Interface2String(inter interface{}) string { | ||||
| 	switch inter.(type) { | ||||
| 	case string: | ||||
| 		return inter.(string) | ||||
| 	case int: | ||||
| 		return fmt.Sprintf("%d", inter.(int)) | ||||
| 	case float64: | ||||
| 		return fmt.Sprintf("%f", inter.(float64)) | ||||
| 	} | ||||
| 	return "Not Implemented" | ||||
| } | ||||
|  | ||||
| func UnescapeHTML(x string) interface{} { | ||||
| 	return template.HTML(x) | ||||
| } | ||||
|  | ||||
| func IntMax(a int, b int) int { | ||||
| 	if a >= b { | ||||
| 		return a | ||||
| 	} else { | ||||
| 		return b | ||||
| 		return fmt.Sprintf("%d 点额度", quota) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUUID() string { | ||||
| 	code := uuid.New().String() | ||||
| 	code = strings.Replace(code, "-", "", -1) | ||||
| 	return code | ||||
| } | ||||
|  | ||||
| const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||
|  | ||||
| func init() { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| } | ||||
|  | ||||
| func GenerateKey() string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, 48) | ||||
| 	for i := 0; i < 16; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	uuid_ := GetUUID() | ||||
| 	for i := 0; i < 32; i++ { | ||||
| 		c := uuid_[i] | ||||
| 		if i%2 == 0 && c >= 'a' && c <= 'z' { | ||||
| 			c = c - 'a' + 'A' | ||||
| 		} | ||||
| 		key[i+16] = c | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetRandomString(length int) string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, length) | ||||
| 	for i := 0; i < length; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetTimestamp() int64 { | ||||
| 	return time.Now().Unix() | ||||
| } | ||||
|  | ||||
| func GetTimeString() string { | ||||
| 	now := time.Now() | ||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||
| } | ||||
|  | ||||
| func Max(a int, b int) int { | ||||
| 	if a >= b { | ||||
| 		return a | ||||
| 	} else { | ||||
| 		return b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetOrDefault(env string, defaultValue int) int { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.Atoi(os.Getenv(env)) | ||||
| 	if err != nil { | ||||
| 		SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func MessageWithRequestId(message string, id string) string { | ||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||
| } | ||||
|  | ||||
| func String2Int(str string) int { | ||||
| 	num, err := strconv.Atoi(str) | ||||
| 	if err != nil { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|   | ||||
| @@ -2,8 +2,9 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/model" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| func GetSubscription(c *gin.Context) { | ||||
| @@ -12,7 +13,7 @@ func GetSubscription(c *gin.Context) { | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	var expiredTime int64 | ||||
| 	if common.DisplayTokenStatEnabled { | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		expiredTime = token.ExpiredTime | ||||
| @@ -27,19 +28,19 @@ func GetSubscription(c *gin.Context) { | ||||
| 		expiredTime = 0 | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		openAIError := OpenAIError{ | ||||
| 		Error := openai.Error{ | ||||
| 			Message: err.Error(), | ||||
| 			Type:    "upstream_error", | ||||
| 		} | ||||
| 		c.JSON(200, gin.H{ | ||||
| 			"error": openAIError, | ||||
| 			"error": Error, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	quota := remainQuota + usedQuota | ||||
| 	amount := float64(quota) | ||||
| 	if common.DisplayInCurrencyEnabled { | ||||
| 		amount /= common.QuotaPerUnit | ||||
| 	if config.DisplayInCurrencyEnabled { | ||||
| 		amount /= config.QuotaPerUnit | ||||
| 	} | ||||
| 	if token != nil && token.UnlimitedQuota { | ||||
| 		amount = 100000000 | ||||
| @@ -60,7 +61,7 @@ func GetUsage(c *gin.Context) { | ||||
| 	var quota int | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	if common.DisplayTokenStatEnabled { | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		quota = token.UsedQuota | ||||
| @@ -69,18 +70,18 @@ func GetUsage(c *gin.Context) { | ||||
| 		quota, err = model.GetUserUsedQuota(userId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		openAIError := OpenAIError{ | ||||
| 		Error := openai.Error{ | ||||
| 			Message: err.Error(), | ||||
| 			Type:    "one_api_error", | ||||
| 		} | ||||
| 		c.JSON(200, gin.H{ | ||||
| 			"error": openAIError, | ||||
| 			"error": Error, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	amount := float64(quota) | ||||
| 	if common.DisplayInCurrencyEnabled { | ||||
| 		amount /= common.QuotaPerUnit | ||||
| 	if config.DisplayInCurrencyEnabled { | ||||
| 		amount /= config.QuotaPerUnit | ||||
| 	} | ||||
| 	usage := OpenAIUsageResponse{ | ||||
| 		Object:     "list", | ||||
|   | ||||
| @@ -7,7 +7,10 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| 	"one-api/relay/util" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| @@ -92,7 +95,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | ||||
| 	for k := range headers { | ||||
| 		req.Header.Add(k, headers.Get(k)) | ||||
| 	} | ||||
| 	res, err := httpClient.Do(req) | ||||
| 	res, err := util.HTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -313,7 +316,7 @@ func updateAllChannelsBalance() error { | ||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | ||||
| 			} | ||||
| 		} | ||||
| 		time.Sleep(common.RequestInterval) | ||||
| 		time.Sleep(config.RequestInterval) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { | ||||
| func AutomaticallyUpdateChannels(frequency int) { | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||
| 		common.SysLog("updating all channels") | ||||
| 		logger.SysLog("updating all channels") | ||||
| 		_ = updateAllChannelsBalance() | ||||
| 		common.SysLog("channels update done") | ||||
| 		logger.SysLog("channels update done") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -8,7 +8,11 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/util" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| @@ -16,7 +20,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||
| func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) { | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		fallthrough | ||||
| @@ -46,13 +50,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 	} | ||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) | ||||
| 		requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) | ||||
| 	} else { | ||||
| 		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { | ||||
| 			requestURL = baseURL | ||||
| 		} | ||||
|  | ||||
| 		requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) | ||||
| 		requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(request) | ||||
| 	if err != nil { | ||||
| @@ -68,12 +72,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	resp, err := util.HTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	var response TextResponse | ||||
| 	var response openai.SlimTextResponse | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| @@ -91,12 +95,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func buildTestRequest() *ChatRequest { | ||||
| 	testRequest := &ChatRequest{ | ||||
| func buildTestRequest() *openai.ChatRequest { | ||||
| 	testRequest := &openai.ChatRequest{ | ||||
| 		Model:     "", // this will be set later | ||||
| 		MaxTokens: 1, | ||||
| 	} | ||||
| 	testMessage := Message{ | ||||
| 	testMessage := openai.Message{ | ||||
| 		Role:    "user", | ||||
| 		Content: "hi", | ||||
| 	} | ||||
| @@ -148,12 +152,12 @@ var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if common.RootUserEmail == "" { | ||||
| 		common.RootUserEmail = model.GetRootUserEmail() | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	err := common.SendEmail(subject, common.RootUserEmail, content) | ||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -174,8 +178,8 @@ func enableChannel(channelId int, channelName string) { | ||||
| } | ||||
|  | ||||
| func testAllChannels(notify bool) error { | ||||
| 	if common.RootUserEmail == "" { | ||||
| 		common.RootUserEmail = model.GetRootUserEmail() | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	testAllChannelsLock.Lock() | ||||
| 	if testAllChannelsRunning { | ||||
| @@ -189,7 +193,7 @@ func testAllChannels(notify bool) error { | ||||
| 		return err | ||||
| 	} | ||||
| 	testRequest := buildTestRequest() | ||||
| 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000) | ||||
| 	var disableThreshold = int64(config.ChannelDisableThreshold * 1000) | ||||
| 	if disableThreshold == 0 { | ||||
| 		disableThreshold = 10000000 // a impossible value | ||||
| 	} | ||||
| @@ -204,22 +208,22 @@ func testAllChannels(notify bool) error { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { | ||||
| 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { | ||||
| 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | ||||
| 				enableChannel(channel.Id, channel.Name) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
| 			time.Sleep(common.RequestInterval) | ||||
| 			time.Sleep(config.RequestInterval) | ||||
| 		} | ||||
| 		testAllChannelsLock.Lock() | ||||
| 		testAllChannelsRunning = false | ||||
| 		testAllChannelsLock.Unlock() | ||||
| 		if notify { | ||||
| 			err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | ||||
| 			err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | ||||
| 			if err != nil { | ||||
| 				common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| @@ -245,8 +249,8 @@ func TestAllChannels(c *gin.Context) { | ||||
| func AutomaticallyTestChannels(frequency int) { | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||
| 		common.SysLog("testing all channels") | ||||
| 		logger.SysLog("testing all channels") | ||||
| 		_ = testAllChannels(false) | ||||
| 		common.SysLog("channel test finished") | ||||
| 		logger.SysLog("channel test finished") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -3,7 +3,8 @@ package controller | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| @@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) | ||||
| 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	channel.CreatedTime = common.GetTimestamp() | ||||
| 	channel.CreatedTime = helper.GetTimestamp() | ||||
| 	keys := strings.Split(channel.Key, "\n") | ||||
| 	channels := make([]model.Channel, 0, len(keys)) | ||||
| 	for _, key := range keys { | ||||
|   | ||||
| @@ -9,6 +9,9 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| @@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("无效的参数") | ||||
| 	} | ||||
| 	values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} | ||||
| 	values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} | ||||
| 	jsonData, err := json.Marshal(values) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		common.SysLog(err.Error()) | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| @@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | ||||
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | ||||
| 	res2, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		common.SysLog(err.Error()) | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	defer res2.Body.Close() | ||||
| @@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !common.GitHubOAuthEnabled { | ||||
| 	if !config.GitHubOAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 GitHub 登录以及注册", | ||||
| @@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if common.RegisterEnabled { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			if githubUser.Name != "" { | ||||
| 				user.DisplayName = githubUser.Name | ||||
| @@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GitHubBind(c *gin.Context) { | ||||
| 	if !common.GitHubOAuthEnabled { | ||||
| 	if !config.GitHubOAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 GitHub 登录以及注册", | ||||
| @@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) { | ||||
|  | ||||
| func GenerateOAuthCode(c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	state := common.GetRandomString(12) | ||||
| 	state := helper.GetRandomString(12) | ||||
| 	session.Set("oauth_state", state) | ||||
| 	err := session.Save() | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -3,7 +3,7 @@ package controller | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| ) | ||||
| @@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) { | ||||
| 	tokenName := c.Query("token_name") | ||||
| 	modelName := c.Query("model_name") | ||||
| 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||
| 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) | ||||
| 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) { | ||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||
| 	tokenName := c.Query("token_name") | ||||
| 	modelName := c.Query("model_name") | ||||
| 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| @@ -18,55 +19,55 @@ func GetStatus(c *gin.Context) { | ||||
| 		"data": gin.H{ | ||||
| 			"version":             common.Version, | ||||
| 			"start_time":          common.StartTime, | ||||
| 			"email_verification":  common.EmailVerificationEnabled, | ||||
| 			"github_oauth":        common.GitHubOAuthEnabled, | ||||
| 			"github_client_id":    common.GitHubClientId, | ||||
| 			"system_name":         common.SystemName, | ||||
| 			"logo":                common.Logo, | ||||
| 			"footer_html":         common.Footer, | ||||
| 			"wechat_qrcode":       common.WeChatAccountQRCodeImageURL, | ||||
| 			"wechat_login":        common.WeChatAuthEnabled, | ||||
| 			"server_address":      common.ServerAddress, | ||||
| 			"turnstile_check":     common.TurnstileCheckEnabled, | ||||
| 			"turnstile_site_key":  common.TurnstileSiteKey, | ||||
| 			"top_up_link":         common.TopUpLink, | ||||
| 			"chat_link":           common.ChatLink, | ||||
| 			"quota_per_unit":      common.QuotaPerUnit, | ||||
| 			"display_in_currency": common.DisplayInCurrencyEnabled, | ||||
| 			"email_verification":  config.EmailVerificationEnabled, | ||||
| 			"github_oauth":        config.GitHubOAuthEnabled, | ||||
| 			"github_client_id":    config.GitHubClientId, | ||||
| 			"system_name":         config.SystemName, | ||||
| 			"logo":                config.Logo, | ||||
| 			"footer_html":         config.Footer, | ||||
| 			"wechat_qrcode":       config.WeChatAccountQRCodeImageURL, | ||||
| 			"wechat_login":        config.WeChatAuthEnabled, | ||||
| 			"server_address":      config.ServerAddress, | ||||
| 			"turnstile_check":     config.TurnstileCheckEnabled, | ||||
| 			"turnstile_site_key":  config.TurnstileSiteKey, | ||||
| 			"top_up_link":         config.TopUpLink, | ||||
| 			"chat_link":           config.ChatLink, | ||||
| 			"quota_per_unit":      config.QuotaPerUnit, | ||||
| 			"display_in_currency": config.DisplayInCurrencyEnabled, | ||||
| 		}, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetNotice(c *gin.Context) { | ||||
| 	common.OptionMapRWMutex.RLock() | ||||
| 	defer common.OptionMapRWMutex.RUnlock() | ||||
| 	config.OptionMapRWMutex.RLock() | ||||
| 	defer config.OptionMapRWMutex.RUnlock() | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    common.OptionMap["Notice"], | ||||
| 		"data":    config.OptionMap["Notice"], | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetAbout(c *gin.Context) { | ||||
| 	common.OptionMapRWMutex.RLock() | ||||
| 	defer common.OptionMapRWMutex.RUnlock() | ||||
| 	config.OptionMapRWMutex.RLock() | ||||
| 	defer config.OptionMapRWMutex.RUnlock() | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    common.OptionMap["About"], | ||||
| 		"data":    config.OptionMap["About"], | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetHomePageContent(c *gin.Context) { | ||||
| 	common.OptionMapRWMutex.RLock() | ||||
| 	defer common.OptionMapRWMutex.RUnlock() | ||||
| 	config.OptionMapRWMutex.RLock() | ||||
| 	defer config.OptionMapRWMutex.RUnlock() | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    common.OptionMap["HomePageContent"], | ||||
| 		"data":    config.OptionMap["HomePageContent"], | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -80,9 +81,9 @@ func SendEmailVerification(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if common.EmailDomainRestrictionEnabled { | ||||
| 	if config.EmailDomainRestrictionEnabled { | ||||
| 		allowed := false | ||||
| 		for _, domain := range common.EmailDomainWhitelist { | ||||
| 		for _, domain := range config.EmailDomainWhitelist { | ||||
| 			if strings.HasSuffix(email, "@"+domain) { | ||||
| 				allowed = true | ||||
| 				break | ||||
| @@ -105,10 +106,10 @@ func SendEmailVerification(c *gin.Context) { | ||||
| 	} | ||||
| 	code := common.GenerateVerificationCode(6) | ||||
| 	common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) | ||||
| 	subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) | ||||
| 	subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) | ||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | ||||
| 		"<p>您的验证码为: <strong>%s</strong></p>"+ | ||||
| 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes) | ||||
| 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -142,12 +143,12 @@ func SendPasswordResetEmail(c *gin.Context) { | ||||
| 	} | ||||
| 	code := common.GenerateVerificationCode(0) | ||||
| 	common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) | ||||
| 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) | ||||
| 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | ||||
| 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) | ||||
| 	subject := fmt.Sprintf("%s密码重置", config.SystemName) | ||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | ||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
| @@ -2,8 +2,8 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/models/list | ||||
| @@ -342,6 +342,24 @@ func init() { | ||||
| 			Root:       "code-davinci-edit-001", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "davinci-002", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "davinci-002", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "babbage-002", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "babbage-002", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-instant-1", | ||||
| 			Object:     "model", | ||||
| @@ -418,7 +436,7 @@ func init() { | ||||
| 			Id:         "PaLM-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "google", | ||||
| 			OwnedBy:    "google palm", | ||||
| 			Permission: permission, | ||||
| 			Root:       "PaLM-2", | ||||
| 			Parent:     nil, | ||||
| @@ -427,11 +445,20 @@ func init() { | ||||
| 			Id:         "gemini-pro", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "google", | ||||
| 			OwnedBy:    "google gemini", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gemini-pro", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gemini-pro-vision", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "google gemini", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gemini-pro-vision", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_turbo", | ||||
| 			Object:     "model", | ||||
| @@ -586,14 +613,14 @@ func RetrieveModel(c *gin.Context) { | ||||
| 	if model, ok := openAIModelsMap[modelId]; ok { | ||||
| 		c.JSON(200, model) | ||||
| 	} else { | ||||
| 		openAIError := OpenAIError{ | ||||
| 		Error := openai.Error{ | ||||
| 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | ||||
| 			Type:    "invalid_request_error", | ||||
| 			Param:   "model", | ||||
| 			Code:    "model_not_found", | ||||
| 		} | ||||
| 		c.JSON(200, gin.H{ | ||||
| 			"error": openAIError, | ||||
| 			"error": Error, | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -3,7 +3,8 @@ package controller | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| @@ -12,17 +13,17 @@ import ( | ||||
|  | ||||
| func GetOptions(c *gin.Context) { | ||||
| 	var options []*model.Option | ||||
| 	common.OptionMapRWMutex.Lock() | ||||
| 	for k, v := range common.OptionMap { | ||||
| 	config.OptionMapRWMutex.Lock() | ||||
| 	for k, v := range config.OptionMap { | ||||
| 		if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { | ||||
| 			continue | ||||
| 		} | ||||
| 		options = append(options, &model.Option{ | ||||
| 			Key:   k, | ||||
| 			Value: common.Interface2String(v), | ||||
| 			Value: helper.Interface2String(v), | ||||
| 		}) | ||||
| 	} | ||||
| 	common.OptionMapRWMutex.Unlock() | ||||
| 	config.OptionMapRWMutex.Unlock() | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| @@ -42,8 +43,16 @@ func UpdateOption(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	switch option.Key { | ||||
| 	case "Theme": | ||||
| 		if !config.ValidThemes[option.Value] { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无效的主题", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	case "GitHubOAuthEnabled": | ||||
| 		if option.Value == "true" && common.GitHubClientId == "" { | ||||
| 		if option.Value == "true" && config.GitHubClientId == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | ||||
| @@ -51,7 +60,7 @@ func UpdateOption(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
| 	case "EmailDomainRestrictionEnabled": | ||||
| 		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { | ||||
| 		if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | ||||
| @@ -59,7 +68,7 @@ func UpdateOption(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
| 	case "WeChatAuthEnabled": | ||||
| 		if option.Value == "true" && common.WeChatServerAddress == "" { | ||||
| 		if option.Value == "true" && config.WeChatServerAddress == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用微信登录,请先填入微信登录相关配置信息!", | ||||
| @@ -67,7 +76,7 @@ func UpdateOption(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
| 	case "TurnstileCheckEnabled": | ||||
| 		if option.Value == "true" && common.TurnstileSiteKey == "" { | ||||
| 		if option.Value == "true" && config.TurnstileSiteKey == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", | ||||
|   | ||||
| @@ -3,7 +3,8 @@ package controller | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| ) | ||||
| @@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) { | ||||
| 	} | ||||
| 	var keys []string | ||||
| 	for i := 0; i < redemption.Count; i++ { | ||||
| 		key := common.GetUUID() | ||||
| 		key := helper.GetUUID() | ||||
| 		cleanRedemption := model.Redemption{ | ||||
| 			UserId:      c.GetInt("id"), | ||||
| 			Name:        redemption.Name, | ||||
| 			Key:         key, | ||||
| 			CreatedTime: common.GetTimestamp(), | ||||
| 			CreatedTime: helper.GetTimestamp(), | ||||
| 			Quota:       redemption.Quota, | ||||
| 		} | ||||
| 		err = cleanRedemption.Insert() | ||||
|   | ||||
| @@ -1,314 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||
|  | ||||
| type AliMessage struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| } | ||||
|  | ||||
| type AliInput struct { | ||||
| 	//Prompt   string       `json:"prompt"` | ||||
| 	Messages []AliMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| type AliParameters struct { | ||||
| 	TopP         float64 `json:"top_p,omitempty"` | ||||
| 	TopK         int     `json:"top_k,omitempty"` | ||||
| 	Seed         uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliChatRequest struct { | ||||
| 	Model      string        `json:"model"` | ||||
| 	Input      AliInput      `json:"input"` | ||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliEmbeddingRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Input struct { | ||||
| 		Texts []string `json:"texts"` | ||||
| 	} `json:"input"` | ||||
| 	Parameters *struct { | ||||
| 		TextType string `json:"text_type,omitempty"` | ||||
| 	} `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliEmbedding struct { | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	TextIndex int       `json:"text_index"` | ||||
| } | ||||
|  | ||||
| type AliEmbeddingResponse struct { | ||||
| 	Output struct { | ||||
| 		Embeddings []AliEmbedding `json:"embeddings"` | ||||
| 	} `json:"output"` | ||||
| 	Usage AliUsage `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| type AliError struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type AliUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type AliOutput struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type AliChatResponse struct { | ||||
| 	Output AliOutput `json:"output"` | ||||
| 	Usage  AliUsage  `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		messages = append(messages, AliMessage{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    strings.ToLower(message.Role), | ||||
| 		}) | ||||
| 	} | ||||
| 	return &AliChatRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Input: AliInput{ | ||||
| 			Messages: messages, | ||||
| 		}, | ||||
| 		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's | ||||
| 		//	TopP: request.TopP, | ||||
| 		//	TopK: 50, | ||||
| 		//	//Seed:         0, | ||||
| 		//	//EnableSearch: false, | ||||
| 		//}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { | ||||
| 	return &AliEmbeddingRequest{ | ||||
| 		Model: "text-embedding-v1", | ||||
| 		Input: struct { | ||||
| 			Texts []string `json:"texts"` | ||||
| 		}{ | ||||
| 			Texts: request.ParseInput(), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var aliResponse AliEmbeddingResponse | ||||
| 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
|  | ||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Usage:  Usage{TotalTokens: response.Usage.TotalTokens}, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range response.Output.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||
| 			Object:    `embedding`, | ||||
| 			Index:     item.TextIndex, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage: Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	if aliResponse.Output.FinishReason != "null" { | ||||
| 		finishReason := aliResponse.Output.FinishReason | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	lastResponseText := "" | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse AliChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if aliResponse.Usage.OutputTokens != 0 { | ||||
| 				usage.PromptTokens = aliResponse.Usage.InputTokens | ||||
| 				usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			} | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var aliResponse AliChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -1,287 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://cloud.tencent.com/document/product/1729/97732 | ||||
|  | ||||
| type TencentMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type TencentChatRequest struct { | ||||
| 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID | ||||
| 	SecretId string `json:"secret_id"` // 官网 SecretId | ||||
| 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 | ||||
| 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 | ||||
| 	Timestamp int64 `json:"timestamp"` | ||||
| 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, | ||||
| 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 | ||||
| 	Expired int64  `json:"expired"` | ||||
| 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 | ||||
| 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 | ||||
| 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 | ||||
| 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p | ||||
| 	Temperature float64 `json:"temperature"` | ||||
| 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 | ||||
| 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 | ||||
| 	// 建议该参数和 temperature 只设置1个,不要同时更改 | ||||
| 	TopP float64 `json:"top_p"` | ||||
| 	// Stream 0:同步,1:流式 (默认,协议:SSE) | ||||
| 	// 同步请求超时:60s,如果内容较长建议使用流式 | ||||
| 	Stream int `json:"stream"` | ||||
| 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 | ||||
| 	// 输入 content 总数最大支持 3000 token。 | ||||
| 	Messages []TencentMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| type TencentError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type TencentUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type TencentResponseChoices struct { | ||||
| 	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||
| 	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| 	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| } | ||||
|  | ||||
| type TencentChatResponse struct { | ||||
| 	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 | ||||
| 	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串 | ||||
| 	Id      string                   `json:"id,omitempty"`      // 会话 id | ||||
| 	Usage   Usage                    `json:"usage,omitempty"`   // token 数量 | ||||
| 	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string                   `json:"note,omitempty"`    // 注释 | ||||
| 	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | ||||
| 	messages := make([]TencentMessage, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, TencentMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, TencentMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} | ||||
| 		messages = append(messages, TencentMessage{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    message.Role, | ||||
| 		}) | ||||
| 	} | ||||
| 	stream := 0 | ||||
| 	if request.Stream { | ||||
| 		stream = 1 | ||||
| 	} | ||||
| 	return &TencentChatRequest{ | ||||
| 		Timestamp:   common.GetTimestamp(), | ||||
| 		Expired:     common.GetTimestamp() + 24*60*60, | ||||
| 		QueryID:     common.GetUUID(), | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Stream:      stream, | ||||
| 		Messages:    messages, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Usage:   response.Usage, | ||||
| 	} | ||||
| 	if len(response.Choices) > 0 { | ||||
| 		choice := OpenAITextResponseChoice{ | ||||
| 			Index: 0, | ||||
| 			Message: Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: response.Choices[0].Messages.Content, | ||||
| 			}, | ||||
| 			FinishReason: response.Choices[0].FinishReason, | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "tencent-hunyuan", | ||||
| 	} | ||||
| 	if len(TencentResponse.Choices) > 0 { | ||||
| 		var choice ChatCompletionsStreamResponseChoice | ||||
| 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content | ||||
| 		if TencentResponse.Choices[0].FinishReason == "stop" { | ||||
| 			choice.FinishReason = &stopFinishReason | ||||
| 		} | ||||
| 		response.Choices = append(response.Choices, choice) | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	var responseText string | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var TencentResponse TencentChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &TencentResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||
| 			if len(response.Choices) != 0 { | ||||
| 				responseText += response.Choices[0].Delta.Content | ||||
| 			} | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var TencentResponse TencentChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &TencentResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if TencentResponse.Error.Code != 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: TencentResponse.Error.Message, | ||||
| 				Code:    TencentResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseTencent2OpenAI(&TencentResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { | ||||
| 	parts := strings.Split(config, "|") | ||||
| 	if len(parts) != 3 { | ||||
| 		err = errors.New("invalid tencent config") | ||||
| 		return | ||||
| 	} | ||||
| 	appId, err = strconv.ParseInt(parts[0], 10, 64) | ||||
| 	secretId = parts[1] | ||||
| 	secretKey = parts[2] | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func getTencentSign(req TencentChatRequest, secretKey string) string { | ||||
| 	params := make([]string, 0) | ||||
| 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | ||||
| 	params = append(params, "secret_id="+req.SecretId) | ||||
| 	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) | ||||
| 	params = append(params, "query_id="+req.QueryID) | ||||
| 	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) | ||||
| 	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) | ||||
| 	params = append(params, "stream="+strconv.Itoa(req.Stream)) | ||||
| 	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) | ||||
|  | ||||
| 	var messageStr string | ||||
| 	for _, msg := range req.Messages { | ||||
| 		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) | ||||
| 	} | ||||
| 	messageStr = strings.TrimSuffix(messageStr, ",") | ||||
| 	params = append(params, "messages=["+messageStr+"]") | ||||
|  | ||||
| 	sort.Sort(sort.StringSlice(params)) | ||||
| 	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") | ||||
| 	mac := hmac.New(sha1.New, []byte(secretKey)) | ||||
| 	signURL := url | ||||
| 	mac.Write([]byte(signURL)) | ||||
| 	sign := mac.Sum([]byte(nil)) | ||||
| 	return base64.StdEncoding.EncodeToString(sign) | ||||
| } | ||||
| @@ -1,695 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	APITypeOpenAI = iota | ||||
| 	APITypeClaude | ||||
| 	APITypePaLM | ||||
| 	APITypeBaidu | ||||
| 	APITypeZhipu | ||||
| 	APITypeAli | ||||
| 	APITypeXunfei | ||||
| 	APITypeAIProxyLibrary | ||||
| 	APITypeTencent | ||||
| 	APITypeGemini | ||||
| ) | ||||
|  | ||||
| var httpClient *http.Client | ||||
| var impatientHTTPClient *http.Client | ||||
|  | ||||
| func init() { | ||||
| 	if common.RelayTimeout == 0 { | ||||
| 		httpClient = &http.Client{} | ||||
| 	} else { | ||||
| 		httpClient = &http.Client{ | ||||
| 			Timeout: time.Duration(common.RelayTimeout) * time.Second, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	impatientHTTPClient = &http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
| 	var textRequest GeneralOpenAIRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { | ||||
| 		return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { | ||||
| 		textRequest.Model = "text-moderation-latest" | ||||
| 	} | ||||
| 	if relayMode == RelayModeEmbeddings && textRequest.Model == "" { | ||||
| 		textRequest.Model = c.Param("model") | ||||
| 	} | ||||
| 	// request validation | ||||
| 	if textRequest.Model == "" { | ||||
| 		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 	switch relayMode { | ||||
| 	case RelayModeCompletions: | ||||
| 		if textRequest.Prompt == "" { | ||||
| 			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeChatCompletions: | ||||
| 		if textRequest.Messages == nil || len(textRequest.Messages) == 0 { | ||||
| 			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeEmbeddings: | ||||
| 	case RelayModeModerations: | ||||
| 		if textRequest.Input == "" { | ||||
| 			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeEdits: | ||||
| 		if textRequest.Instruction == "" { | ||||
| 			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" && modelMapping != "{}" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[textRequest.Model] != "" { | ||||
| 			textRequest.Model = modelMap[textRequest.Model] | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
| 	apiType := APITypeOpenAI | ||||
| 	switch channelType { | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		apiType = APITypeClaude | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		apiType = APITypeBaidu | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		apiType = APITypePaLM | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		apiType = APITypeZhipu | ||||
| 	case common.ChannelTypeAli: | ||||
| 		apiType = APITypeAli | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		apiType = APITypeXunfei | ||||
| 	case common.ChannelTypeAIProxyLibrary: | ||||
| 		apiType = APITypeAIProxyLibrary | ||||
| 	case common.ChannelTypeTencent: | ||||
| 		apiType = APITypeTencent | ||||
| 	case common.ChannelTypeGemini: | ||||
| 		apiType = APITypeGemini | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 			apiVersion := GetAPIVersion(c) | ||||
| 			requestURL := strings.Split(requestURL, "?")[0] | ||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||
| 			baseURL = c.GetString("base_url") | ||||
| 			task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 			model_ := textRequest.Model | ||||
| 			model_ = strings.Replace(model_, ".", "", -1) | ||||
| 			// https://github.com/songquanpeng/one-api/issues/67 | ||||
| 			model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | ||||
|  | ||||
| 			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||
| 			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) | ||||
| 		} | ||||
| 	case APITypeBaidu: | ||||
| 		switch textRequest.Model { | ||||
| 		case "ERNIE-Bot": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 		case "ERNIE-Bot-turbo": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "ERNIE-Bot-4": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		var err error | ||||
| 		if apiKey, err = getBaiduAccessToken(apiKey); err != nil { | ||||
| 			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		fullRequestURL += "?access_token=" + apiKey | ||||
| 	case APITypePaLM: | ||||
| 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		fullRequestURL += "?key=" + apiKey | ||||
| 	case APITypeGemini: | ||||
| 		requestBaseURL := "https://generativelanguage.googleapis.com" | ||||
| 		if baseURL != "" { | ||||
| 			requestBaseURL = baseURL | ||||
| 		} | ||||
| 		version := "v1" | ||||
| 		if c.GetString("api_version") != "" { | ||||
| 			version = c.GetString("api_version") | ||||
| 		} | ||||
| 		action := "generateContent" | ||||
| 		if textRequest.Stream { | ||||
| 			action = "streamGenerateContent" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		fullRequestURL += "?key=" + apiKey | ||||
| 	case APITypeZhipu: | ||||
| 		method := "invoke" | ||||
| 		if textRequest.Stream { | ||||
| 			method = "sse-invoke" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||
| 	case APITypeAli: | ||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||
| 		if relayMode == RelayModeEmbeddings { | ||||
| 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||
| 		} | ||||
| 	case APITypeTencent: | ||||
| 		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | ||||
| 	} | ||||
| 	var promptTokens int | ||||
| 	var completionTokens int | ||||
| 	switch relayMode { | ||||
| 	case RelayModeChatCompletions: | ||||
| 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) | ||||
| 	case RelayModeCompletions: | ||||
| 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) | ||||
| 	case RelayModeModerations: | ||||
| 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model) | ||||
| 	} | ||||
| 	preConsumedTokens := common.PreConsumedQuota | ||||
| 	if textRequest.MaxTokens != 0 { | ||||
| 		preConsumedTokens = promptTokens + textRequest.MaxTokens | ||||
| 	} | ||||
| 	modelRatio := common.GetModelRatio(textRequest.Model) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota > 100*preConsumedQuota { | ||||
| 		// in this case, we do not pre-consume quota | ||||
| 		// because the user has enough quota | ||||
| 		preConsumedQuota = 0 | ||||
| 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||
| 	} | ||||
| 	if preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 		} | ||||
| 	} | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
| 	switch apiType { | ||||
| 	case APITypeClaude: | ||||
| 		claudeRequest := requestOpenAI2Claude(textRequest) | ||||
| 		jsonStr, err := json.Marshal(claudeRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeBaidu: | ||||
| 		var jsonData []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case RelayModeEmbeddings: | ||||
| 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||
| 		default: | ||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonData) | ||||
| 	case APITypePaLM: | ||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) | ||||
| 		jsonStr, err := json.Marshal(palmRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeGemini: | ||||
| 		geminiChatRequest := requestOpenAI2Gemini(textRequest) | ||||
| 		jsonStr, err := json.Marshal(geminiChatRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeZhipu: | ||||
| 		zhipuRequest := requestOpenAI2Zhipu(textRequest) | ||||
| 		jsonStr, err := json.Marshal(zhipuRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeAli: | ||||
| 		var jsonStr []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case RelayModeEmbeddings: | ||||
| 			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) | ||||
| 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | ||||
| 		default: | ||||
| 			aliRequest := requestOpenAI2Ali(textRequest) | ||||
| 			jsonStr, err = json.Marshal(aliRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeTencent: | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		appId, secretId, secretKey, err := parseTencentConfig(apiKey) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		tencentRequest := requestOpenAI2Tencent(textRequest) | ||||
| 		tencentRequest.AppId = appId | ||||
| 		tencentRequest.SecretId = secretId | ||||
| 		jsonStr, err := json.Marshal(tencentRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		sign := getTencentSign(*tencentRequest, secretKey) | ||||
| 		c.Request.Header.Set("Authorization", sign) | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | ||||
| 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||
| 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} | ||||
|  | ||||
| 	var req *http.Request | ||||
| 	var resp *http.Response | ||||
| 	isStream := textRequest.Stream | ||||
|  | ||||
| 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		switch apiType { | ||||
| 		case APITypeOpenAI: | ||||
| 			if channelType == common.ChannelTypeAzure { | ||||
| 				req.Header.Set("api-key", apiKey) | ||||
| 			} else { | ||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 				if channelType == common.ChannelTypeOpenRouter { | ||||
| 					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") | ||||
| 					req.Header.Set("X-Title", "One API") | ||||
| 				} | ||||
| 			} | ||||
| 		case APITypeClaude: | ||||
| 			req.Header.Set("x-api-key", apiKey) | ||||
| 			anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 			if anthropicVersion == "" { | ||||
| 				anthropicVersion = "2023-06-01" | ||||
| 			} | ||||
| 			req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 		case APITypeZhipu: | ||||
| 			token := getZhipuToken(apiKey) | ||||
| 			req.Header.Set("Authorization", token) | ||||
| 		case APITypeAli: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 			if textRequest.Stream { | ||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 			} | ||||
| 			if c.GetString("plugin") != "" { | ||||
| 				req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) | ||||
| 			} | ||||
| 		case APITypeTencent: | ||||
| 			req.Header.Set("Authorization", apiKey) | ||||
| 		case APITypePaLM: | ||||
| 			// do not set Authorization header | ||||
| 		case APITypeGemini: | ||||
| 			// do not set Authorization header | ||||
| 		default: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 		} | ||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 		if isStream && c.Request.Header.Get("Accept") == "" { | ||||
| 			req.Header.Set("Accept", "text/event-stream") | ||||
| 		} | ||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = req.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = c.Request.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
|  | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			if preConsumedQuota != 0 { | ||||
| 				go func(ctx context.Context) { | ||||
| 					// return pre-consumed quota | ||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||
| 					if err != nil { | ||||
| 						common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 					} | ||||
| 				}(c.Request.Context()) | ||||
| 			} | ||||
| 			return relayErrorHandler(resp) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var textResponse TextResponse | ||||
| 	tokenName := c.GetString("token_name") | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		// c.Writer.Flush() | ||||
| 		go func() { | ||||
| 			quota := 0 | ||||
| 			completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 			promptTokens = textResponse.Usage.PromptTokens | ||||
| 			completionTokens = textResponse.Usage.CompletionTokens | ||||
| 			quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||
| 			if ratio != 0 && quota <= 0 { | ||||
| 				quota = 1 | ||||
| 			} | ||||
| 			totalTokens := promptTokens + completionTokens | ||||
| 			if totalTokens == 0 { | ||||
| 				// in this case, must be some error happened | ||||
| 				// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 				quota = 0 | ||||
| 			} | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
|  | ||||
| 		}() | ||||
| 	}(c.Request.Context()) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if isStream { | ||||
| 			err, responseText := openaiStreamHandler(c, resp, relayMode) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		if isStream { | ||||
| 			err, responseText := claudeStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeBaidu: | ||||
| 		if isStream { | ||||
| 			err, usage := baiduStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			var err *OpenAIErrorWithStatusCode | ||||
| 			var usage *Usage | ||||
| 			switch relayMode { | ||||
| 			case RelayModeEmbeddings: | ||||
| 				err, usage = baiduEmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = baiduHandler(c, resp) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypePaLM: | ||||
| 		if textRequest.Stream { // PaLM2 API does not support stream | ||||
| 			err, responseText := palmStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeGemini: | ||||
| 		if textRequest.Stream { | ||||
| 			err, responseText := geminiChatStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeZhipu: | ||||
| 		if isStream { | ||||
| 			err, usage := zhipuStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := zhipuHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeAli: | ||||
| 		if isStream { | ||||
| 			err, usage := aliStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			var err *OpenAIErrorWithStatusCode | ||||
| 			var usage *Usage | ||||
| 			switch relayMode { | ||||
| 			case RelayModeEmbeddings: | ||||
| 				err, usage = aliEmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = aliHandler(c, resp) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeXunfei: | ||||
| 		auth := c.Request.Header.Get("Authorization") | ||||
| 		auth = strings.TrimPrefix(auth, "Bearer ") | ||||
| 		splits := strings.Split(auth, "|") | ||||
| 		if len(splits) != 3 { | ||||
| 			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||
| 		} | ||||
| 		var err *OpenAIErrorWithStatusCode | ||||
| 		var usage *Usage | ||||
| 		if isStream { | ||||
| 			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||
| 		} else { | ||||
| 			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if usage != nil { | ||||
| 			textResponse.Usage = *usage | ||||
| 		} | ||||
| 		return nil | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		if isStream { | ||||
| 			err, usage := aiProxyLibraryStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := aiProxyLibraryHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeTencent: | ||||
| 		if isStream { | ||||
| 			err, responseText := tencentStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := tencentHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	default: | ||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||
| 	} | ||||
| } | ||||
| @@ -2,316 +2,57 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string  `json:"role"` | ||||
| 	Content any     `json:"content"` | ||||
| 	Name    *string `json:"name,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageURL struct { | ||||
| 	Url    string `json:"url,omitempty"` | ||||
| 	Detail string `json:"detail,omitempty"` | ||||
| } | ||||
|  | ||||
| type TextContent struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageContent struct { | ||||
| 	Type     string    `json:"type,omitempty"` | ||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m Message) StringContent() string { | ||||
| 	content, ok := m.Content.(string) | ||||
| 	if ok { | ||||
| 		return content | ||||
| 	} | ||||
| 	contentList, ok := m.Content.([]any) | ||||
| 	if ok { | ||||
| 		var contentStr string | ||||
| 		for _, contentItem := range contentList { | ||||
| 			contentMap, ok := contentItem.(map[string]any) | ||||
| 			if !ok { | ||||
| 				continue | ||||
| 			} | ||||
| 			if contentMap["type"] == "text" { | ||||
| 				if subStr, ok := contentMap["text"].(string); ok { | ||||
| 					contentStr += subStr | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return contentStr | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	RelayModeUnknown = iota | ||||
| 	RelayModeChatCompletions | ||||
| 	RelayModeCompletions | ||||
| 	RelayModeEmbeddings | ||||
| 	RelayModeModerations | ||||
| 	RelayModeImagesGenerations | ||||
| 	RelayModeEdits | ||||
| 	RelayModeAudioSpeech | ||||
| 	RelayModeAudioTranscription | ||||
| 	RelayModeAudioTranslation | ||||
| 	"net/http" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"one-api/relay/controller" | ||||
| 	"one-api/relay/util" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/chat | ||||
|  | ||||
| type ResponseFormat struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Model            string          `json:"model,omitempty"` | ||||
| 	Messages         []Message       `json:"messages,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Stream           bool            `json:"stream,omitempty"` | ||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	N                int             `json:"n,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| 	Functions        any             `json:"functions,omitempty"` | ||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed             float64         `json:"seed,omitempty"` | ||||
| 	Tools            any             `json:"tools,omitempty"` | ||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
| 	if r.Input == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	var input []string | ||||
| 	switch r.Input.(type) { | ||||
| 	case string: | ||||
| 		input = []string{r.Input.(string)} | ||||
| 	case []any: | ||||
| 		input = make([]string, 0, len(r.Input.([]any))) | ||||
| 		for _, item := range r.Input.([]any) { | ||||
| 			if str, ok := item.(string); ok { | ||||
| 				input = append(input, str) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return input | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	Messages  []Message `json:"messages"` | ||||
| 	MaxTokens int       `json:"max_tokens"` | ||||
| } | ||||
|  | ||||
| type TextRequest struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	Messages  []Message `json:"messages"` | ||||
| 	Prompt    string    `json:"prompt"` | ||||
| 	MaxTokens int       `json:"max_tokens"` | ||||
| 	//Stream   bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| // ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create | ||||
| type ImageRequest struct { | ||||
| 	Model          string `json:"model"` | ||||
| 	Prompt         string `json:"prompt" binding:"required"` | ||||
| 	N              int    `json:"n,omitempty"` | ||||
| 	Size           string `json:"size,omitempty"` | ||||
| 	Quality        string `json:"quality,omitempty"` | ||||
| 	ResponseFormat string `json:"response_format,omitempty"` | ||||
| 	Style          string `json:"style,omitempty"` | ||||
| 	User           string `json:"user,omitempty"` | ||||
| } | ||||
|  | ||||
| type WhisperJSONResponse struct { | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type WhisperVerboseJSONResponse struct { | ||||
| 	Task     string    `json:"task,omitempty"` | ||||
| 	Language string    `json:"language,omitempty"` | ||||
| 	Duration float64   `json:"duration,omitempty"` | ||||
| 	Text     string    `json:"text,omitempty"` | ||||
| 	Segments []Segment `json:"segments,omitempty"` | ||||
| } | ||||
|  | ||||
| type Segment struct { | ||||
| 	Id               int     `json:"id"` | ||||
| 	Seek             int     `json:"seek"` | ||||
| 	Start            float64 `json:"start"` | ||||
| 	End              float64 `json:"end"` | ||||
| 	Text             string  `json:"text"` | ||||
| 	Tokens           []int   `json:"tokens"` | ||||
| 	Temperature      float64 `json:"temperature"` | ||||
| 	AvgLogprob       float64 `json:"avg_logprob"` | ||||
| 	CompressionRatio float64 `json:"compression_ratio"` | ||||
| 	NoSpeechProb     float64 `json:"no_speech_prob"` | ||||
| } | ||||
|  | ||||
| type TextToSpeechRequest struct { | ||||
| 	Model          string  `json:"model" binding:"required"` | ||||
| 	Input          string  `json:"input" binding:"required"` | ||||
| 	Voice          string  `json:"voice" binding:"required"` | ||||
| 	Speed          float64 `json:"speed"` | ||||
| 	ResponseFormat string  `json:"response_format"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	PromptTokens     int `json:"prompt_tokens"` | ||||
| 	CompletionTokens int `json:"completion_tokens"` | ||||
| 	TotalTokens      int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type OpenAIError struct { | ||||
| 	Message string `json:"message"` | ||||
| 	Type    string `json:"type"` | ||||
| 	Param   string `json:"param"` | ||||
| 	Code    any    `json:"code"` | ||||
| } | ||||
|  | ||||
| type OpenAIErrorWithStatusCode struct { | ||||
| 	OpenAIError | ||||
| 	StatusCode int `json:"status_code"` | ||||
| } | ||||
|  | ||||
| type TextResponse struct { | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| 	Error   OpenAIError `json:"error"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponseChoice struct { | ||||
| 	Index        int `json:"index"` | ||||
| 	Message      `json:"message"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponse struct { | ||||
| 	Id      string                     `json:"id"` | ||||
| 	Object  string                     `json:"object"` | ||||
| 	Created int64                      `json:"created"` | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponseItem struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Index     int       `json:"index"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponse struct { | ||||
| 	Object string                        `json:"object"` | ||||
| 	Data   []OpenAIEmbeddingResponseItem `json:"data"` | ||||
| 	Model  string                        `json:"model"` | ||||
| 	Usage  `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ImageResponse struct { | ||||
| 	Created int `json:"created"` | ||||
| 	Data    []struct { | ||||
| 		Url string `json:"url"` | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponseChoice struct { | ||||
| 	Delta struct { | ||||
| 		Content string `json:"content"` | ||||
| 	} `json:"delta"` | ||||
| 	FinishReason *string `json:"finish_reason,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponse struct { | ||||
| 	Id      string                                `json:"id"` | ||||
| 	Object  string                                `json:"object"` | ||||
| 	Created int64                                 `json:"created"` | ||||
| 	Model   string                                `json:"model"` | ||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||
| } | ||||
|  | ||||
| type CompletionsStreamResponse struct { | ||||
| 	Choices []struct { | ||||
| 		Text         string `json:"text"` | ||||
| 		FinishReason string `json:"finish_reason"` | ||||
| 	} `json:"choices"` | ||||
| } | ||||
|  | ||||
| func Relay(c *gin.Context) { | ||||
| 	relayMode := RelayModeUnknown | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { | ||||
| 		relayMode = RelayModeChatCompletions | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { | ||||
| 		relayMode = RelayModeCompletions | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 		relayMode = RelayModeModerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 		relayMode = RelayModeImagesGenerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||
| 		relayMode = RelayModeEdits | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { | ||||
| 		relayMode = RelayModeAudioSpeech | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { | ||||
| 		relayMode = RelayModeAudioTranscription | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 		relayMode = RelayModeAudioTranslation | ||||
| 	} | ||||
| 	var err *OpenAIErrorWithStatusCode | ||||
| 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) | ||||
| 	var err *openai.ErrorWithStatusCode | ||||
| 	switch relayMode { | ||||
| 	case RelayModeImagesGenerations: | ||||
| 		err = relayImageHelper(c, relayMode) | ||||
| 	case RelayModeAudioSpeech: | ||||
| 	case constant.RelayModeImagesGenerations: | ||||
| 		err = controller.RelayImageHelper(c, relayMode) | ||||
| 	case constant.RelayModeAudioSpeech: | ||||
| 		fallthrough | ||||
| 	case RelayModeAudioTranslation: | ||||
| 	case constant.RelayModeAudioTranslation: | ||||
| 		fallthrough | ||||
| 	case RelayModeAudioTranscription: | ||||
| 		err = relayAudioHelper(c, relayMode) | ||||
| 	case constant.RelayModeAudioTranscription: | ||||
| 		err = controller.RelayAudioHelper(c, relayMode) | ||||
| 	default: | ||||
| 		err = relayTextHelper(c, relayMode) | ||||
| 		err = controller.RelayTextHelper(c, relayMode) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		requestId := c.GetString(common.RequestIdKey) | ||||
| 		requestId := c.GetString(logger.RequestIdKey) | ||||
| 		retryTimesStr := c.Query("retry") | ||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||
| 		if retryTimesStr == "" { | ||||
| 			retryTimes = common.RetryTimes | ||||
| 			retryTimes = config.RetryTimes | ||||
| 		} | ||||
| 		if retryTimes > 0 { | ||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||
| 		} else { | ||||
| 			if err.StatusCode == http.StatusTooManyRequests { | ||||
| 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | ||||
| 				err.Error.Message = "当前分组上游负载已饱和,请稍后再试" | ||||
| 			} | ||||
| 			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | ||||
| 			err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) | ||||
| 			c.JSON(err.StatusCode, gin.H{ | ||||
| 				"error": err.OpenAIError, | ||||
| 				"error": err.Error, | ||||
| 			}) | ||||
| 		} | ||||
| 		channelId := c.GetInt("channel_id") | ||||
| 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||
| 		logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | ||||
| 		if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			channelName := c.GetString("channel_name") | ||||
| 			disableChannel(channelId, channelName, err.Message) | ||||
| @@ -320,7 +61,7 @@ func Relay(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func RelayNotImplemented(c *gin.Context) { | ||||
| 	err := OpenAIError{ | ||||
| 	err := openai.Error{ | ||||
| 		Message: "API not implemented", | ||||
| 		Type:    "one_api_error", | ||||
| 		Param:   "", | ||||
| @@ -332,7 +73,7 @@ func RelayNotImplemented(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func RelayNotFound(c *gin.Context) { | ||||
| 	err := OpenAIError{ | ||||
| 	err := openai.Error{ | ||||
| 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | ||||
| 		Type:    "invalid_request_error", | ||||
| 		Param:   "", | ||||
|   | ||||
| @@ -4,6 +4,8 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| ) | ||||
| @@ -14,7 +16,7 @@ func GetAllTokens(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -119,9 +121,9 @@ func AddToken(c *gin.Context) { | ||||
| 	cleanToken := model.Token{ | ||||
| 		UserId:         c.GetInt("id"), | ||||
| 		Name:           token.Name, | ||||
| 		Key:            common.GenerateKey(), | ||||
| 		CreatedTime:    common.GetTimestamp(), | ||||
| 		AccessedTime:   common.GetTimestamp(), | ||||
| 		Key:            helper.GenerateKey(), | ||||
| 		CreatedTime:    helper.GetTimestamp(), | ||||
| 		AccessedTime:   helper.GetTimestamp(), | ||||
| 		ExpiredTime:    token.ExpiredTime, | ||||
| 		RemainQuota:    token.RemainQuota, | ||||
| 		UnlimitedQuota: token.UnlimitedQuota, | ||||
| @@ -187,7 +189,7 @@ func UpdateToken(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	if token.Status == common.TokenStatusEnabled { | ||||
| 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { | ||||
| 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", | ||||
|   | ||||
| @@ -5,8 +5,11 @@ import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -18,7 +21,7 @@ type LoginRequest struct { | ||||
| } | ||||
|  | ||||
| func Login(c *gin.Context) { | ||||
| 	if !common.PasswordLoginEnabled { | ||||
| 	if !config.PasswordLoginEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员关闭了密码登录", | ||||
| 			"success": false, | ||||
| @@ -105,14 +108,14 @@ func Logout(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func Register(c *gin.Context) { | ||||
| 	if !common.RegisterEnabled { | ||||
| 	if !config.RegisterEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员关闭了新用户注册", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if !common.PasswordRegisterEnabled { | ||||
| 	if !config.PasswordRegisterEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", | ||||
| 			"success": false, | ||||
| @@ -135,7 +138,7 @@ func Register(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if common.EmailVerificationEnabled { | ||||
| 	if config.EmailVerificationEnabled { | ||||
| 		if user.Email == "" || user.VerificationCode == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| @@ -159,7 +162,7 @@ func Register(c *gin.Context) { | ||||
| 		DisplayName: user.Username, | ||||
| 		InviterId:   inviterId, | ||||
| 	} | ||||
| 	if common.EmailVerificationEnabled { | ||||
| 	if config.EmailVerificationEnabled { | ||||
| 		cleanUser.Email = user.Email | ||||
| 	} | ||||
| 	if err := cleanUser.Insert(inviterId); err != nil { | ||||
| @@ -181,7 +184,7 @@ func GetAllUsers(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -248,6 +251,29 @@ func GetUser(c *gin.Context) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetUserDashboard(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	now := time.Now() | ||||
| 	startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() | ||||
| 	endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() | ||||
|  | ||||
| 	dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay)) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "无法获取统计信息", | ||||
| 			"data":    nil, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    dashboards, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GenerateAccessToken(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	user, err := model.GetUserById(id, true) | ||||
| @@ -258,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.AccessToken = common.GetUUID() | ||||
| 	user.AccessToken = helper.GetUUID() | ||||
|  | ||||
| 	if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -295,7 +321,7 @@ func GetAffCode(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	if user.AffCode == "" { | ||||
| 		user.AffCode = common.GetRandomString(4) | ||||
| 		user.AffCode = helper.GetRandomString(4) | ||||
| 		if err := user.Update(false); err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| @@ -702,7 +728,7 @@ func EmailBind(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	if user.Role == common.RoleRootUser { | ||||
| 		common.RootUserEmail = email | ||||
| 		config.RootUserEmail = email | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| @@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) { | ||||
| 	if code == "" { | ||||
| 		return "", errors.New("无效的参数") | ||||
| 	} | ||||
| 	req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) | ||||
| 	req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", common.WeChatServerToken) | ||||
| 	req.Header.Set("Authorization", config.WeChatServerToken) | ||||
| 	client := http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| @@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) { | ||||
| } | ||||
|  | ||||
| func WeChatAuth(c *gin.Context) { | ||||
| 	if !common.WeChatAuthEnabled { | ||||
| 	if !config.WeChatAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员未开启通过微信登录以及注册", | ||||
| 			"success": false, | ||||
| @@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if common.RegisterEnabled { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			user.DisplayName = "WeChat User" | ||||
| 			user.Role = common.RoleCommonUser | ||||
| @@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func WeChatBind(c *gin.Context) { | ||||
| 	if !common.WeChatAuthEnabled { | ||||
| 	if !config.WeChatAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员未开启通过微信登录以及注册", | ||||
| 			"success": false, | ||||
|   | ||||
							
								
								
									
										4
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
									
									
									
									
								
							| @@ -16,7 +16,7 @@ require ( | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||
| 	github.com/stretchr/testify v1.8.3 | ||||
| 	golang.org/x/crypto v0.14.0 | ||||
| 	golang.org/x/crypto v0.17.0 | ||||
| 	golang.org/x/image v0.14.0 | ||||
| 	gorm.io/driver/mysql v1.4.3 | ||||
| 	gorm.io/driver/postgres v1.5.2 | ||||
| @@ -58,7 +58,7 @@ require ( | ||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| 	golang.org/x/net v0.17.0 // indirect | ||||
| 	golang.org/x/sys v0.13.0 // indirect | ||||
| 	golang.org/x/sys v0.15.0 // indirect | ||||
| 	golang.org/x/text v0.14.0 // indirect | ||||
| 	google.golang.org/protobuf v1.30.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
|   | ||||
							
								
								
									
										8
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								go.sum
									
									
									
									
									
								
							| @@ -150,8 +150,8 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu | ||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= | ||||
| golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= | ||||
| golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= | ||||
| golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= | ||||
| golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= | ||||
| golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| @@ -164,8 +164,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc | ||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= | ||||
| golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= | ||||
| golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
|   | ||||
							
								
								
									
										249
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										249
									
								
								i18n/en.json
									
									
									
									
									
								
							| @@ -86,6 +86,7 @@ | ||||
|   "该令牌已过期": "The token has expired", | ||||
|   "该令牌额度已用尽": "The token quota has been used up", | ||||
|   "无效的令牌": "Invalid token", | ||||
|   "令牌验证失败": "Token verification failed", | ||||
|   "id 或 userId 为空!": "id or userId is empty!", | ||||
|   "quota 不能为负数!": "quota cannot be negative!", | ||||
|   "令牌额度不足": "Insufficient token quota", | ||||
| @@ -458,6 +459,7 @@ | ||||
|   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", | ||||
|   "用户名称": "User Name", | ||||
|   "令牌名称": "Token Name", | ||||
|   "默认令牌": "Default Token", | ||||
|   "留空则查询全部用户": "Leave blank to query all users", | ||||
|   "留空则查询全部令牌": "Leave blank to query all tokens", | ||||
|   "模型名称": "Model Name", | ||||
| @@ -526,5 +528,250 @@ | ||||
|   "模型版本": "Model version", | ||||
|   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", | ||||
|   "点击查看": "click to view", | ||||
|   "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" | ||||
|   "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!", | ||||
|   "处理中...": "Processing...", | ||||
|   "绑定成功!": "Binding successful!", | ||||
|   "登录成功!": "Login successful!", | ||||
|   "操作失败,重定向至登录界面中...": "Operation failed, redirecting to login screen...", | ||||
|   "出现错误,第 ${count} 次重试中...": "An error occurred, retrying ${count}...", | ||||
|   "首页": "Home", | ||||
|   "渠道": "Channel", | ||||
|   "令牌": "API Keys", | ||||
|   "兑换": "Redeem", | ||||
|   "充值": "Recharge", | ||||
|   "用户": "Users", | ||||
|   "日志": "Logs", | ||||
|   "设置": "Settings", | ||||
|   "关于": "About", | ||||
|   "聊天": "Chat", | ||||
|   "注销成功!": "Logout successful!", | ||||
|   "注销": "Log out", | ||||
|   "登录": "Log in", | ||||
|   "注册": "Sign up", | ||||
|   "加载{name}中...": "Loading {name}...", | ||||
|   "未登录或登录已过期,请重新登录!": "Not logged in or login has expired, please log in again!", | ||||
|   "请立刻修改默认密码!": "Please change the default password immediately!", | ||||
|   "欢迎回来": "Welcome back", | ||||
|   "没有账户?": "No account?", | ||||
|   "立刻注册": "Sign up now", | ||||
|   "用户名": "Username", | ||||
|   "密码": "Password", | ||||
|   "正在登录……": "Logging in...", | ||||
|   "忘记密码": "Forgot password", | ||||
|   "其他方式": "Other methods", | ||||
|   "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "Scan the QR code with WeChat, follow the official account and enter 'verification code' to get the verification code (valid within three minutes)", | ||||
|   "验证码": "Verification code", | ||||
|   "全部用户": "All users", | ||||
|   "当前用户": "Current user", | ||||
|   "全部": "All", | ||||
|   "消费": "Consumption", | ||||
|   "管理": "Management", | ||||
|   "系统": "System", | ||||
|   "未知": "Unknown", | ||||
|   "其他模型": "Other models", | ||||
|   "复制成功": "Copy successful", | ||||
|   "使用明细": "Usages", | ||||
|   "刷新": "Refresh", | ||||
|   "收起面板": "Collapse panel", | ||||
|   "展开面板": "Expand panel", | ||||
|   "显示查询选项": "Show search options", | ||||
|   "隐藏查询选项": "Hide search options", | ||||
|   "用户名称": "User name", | ||||
|   "可选值": "Optional values", | ||||
|   "渠道 ID": "Channel ID", | ||||
|   "令牌名称": "Key name", | ||||
|   "模型名称": "Model name", | ||||
|   "起始时间": "Start time", | ||||
|   "结束时间": "End time", | ||||
|   "查询": "Query", | ||||
|   "隐藏条形图": "Hide bar chart", | ||||
|   "显示条形图": "Show bar chart", | ||||
|   "折线条形图只展示最新50条数据": "Line and bar charts only show the latest 50 pieces of data", | ||||
|   "总消耗": "Total consumption", | ||||
|   "总共调用了 {payload[0].value} 次": "A total of {payload[0].value} calls were made", | ||||
|   "{model.name}: {model.value} 次": "{model.name}: {model.value} times", | ||||
|   "总共调用了 {payload[0].value} 次 {payload[0].name}": "A total of {payload[0].value} {payload[0].name} calls were made", | ||||
|   "总消耗额度": "Total consumption limit", | ||||
|   "暂无数据": "No data available", | ||||
|   "更多数据统计图形即将到来,敬请期待!": "More data statistics graphics are coming soon, stay tuned!", | ||||
|   "复制用户名": "Copy username", | ||||
|   "{`共 ${counts} 条数据`}": "{`A total of ${counts} pieces of data`}", | ||||
|   "共 0 条数据": "A total of 0 pieces of data", | ||||
|   "选择明细分类": "Select detail category", | ||||
|   "模型倍率": "model rate", | ||||
|   "分组倍率": "group rate", | ||||
|   "新密码已复制到剪贴板:": "New password has been copied to the clipboard:", | ||||
|   "密码重置确认": "Password reset confirmation", | ||||
|   "邮箱地址": "Email address", | ||||
|   "新密码": "New password", | ||||
|   "密码已复制到剪贴板:": "Password has been copied to the clipboard:", | ||||
|   "密码重置完成": "Password reset complete", | ||||
|   "提交": "Submit", | ||||
|   "返回登录": "Return to login", | ||||
|   "请稍后重试,浏览器环境检查未通过": "Please try again later, browser environment check failed", | ||||
|   "重置邮件发送成功,请检查邮箱!": "Reset email sent successfully, please check your email!", | ||||
|   "密码重置": "Password reset", | ||||
|   "重试": "Retry", | ||||
|   "组": "Group", | ||||
|   "令牌已重置并已复制到剪贴板": "Token has been reset and copied to the clipboard", | ||||
|   "邀请链接已复制到剪切板": "Invitation link has been copied to the clipboard", | ||||
|   "系统令牌已复制到剪切板": "System token has been copied to the clipboard", | ||||
|   "请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!", | ||||
|   "账户已删除!": "Account has been deleted!", | ||||
|   "微信账户绑定成功!": "WeChat account binding successful!", | ||||
|   "请稍后几秒重试,Turnstile 正在检查用户环境!": "Please try again in a few seconds, Turnstile is checking the user environment!", | ||||
|   "验证码发送成功,请检查邮箱!": "Verification code sent successfully, please check your email!", | ||||
|   "邮箱账户绑定成功!": "Email account binding successful!", | ||||
|   "个人信息": "Personal information", | ||||
|   "编辑个人信息": "Edit personal information", | ||||
|   "生成系统访问令牌": "Generate system access token", | ||||
|   "复制邀请链接": "Copy invitation link", | ||||
|   "删除个人帐户": "Delete personal account", | ||||
|   "普通用户": "Regular user", | ||||
|   "管理员": "Administrator", | ||||
|   "超级管理员": "Super administrator", | ||||
|   "显示名称": "Display name", | ||||
|   "GitHub 账号": "GitHub account", | ||||
|   "微信账号": "WeChat account", | ||||
|   "修改个人信息只允许在电脑端进行。生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。": "Modifying personal information is only allowed on a computer. The generated token is for system management, not for requesting OpenAI related services. Please be aware.", | ||||
|   "可用模型": "Available models", | ||||
|   "账号绑定": "Account binding", | ||||
|   "绑定微信": "Bind WeChat", | ||||
|   "绑定 GitHub": "Bind GitHub", | ||||
|   "绑定邮箱": "Bind Email", | ||||
|   "绑定": "Bind", | ||||
|   "绑定邮箱地址": "Bind email address", | ||||
|   "输入邮箱地址": "Enter email address", | ||||
|   "重新发送": "Resend", | ||||
|   "获取验证码": "Get verification code", | ||||
|   "确认绑定": "Confirm binding", | ||||
|   "取消": "Cancel", | ||||
|   "危险操作": "Dangerous operation", | ||||
|   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your own account, all data will be cleared and cannot be recovered", | ||||
|   "输入你的账户名": "Enter your account name", | ||||
|   "以确认删除": "To confirm deletion", | ||||
|   "确认删除": "Confirm deletion", | ||||
|   "未使用": "Not used", | ||||
|   "已禁用": "Disabled", | ||||
|   "已使用": "Used", | ||||
|   "未知状态": "Unknown status", | ||||
|   "操作成功完成!": "Operation successfully completed!", | ||||
|   "搜索兑换码的 ID 和名称 ...": "Search for the ID and name of the redemption code ...", | ||||
|   "名称": "Name", | ||||
|   "状态": "Status", | ||||
|   "额度": "Quota", | ||||
|   "创建时间": "Creation time", | ||||
|   "兑换时间": "Redemption time", | ||||
|   "操作": "Operation", | ||||
|   "尚未兑换": "Not yet redeemed", | ||||
|   "已复制到剪贴板!": "Copied to clipboard!", | ||||
|   "无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually. The redemption code has been filled in the search box.", | ||||
|   "复制": "Copy", | ||||
|   "删除": "Delete", | ||||
|   "禁用": "Disable", | ||||
|   "启用": "Enable", | ||||
|   "编辑": "Edit", | ||||
|   "添加新的兑换码": "Add new redemption code", | ||||
|   "密码长度不得小于 8 位!": "Password length must not be less than 8 characters!", | ||||
|   "两次输入的密码不一致": "The two passwords entered do not match", | ||||
|   "注册成功!": "Registration successful!", | ||||
|   "请填写注册邮箱!": "Please fill in the registration email!", | ||||
|   "请在${verificationTimeout}秒后再试": "Please try again after ${verificationTimeout} seconds", | ||||
|   "验证码发送成功,请检查你的邮箱!": "Verification code sent successfully, please check your email!", | ||||
|   "已有账户?": "Already have an account?", | ||||
|   "请输入用户名(最长 12 位)": "Please enter a username (up to 12 characters)", | ||||
|   "请输入密码(最短 8 位,最长 20 位)": "Please enter a password (minimum 8 characters, maximum 20 characters)", | ||||
|   "请再次输入密码": "Please enter the password again", | ||||
|   "请输入邮箱地址": "Please enter an email address", | ||||
|   "秒后可重发": "Can be resent after seconds", | ||||
|   "请输入邮箱验证码": "Please enter the email verification code", | ||||
|   "已过期": "Expired", | ||||
|   "已启用": "Enabled", | ||||
|   "已耗尽": "Exhausted", | ||||
|   "无": "None", | ||||
|   "令牌密钥": "API Key", | ||||
|   "令牌状态": "Key status", | ||||
|   "已用额度": "Used quota", | ||||
|   "剩余额度": "Remaining quota", | ||||
|   "过期时间": "Expiration time", | ||||
|   "你确定要删除这个令牌吗?": "Are you sure you want to delete this key?", | ||||
|   "无法复制到剪贴板,请手动复制,已将令牌密钥填入搜索框": "Unable to copy to clipboard, please copy manually. The key key has been filled in the search box.", | ||||
|   "无限制": "Unlimited", | ||||
|   "永不过期": "Never expires", | ||||
|   "使用 API 访问令牌进行服务鉴权和计费。": "Use API Key for service authentication and billing.", | ||||
|   "API 访问令牌关系到您的个人利益,请妥善留存,不要与其他人共享,也不要保存在客户端代码中。": "API Key is related to your personal interests. Please keep it properly. Do not share it with others or save it in client code.", | ||||
|   "创建令牌": "Create Key", | ||||
|   "什么都还没有,快去创建一个令牌开始使用吧!": "Nothing yet, go create a key to start using!", | ||||
|   "你确定要删除该令牌吗": "Are you sure you want to delete this key", | ||||
|   "导出令牌信息": "Export key information", | ||||
|   "错误:未登录或登录已过期,请重新登录!": "Error: Not logged in or login has expired, please log in again!", | ||||
|   "错误:请求次数过多,请稍后再试!": "Error: Too many requests, please try again later!", | ||||
|   "错误:服务器内部错误,请联系管理员!": "Error: Server internal error, please contact the online customer service!", | ||||
|   "本站仅作演示之用,无服务端!": "This site is for demonstration purposes only, no server!", | ||||
|   "错误:": "Error:", | ||||
|   "加载首页内容失败...": "Failed to load homepage content...", | ||||
|   "系统状况": "System status", | ||||
|   "系统信息": "System information", | ||||
|   "系统信息总览": "System information overview", | ||||
|   "名称:": "Name:", | ||||
|   "版本:": "Version:", | ||||
|   "源码:": "Source code:", | ||||
|   "启动时间:": "Startup time:", | ||||
|   "系统配置": "System configuration", | ||||
|   "系统配置总览": "System configuration overview", | ||||
|   "邮箱验证:": "Email verification:", | ||||
|   "未启用": "Not enabled", | ||||
|   "Turnstile 用户校验:": "Turnstile user verification:", | ||||
|   "页面不存在": "Page does not exist", | ||||
|   "请检查你的浏览器地址是否正确": "Please check if your browser address is correct", | ||||
|   "个人设置": "Personal settings", | ||||
|   "运营设置": "Operations settings", | ||||
|   "系统设置": "System settings", | ||||
|   "其他设置": "Other settings", | ||||
|   "默认令牌": "Default key", | ||||
|   "过期时间必须在当前时间之后!": "Expiration time must be after the current time!", | ||||
|   "额度必须大于等于 0!": "Quota must be greater than or equal to 0!", | ||||
|   "过期时间格式错误!": "Expiration time format error!", | ||||
|   "创建令牌数量必须大于等于 1!": "The number of keys to create must be greater than or equal to 1!", | ||||
|   "令牌修改成功": "API Key modification successful", | ||||
|   "令牌创建成功": "API Key creation successful", | ||||
|   "更新令牌信息": "Update key information", | ||||
|   "创建新的令牌": "Create a new key", | ||||
|   "请输入名称": "Please enter a name", | ||||
|   "请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means unlimited", | ||||
|   "无限额度": "Unlimited quota", | ||||
|   "注意:启用无限额度后,已用额度将不再进行计算。": "Note: After enabling unlimited quota, the used quota will no longer be calculated.", | ||||
|   "等于": "Equals", | ||||
|   "请输入额度(单位:token)": "Please enter the quota (unit: token)", | ||||
|   "创建令牌数量": "Create key quantity", | ||||
|   "请输入令牌数量": "Please enter the number of keys", | ||||
|   "注意:令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note: The quota of the key is only used to limit the maximum quota usage of the key itself, and the actual usage is subject to the remaining quota of the account.", | ||||
|   "我的令牌": "My keys", | ||||
|   "请输入额度兑换码!": "Please enter the redeem code!", | ||||
|   "充值成功!": "Recharge successful!", | ||||
|   "请求失败": "Request failed", | ||||
|   "超级管理员未设置充值链接!": "The super administrator did not set a recharge link!", | ||||
|   "充值额度": "Recharge quota", | ||||
|   "兑换中...": "Redeeming...", | ||||
|   "请点击充值以获取额度兑换码。": "Please click recharge to get the quota redemption code.", | ||||
|   "用户信息更新成功!": "User information updated successfully!", | ||||
|   "更新用户信息": "Update user information", | ||||
|   "请输入新的用户名": "Please enter a new username", | ||||
|   "请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characters", | ||||
|   "请输入新的显示名称": "Please enter a new display name", | ||||
|   "分组": "Group", | ||||
|   "请选择分组": "Please select a group", | ||||
|   "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group rate on the system settings page to add a new group:", | ||||
|   "请输入新的剩余额度": "Please enter a new remaining quota", | ||||
|   "已绑定的 GitHub 账户": "Bound GitHub account", | ||||
|   "此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only, users need to bind through the relevant binding button on the personal settings page, cannot be directly modified", | ||||
|   "已绑定的微信账户": "Bound WeChat account", | ||||
|   "已绑定的邮箱账户": "Bound email account", | ||||
|   "新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using the shortcut key Shift + F5", | ||||
|   "无法正常连接至服务器!": "Unable to connect to the server normally!", | ||||
|   "提示:": "Input:", | ||||
|   "补全:": "Output:", | ||||
|   "搜索令牌名称": "Search key name", | ||||
|   "测试所有渠道": "Test all channels", | ||||
|   "更新已启用渠道余额": "Update the balance of enabled channels" | ||||
| } | ||||
|   | ||||
							
								
								
									
										53
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								main.go
									
									
									
									
									
								
							| @@ -7,82 +7,83 @@ import ( | ||||
| 	"github.com/gin-contrib/sessions/cookie" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/controller" | ||||
| 	"one-api/middleware" | ||||
| 	"one-api/model" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/router" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| //go:embed web/build | ||||
| //go:embed web/build/* | ||||
| var buildFS embed.FS | ||||
|  | ||||
| //go:embed web/build/index.html | ||||
| var indexPage []byte | ||||
|  | ||||
| func main() { | ||||
| 	common.SetupLogger() | ||||
| 	common.SysLog("One API " + common.Version + " started") | ||||
| 	logger.SetupLogger() | ||||
| 	logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) | ||||
| 	if os.Getenv("GIN_MODE") != "debug" { | ||||
| 		gin.SetMode(gin.ReleaseMode) | ||||
| 	} | ||||
| 	if common.DebugEnabled { | ||||
| 		common.SysLog("running in debug mode") | ||||
| 	if config.DebugEnabled { | ||||
| 		logger.SysLog("running in debug mode") | ||||
| 	} | ||||
| 	// Initialize SQL Database | ||||
| 	err := model.InitDB() | ||||
| 	if err != nil { | ||||
| 		common.FatalLog("failed to initialize database: " + err.Error()) | ||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		err := model.CloseDB() | ||||
| 		if err != nil { | ||||
| 			common.FatalLog("failed to close database: " + err.Error()) | ||||
| 			logger.FatalLog("failed to close database: " + err.Error()) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// Initialize Redis | ||||
| 	err = common.InitRedisClient() | ||||
| 	if err != nil { | ||||
| 		common.FatalLog("failed to initialize Redis: " + err.Error()) | ||||
| 		logger.FatalLog("failed to initialize Redis: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	// Initialize options | ||||
| 	model.InitOptionMap() | ||||
| 	logger.SysLog(fmt.Sprintf("using theme %s", config.Theme)) | ||||
| 	if common.RedisEnabled { | ||||
| 		// for compatibility with old versions | ||||
| 		common.MemoryCacheEnabled = true | ||||
| 		config.MemoryCacheEnabled = true | ||||
| 	} | ||||
| 	if common.MemoryCacheEnabled { | ||||
| 		common.SysLog("memory cache enabled") | ||||
| 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) | ||||
| 	if config.MemoryCacheEnabled { | ||||
| 		logger.SysLog("memory cache enabled") | ||||
| 		logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) | ||||
| 		model.InitChannelCache() | ||||
| 	} | ||||
| 	if common.MemoryCacheEnabled { | ||||
| 		go model.SyncOptions(common.SyncFrequency) | ||||
| 		go model.SyncChannelCache(common.SyncFrequency) | ||||
| 	if config.MemoryCacheEnabled { | ||||
| 		go model.SyncOptions(config.SyncFrequency) | ||||
| 		go model.SyncChannelCache(config.SyncFrequency) | ||||
| 	} | ||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||
| 		if err != nil { | ||||
| 			common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) | ||||
| 			logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) | ||||
| 		} | ||||
| 		go controller.AutomaticallyUpdateChannels(frequency) | ||||
| 	} | ||||
| 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | ||||
| 		if err != nil { | ||||
| 			common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) | ||||
| 			logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) | ||||
| 		} | ||||
| 		go controller.AutomaticallyTestChannels(frequency) | ||||
| 	} | ||||
| 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | ||||
| 		common.BatchUpdateEnabled = true | ||||
| 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | ||||
| 		config.BatchUpdateEnabled = true | ||||
| 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | ||||
| 		model.InitBatchUpdater() | ||||
| 	} | ||||
| 	controller.InitTokenEncoders() | ||||
| 	openai.InitTokenEncoders() | ||||
|  | ||||
| 	// Initialize HTTP server | ||||
| 	server := gin.New() | ||||
| @@ -92,16 +93,16 @@ func main() { | ||||
| 	server.Use(middleware.RequestId()) | ||||
| 	middleware.SetUpLogger(server) | ||||
| 	// Initialize session store | ||||
| 	store := cookie.NewStore([]byte(common.SessionSecret)) | ||||
| 	store := cookie.NewStore([]byte(config.SessionSecret)) | ||||
| 	server.Use(sessions.Sessions("session", store)) | ||||
|  | ||||
| 	router.SetRouter(server, buildFS, indexPage) | ||||
| 	router.SetRouter(server, buildFS) | ||||
| 	var port = os.Getenv("PORT") | ||||
| 	if port == "" { | ||||
| 		port = strconv.Itoa(*common.Port) | ||||
| 	} | ||||
| 	err = server.Run(":" + port) | ||||
| 	if err != nil { | ||||
| 		common.FatalLog("failed to start HTTP server: " + err.Error()) | ||||
| 		logger.FatalLog("failed to start HTTP server: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| @@ -69,7 +70,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				if channel != nil { | ||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					message = "数据库一致性已被破坏,请联系管理员" | ||||
| 				} | ||||
| 				abortWithMessage(c, http.StatusServiceUnavailable, message) | ||||
|   | ||||
| @@ -3,14 +3,14 @@ package middleware | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| func SetUpLogger(server *gin.Engine) { | ||||
| 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | ||||
| 		var requestID string | ||||
| 		if param.Keys != nil { | ||||
| 			requestID = param.Keys[common.RequestIdKey].(string) | ||||
| 			requestID = param.Keys[logger.RequestIdKey].(string) | ||||
| 		} | ||||
| 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | ||||
| 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| @@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st | ||||
| 	} | ||||
| 	if listLength < int64(maxRequestNum) { | ||||
| 		rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | ||||
| 		rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | ||||
| 		rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | ||||
| 	} else { | ||||
| 		oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() | ||||
| 		oldTime, err := time.Parse(timeFormat, oldTimeStr) | ||||
| @@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st | ||||
| 		// time.Since will return negative number! | ||||
| 		// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows | ||||
| 		if int64(nowTime.Sub(oldTime).Seconds()) < duration { | ||||
| 			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | ||||
| 			rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | ||||
| 			c.Status(http.StatusTooManyRequests) | ||||
| 			c.Abort() | ||||
| 			return | ||||
| 		} else { | ||||
| 			rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | ||||
| 			rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) | ||||
| 			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | ||||
| 			rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi | ||||
| 		} | ||||
| 	} else { | ||||
| 		// It's safe to call multi times. | ||||
| 		inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) | ||||
| 		inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) | ||||
| 		return func(c *gin.Context) { | ||||
| 			memoryRateLimiter(c, maxRequestNum, duration, mark) | ||||
| 		} | ||||
| @@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi | ||||
| } | ||||
|  | ||||
| func GlobalWebRateLimit() func(c *gin.Context) { | ||||
| 	return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") | ||||
| 	return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") | ||||
| } | ||||
|  | ||||
| func GlobalAPIRateLimit() func(c *gin.Context) { | ||||
| 	return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") | ||||
| 	return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") | ||||
| } | ||||
|  | ||||
| func CriticalRateLimit() func(c *gin.Context) { | ||||
| 	return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") | ||||
| 	return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") | ||||
| } | ||||
|  | ||||
| func DownloadRateLimit() func(c *gin.Context) { | ||||
| 	return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") | ||||
| 	return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") | ||||
| } | ||||
|  | ||||
| func UploadRateLimit() func(c *gin.Context) { | ||||
| 	return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") | ||||
| 	return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") | ||||
| } | ||||
|   | ||||
| @@ -4,14 +4,16 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/logger" | ||||
| 	"runtime/debug" | ||||
| ) | ||||
|  | ||||
| func RelayPanicRecover() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		defer func() { | ||||
| 			if err := recover(); err != nil { | ||||
| 				common.SysError(fmt.Sprintf("panic detected: %v", err)) | ||||
| 				logger.SysError(fmt.Sprintf("panic detected: %v", err)) | ||||
| 				logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||
| 				c.JSON(http.StatusInternalServerError, gin.H{ | ||||
| 					"error": gin.H{ | ||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | ||||
|   | ||||
| @@ -3,16 +3,17 @@ package middleware | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| func RequestId() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		id := common.GetTimeString() + common.GetRandomString(8) | ||||
| 		c.Set(common.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) | ||||
| 		id := helper.GetTimeString() + helper.GetRandomString(8) | ||||
| 		c.Set(logger.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | ||||
| 		c.Request = c.Request.WithContext(ctx) | ||||
| 		c.Header(common.RequestIdKey, id) | ||||
| 		c.Header(logger.RequestIdKey, id) | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -6,7 +6,8 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| type turnstileCheckResponse struct { | ||||
| @@ -15,7 +16,7 @@ type turnstileCheckResponse struct { | ||||
|  | ||||
| func TurnstileCheck() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		if common.TurnstileCheckEnabled { | ||||
| 		if config.TurnstileCheckEnabled { | ||||
| 			session := sessions.Default(c) | ||||
| 			turnstileChecked := session.Get("turnstile") | ||||
| 			if turnstileChecked != nil { | ||||
| @@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc { | ||||
| 				return | ||||
| 			} | ||||
| 			rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ | ||||
| 				"secret":   {common.TurnstileSecretKey}, | ||||
| 				"secret":   {config.TurnstileSecretKey}, | ||||
| 				"response": {response}, | ||||
| 				"remoteip": {c.ClientIP()}, | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				common.SysError(err.Error()) | ||||
| 				logger.SysError(err.Error()) | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { | ||||
| 			var res turnstileCheckResponse | ||||
| 			err = json.NewDecoder(rawRes.Body).Decode(&res) | ||||
| 			if err != nil { | ||||
| 				common.SysError(err.Error()) | ||||
| 				logger.SysError(err.Error()) | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
|   | ||||
| @@ -2,16 +2,17 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| 	c.JSON(statusCode, gin.H{ | ||||
| 		"error": gin.H{ | ||||
| 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), | ||||
| 			"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), | ||||
| 			"type":    "one_api_error", | ||||
| 		}, | ||||
| 	}) | ||||
| 	c.Abort() | ||||
| 	common.LogError(c.Request.Context(), message) | ||||
| 	logger.Error(c.Request.Context(), message) | ||||
| } | ||||
|   | ||||
| @@ -6,6 +6,8 @@ import ( | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| @@ -14,10 +16,10 @@ import ( | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	TokenCacheSeconds         = common.SyncFrequency | ||||
| 	UserId2GroupCacheSeconds  = common.SyncFrequency | ||||
| 	UserId2QuotaCacheSeconds  = common.SyncFrequency | ||||
| 	UserId2StatusCacheSeconds = common.SyncFrequency | ||||
| 	TokenCacheSeconds         = config.SyncFrequency | ||||
| 	UserId2GroupCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2QuotaCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2StatusCacheSeconds = config.SyncFrequency | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| @@ -42,7 +44,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set token error: " + err.Error()) | ||||
| 			logger.SysError("Redis set token error: " + err.Error()) | ||||
| 		} | ||||
| 		return &token, nil | ||||
| 	} | ||||
| @@ -62,7 +64,7 @@ func CacheGetUserGroup(id int) (group string, err error) { | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user group error: " + err.Error()) | ||||
| 			logger.SysError("Redis set user group error: " + err.Error()) | ||||
| 		} | ||||
| 	} | ||||
| 	return group, err | ||||
| @@ -80,7 +82,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user quota error: " + err.Error()) | ||||
| 			logger.SysError("Redis set user quota error: " + err.Error()) | ||||
| 		} | ||||
| 		return quota, err | ||||
| 	} | ||||
| @@ -127,7 +129,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||
| 	if err != nil { | ||||
| 		common.SysError("Redis set user enabled error: " + err.Error()) | ||||
| 		logger.SysError("Redis set user enabled error: " + err.Error()) | ||||
| 	} | ||||
| 	return userEnabled, err | ||||
| } | ||||
| @@ -178,19 +180,19 @@ func InitChannelCache() { | ||||
| 	channelSyncLock.Lock() | ||||
| 	group2model2channels = newGroup2model2channels | ||||
| 	channelSyncLock.Unlock() | ||||
| 	common.SysLog("channels synced from database") | ||||
| 	logger.SysLog("channels synced from database") | ||||
| } | ||||
|  | ||||
| func SyncChannelCache(frequency int) { | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Second) | ||||
| 		common.SysLog("syncing channels from database") | ||||
| 		logger.SysLog("syncing channels from database") | ||||
| 		InitChannelCache() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	if !common.MemoryCacheEnabled { | ||||
| 	if !config.MemoryCacheEnabled { | ||||
| 		return GetRandomSatisfiedChannel(group, model) | ||||
| 	} | ||||
| 	channelSyncLock.RLock() | ||||
|   | ||||
| @@ -1,8 +1,13 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| type Channel struct { | ||||
| @@ -42,7 +47,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| @@ -86,11 +91,17 @@ func (channel *Channel) GetBaseURL() string { | ||||
| 	return *channel.BaseURL | ||||
| } | ||||
|  | ||||
| func (channel *Channel) GetModelMapping() string { | ||||
| 	if channel.ModelMapping == nil { | ||||
| 		return "" | ||||
| func (channel *Channel) GetModelMapping() map[string]string { | ||||
| 	if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return *channel.ModelMapping | ||||
| 	modelMapping := make(map[string]string) | ||||
| 	err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error())) | ||||
| 		return nil | ||||
| 	} | ||||
| 	return modelMapping | ||||
| } | ||||
|  | ||||
| func (channel *Channel) Insert() error { | ||||
| @@ -116,21 +127,21 @@ func (channel *Channel) Update() error { | ||||
|  | ||||
| func (channel *Channel) UpdateResponseTime(responseTime int64) { | ||||
| 	err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ | ||||
| 		TestTime:     common.GetTimestamp(), | ||||
| 		TestTime:     helper.GetTimestamp(), | ||||
| 		ResponseTime: int(responseTime), | ||||
| 	}).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update response time: " + err.Error()) | ||||
| 		logger.SysError("failed to update response time: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (channel *Channel) UpdateBalance(balance float64) { | ||||
| 	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ | ||||
| 		BalanceUpdatedTime: common.GetTimestamp(), | ||||
| 		BalanceUpdatedTime: helper.GetTimestamp(), | ||||
| 		Balance:            balance, | ||||
| 	}).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update balance: " + err.Error()) | ||||
| 		logger.SysError("failed to update balance: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -147,16 +158,16 @@ func (channel *Channel) Delete() error { | ||||
| func UpdateChannelStatusById(id int, status int) { | ||||
| 	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update ability status: " + err.Error()) | ||||
| 		logger.SysError("failed to update ability status: " + err.Error()) | ||||
| 	} | ||||
| 	err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update channel status: " + err.Error()) | ||||
| 		logger.SysError("failed to update channel status: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func UpdateChannelUsedQuota(id int, quota int) { | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||
| 		return | ||||
| 	} | ||||
| @@ -166,7 +177,7 @@ func UpdateChannelUsedQuota(id int, quota int) { | ||||
| func updateChannelUsedQuota(id int, quota int) { | ||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||
| 		logger.SysError("failed to update channel used quota: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										65
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										65
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -3,14 +3,18 @@ package model | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type Log struct { | ||||
| 	Id               int    `json:"id;index:idx_created_at_id,priority:1"` | ||||
| 	Id               int    `json:"id"` | ||||
| 	UserId           int    `json:"user_id" gorm:"index"` | ||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` | ||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"` | ||||
| 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||
| 	Content          string `json:"content"` | ||||
| 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||
| @@ -31,31 +35,31 @@ const ( | ||||
| ) | ||||
|  | ||||
| func RecordLog(userId int, logType int, content string) { | ||||
| 	if logType == LogTypeConsume && !common.LogConsumeEnabled { | ||||
| 	if logType == LogTypeConsume && !config.LogConsumeEnabled { | ||||
| 		return | ||||
| 	} | ||||
| 	log := &Log{ | ||||
| 		UserId:    userId, | ||||
| 		Username:  GetUsernameById(userId), | ||||
| 		CreatedAt: common.GetTimestamp(), | ||||
| 		CreatedAt: helper.GetTimestamp(), | ||||
| 		Type:      logType, | ||||
| 		Content:   content, | ||||
| 	} | ||||
| 	err := DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to record log: " + err.Error()) | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||
| 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| 	if !common.LogConsumeEnabled { | ||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| 	if !config.LogConsumeEnabled { | ||||
| 		return | ||||
| 	} | ||||
| 	log := &Log{ | ||||
| 		UserId:           userId, | ||||
| 		Username:         GetUsernameById(userId), | ||||
| 		CreatedAt:        common.GetTimestamp(), | ||||
| 		CreatedAt:        helper.GetTimestamp(), | ||||
| 		Type:             LogTypeConsume, | ||||
| 		Content:          content, | ||||
| 		PromptTokens:     promptTokens, | ||||
| @@ -67,7 +71,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | ||||
| 	} | ||||
| 	err := DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		common.LogError(ctx, "failed to record log: "+err.Error()) | ||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -124,12 +128,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int | ||||
| } | ||||
|  | ||||
| func SearchAllLogs(keyword string) (logs []*Log, err error) { | ||||
| 	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error | ||||
| 	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error | ||||
| 	return logs, err | ||||
| } | ||||
|  | ||||
| func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error | ||||
| 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error | ||||
| 	return logs, err | ||||
| } | ||||
|  | ||||
| @@ -182,3 +186,40 @@ func DeleteOldLog(targetTimestamp int64) (int64, error) { | ||||
| 	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|  | ||||
| type LogStatistic struct { | ||||
| 	Day              string `gorm:"column:day"` | ||||
| 	ModelName        string `gorm:"column:model_name"` | ||||
| 	RequestCount     int    `gorm:"column:request_count"` | ||||
| 	Quota            int    `gorm:"column:quota"` | ||||
| 	PromptTokens     int    `gorm:"column:prompt_tokens"` | ||||
| 	CompletionTokens int    `gorm:"column:completion_tokens"` | ||||
| } | ||||
|  | ||||
| func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) { | ||||
| 	groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day" | ||||
|  | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day" | ||||
| 	} | ||||
|  | ||||
| 	if common.UsingSQLite { | ||||
| 		groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Raw(` | ||||
| 		SELECT `+groupSelect+`, | ||||
| 		model_name, count(1) as request_count, | ||||
| 		sum(quota) as quota, | ||||
| 		sum(prompt_tokens) as prompt_tokens, | ||||
| 		sum(completion_tokens) as completion_tokens | ||||
| 		FROM logs | ||||
| 		WHERE type=2 | ||||
| 		AND user_id= ? | ||||
| 		AND created_at BETWEEN ? AND ? | ||||
| 		GROUP BY day, model_name | ||||
| 		ORDER BY day, model_name | ||||
| 	`, userId, start, end).Scan(&LogStatistics).Error | ||||
|  | ||||
| 	return LogStatistics, err | ||||
| } | ||||
|   | ||||
| @@ -7,6 +7,9 @@ import ( | ||||
| 	"gorm.io/driver/sqlite" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -16,9 +19,9 @@ var DB *gorm.DB | ||||
|  | ||||
| func createRootAccountIfNeed() error { | ||||
| 	var user User | ||||
| 	//if user.Status != common.UserStatusEnabled { | ||||
| 	//if user.Status != util.UserStatusEnabled { | ||||
| 	if err := DB.First(&user).Error; err != nil { | ||||
| 		common.SysLog("no user exists, create a root user for you: username is root, password is 123456") | ||||
| 		logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") | ||||
| 		hashedPassword, err := common.Password2Hash("123456") | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| @@ -29,7 +32,7 @@ func createRootAccountIfNeed() error { | ||||
| 			Role:        common.RoleRootUser, | ||||
| 			Status:      common.UserStatusEnabled, | ||||
| 			DisplayName: "Root User", | ||||
| 			AccessToken: common.GetUUID(), | ||||
| 			AccessToken: helper.GetUUID(), | ||||
| 			Quota:       100000000, | ||||
| 		} | ||||
| 		DB.Create(&rootUser) | ||||
| @@ -42,7 +45,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 		dsn := os.Getenv("SQL_DSN") | ||||
| 		if strings.HasPrefix(dsn, "postgres://") { | ||||
| 			// Use PostgreSQL | ||||
| 			common.SysLog("using PostgreSQL as database") | ||||
| 			logger.SysLog("using PostgreSQL as database") | ||||
| 			common.UsingPostgreSQL = true | ||||
| 			return gorm.Open(postgres.New(postgres.Config{ | ||||
| 				DSN:                  dsn, | ||||
| @@ -52,13 +55,13 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 			}) | ||||
| 		} | ||||
| 		// Use MySQL | ||||
| 		common.SysLog("using MySQL as database") | ||||
| 		logger.SysLog("using MySQL as database") | ||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||
| 			PrepareStmt: true, // precompile SQL | ||||
| 		}) | ||||
| 	} | ||||
| 	// Use SQLite | ||||
| 	common.SysLog("SQL_DSN not set, using SQLite as database") | ||||
| 	logger.SysLog("SQL_DSN not set, using SQLite as database") | ||||
| 	common.UsingSQLite = true | ||||
| 	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) | ||||
| 	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ | ||||
| @@ -69,7 +72,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| func InitDB() (err error) { | ||||
| 	db, err := chooseDB() | ||||
| 	if err == nil { | ||||
| 		if common.DebugEnabled { | ||||
| 		if config.DebugEnabled { | ||||
| 			db = db.Debug() | ||||
| 		} | ||||
| 		DB = db | ||||
| @@ -77,14 +80,14 @@ func InitDB() (err error) { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) | ||||
| 		sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) | ||||
|  | ||||
| 		if !common.IsMasterNode { | ||||
| 		if !config.IsMasterNode { | ||||
| 			return nil | ||||
| 		} | ||||
| 		common.SysLog("database migration started") | ||||
| 		logger.SysLog("database migration started") | ||||
| 		err = db.AutoMigrate(&Channel{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| @@ -113,11 +116,11 @@ func InitDB() (err error) { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		common.SysLog("database migrated") | ||||
| 		logger.SysLog("database migrated") | ||||
| 		err = createRootAccountIfNeed() | ||||
| 		return err | ||||
| 	} else { | ||||
| 		common.FatalLog(err) | ||||
| 		logger.FatalLog(err) | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
|   | ||||
							
								
								
									
										212
									
								
								model/option.go
									
									
									
									
									
								
							
							
						
						
									
										212
									
								
								model/option.go
									
									
									
									
									
								
							| @@ -2,6 +2,8 @@ package model | ||||
|  | ||||
| import ( | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -20,59 +22,56 @@ func AllOption() ([]*Option, error) { | ||||
| } | ||||
|  | ||||
| func InitOptionMap() { | ||||
| 	common.OptionMapRWMutex.Lock() | ||||
| 	common.OptionMap = make(map[string]string) | ||||
| 	common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) | ||||
| 	common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) | ||||
| 	common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) | ||||
| 	common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) | ||||
| 	common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) | ||||
| 	common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) | ||||
| 	common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) | ||||
| 	common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) | ||||
| 	common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) | ||||
| 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) | ||||
| 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) | ||||
| 	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) | ||||
| 	common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) | ||||
| 	common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) | ||||
| 	common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) | ||||
| 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | ||||
| 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) | ||||
| 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) | ||||
| 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) | ||||
| 	common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") | ||||
| 	common.OptionMap["SMTPServer"] = "" | ||||
| 	common.OptionMap["SMTPFrom"] = "" | ||||
| 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) | ||||
| 	common.OptionMap["SMTPAccount"] = "" | ||||
| 	common.OptionMap["SMTPToken"] = "" | ||||
| 	common.OptionMap["Notice"] = "" | ||||
| 	common.OptionMap["About"] = "" | ||||
| 	common.OptionMap["HomePageContent"] = "" | ||||
| 	common.OptionMap["Footer"] = common.Footer | ||||
| 	common.OptionMap["SystemName"] = common.SystemName | ||||
| 	common.OptionMap["Logo"] = common.Logo | ||||
| 	common.OptionMap["ServerAddress"] = "" | ||||
| 	common.OptionMap["GitHubClientId"] = "" | ||||
| 	common.OptionMap["GitHubClientSecret"] = "" | ||||
| 	common.OptionMap["WeChatServerAddress"] = "" | ||||
| 	common.OptionMap["WeChatServerToken"] = "" | ||||
| 	common.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||
| 	common.OptionMap["TurnstileSiteKey"] = "" | ||||
| 	common.OptionMap["TurnstileSecretKey"] = "" | ||||
| 	common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) | ||||
| 	common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) | ||||
| 	common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) | ||||
| 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) | ||||
| 	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) | ||||
| 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | ||||
| 	common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | ||||
| 	common.OptionMap["TopUpLink"] = common.TopUpLink | ||||
| 	common.OptionMap["ChatLink"] = common.ChatLink | ||||
| 	common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) | ||||
| 	common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) | ||||
| 	common.OptionMapRWMutex.Unlock() | ||||
| 	config.OptionMapRWMutex.Lock() | ||||
| 	config.OptionMap = make(map[string]string) | ||||
| 	config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) | ||||
| 	config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) | ||||
| 	config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) | ||||
| 	config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) | ||||
| 	config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) | ||||
| 	config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) | ||||
| 	config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) | ||||
| 	config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) | ||||
| 	config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) | ||||
| 	config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) | ||||
| 	config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) | ||||
| 	config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) | ||||
| 	config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) | ||||
| 	config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) | ||||
| 	config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) | ||||
| 	config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") | ||||
| 	config.OptionMap["SMTPServer"] = "" | ||||
| 	config.OptionMap["SMTPFrom"] = "" | ||||
| 	config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) | ||||
| 	config.OptionMap["SMTPAccount"] = "" | ||||
| 	config.OptionMap["SMTPToken"] = "" | ||||
| 	config.OptionMap["Notice"] = "" | ||||
| 	config.OptionMap["About"] = "" | ||||
| 	config.OptionMap["HomePageContent"] = "" | ||||
| 	config.OptionMap["Footer"] = config.Footer | ||||
| 	config.OptionMap["SystemName"] = config.SystemName | ||||
| 	config.OptionMap["Logo"] = config.Logo | ||||
| 	config.OptionMap["ServerAddress"] = "" | ||||
| 	config.OptionMap["GitHubClientId"] = "" | ||||
| 	config.OptionMap["GitHubClientSecret"] = "" | ||||
| 	config.OptionMap["WeChatServerAddress"] = "" | ||||
| 	config.OptionMap["WeChatServerToken"] = "" | ||||
| 	config.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||
| 	config.OptionMap["TurnstileSiteKey"] = "" | ||||
| 	config.OptionMap["TurnstileSecretKey"] = "" | ||||
| 	config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) | ||||
| 	config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) | ||||
| 	config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) | ||||
| 	config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) | ||||
| 	config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) | ||||
| 	config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | ||||
| 	config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | ||||
| 	config.OptionMap["TopUpLink"] = config.TopUpLink | ||||
| 	config.OptionMap["ChatLink"] = config.ChatLink | ||||
| 	config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) | ||||
| 	config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) | ||||
| 	config.OptionMap["Theme"] = config.Theme | ||||
| 	config.OptionMapRWMutex.Unlock() | ||||
| 	loadOptionsFromDatabase() | ||||
| } | ||||
|  | ||||
| @@ -81,7 +80,7 @@ func loadOptionsFromDatabase() { | ||||
| 	for _, option := range options { | ||||
| 		err := updateOptionMap(option.Key, option.Value) | ||||
| 		if err != nil { | ||||
| 			common.SysError("failed to update option map: " + err.Error()) | ||||
| 			logger.SysError("failed to update option map: " + err.Error()) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -89,7 +88,7 @@ func loadOptionsFromDatabase() { | ||||
| func SyncOptions(frequency int) { | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Second) | ||||
| 		common.SysLog("syncing options from database") | ||||
| 		logger.SysLog("syncing options from database") | ||||
| 		loadOptionsFromDatabase() | ||||
| 	} | ||||
| } | ||||
| @@ -111,115 +110,104 @@ func UpdateOption(key string, value string) error { | ||||
| } | ||||
|  | ||||
| func updateOptionMap(key string, value string) (err error) { | ||||
| 	common.OptionMapRWMutex.Lock() | ||||
| 	defer common.OptionMapRWMutex.Unlock() | ||||
| 	common.OptionMap[key] = value | ||||
| 	if strings.HasSuffix(key, "Permission") { | ||||
| 		intValue, _ := strconv.Atoi(value) | ||||
| 		switch key { | ||||
| 		case "FileUploadPermission": | ||||
| 			common.FileUploadPermission = intValue | ||||
| 		case "FileDownloadPermission": | ||||
| 			common.FileDownloadPermission = intValue | ||||
| 		case "ImageUploadPermission": | ||||
| 			common.ImageUploadPermission = intValue | ||||
| 		case "ImageDownloadPermission": | ||||
| 			common.ImageDownloadPermission = intValue | ||||
| 		} | ||||
| 	} | ||||
| 	config.OptionMapRWMutex.Lock() | ||||
| 	defer config.OptionMapRWMutex.Unlock() | ||||
| 	config.OptionMap[key] = value | ||||
| 	if strings.HasSuffix(key, "Enabled") { | ||||
| 		boolValue := value == "true" | ||||
| 		switch key { | ||||
| 		case "PasswordRegisterEnabled": | ||||
| 			common.PasswordRegisterEnabled = boolValue | ||||
| 			config.PasswordRegisterEnabled = boolValue | ||||
| 		case "PasswordLoginEnabled": | ||||
| 			common.PasswordLoginEnabled = boolValue | ||||
| 			config.PasswordLoginEnabled = boolValue | ||||
| 		case "EmailVerificationEnabled": | ||||
| 			common.EmailVerificationEnabled = boolValue | ||||
| 			config.EmailVerificationEnabled = boolValue | ||||
| 		case "GitHubOAuthEnabled": | ||||
| 			common.GitHubOAuthEnabled = boolValue | ||||
| 			config.GitHubOAuthEnabled = boolValue | ||||
| 		case "WeChatAuthEnabled": | ||||
| 			common.WeChatAuthEnabled = boolValue | ||||
| 			config.WeChatAuthEnabled = boolValue | ||||
| 		case "TurnstileCheckEnabled": | ||||
| 			common.TurnstileCheckEnabled = boolValue | ||||
| 			config.TurnstileCheckEnabled = boolValue | ||||
| 		case "RegisterEnabled": | ||||
| 			common.RegisterEnabled = boolValue | ||||
| 			config.RegisterEnabled = boolValue | ||||
| 		case "EmailDomainRestrictionEnabled": | ||||
| 			common.EmailDomainRestrictionEnabled = boolValue | ||||
| 			config.EmailDomainRestrictionEnabled = boolValue | ||||
| 		case "AutomaticDisableChannelEnabled": | ||||
| 			common.AutomaticDisableChannelEnabled = boolValue | ||||
| 			config.AutomaticDisableChannelEnabled = boolValue | ||||
| 		case "AutomaticEnableChannelEnabled": | ||||
| 			common.AutomaticEnableChannelEnabled = boolValue | ||||
| 			config.AutomaticEnableChannelEnabled = boolValue | ||||
| 		case "ApproximateTokenEnabled": | ||||
| 			common.ApproximateTokenEnabled = boolValue | ||||
| 			config.ApproximateTokenEnabled = boolValue | ||||
| 		case "LogConsumeEnabled": | ||||
| 			common.LogConsumeEnabled = boolValue | ||||
| 			config.LogConsumeEnabled = boolValue | ||||
| 		case "DisplayInCurrencyEnabled": | ||||
| 			common.DisplayInCurrencyEnabled = boolValue | ||||
| 			config.DisplayInCurrencyEnabled = boolValue | ||||
| 		case "DisplayTokenStatEnabled": | ||||
| 			common.DisplayTokenStatEnabled = boolValue | ||||
| 			config.DisplayTokenStatEnabled = boolValue | ||||
| 		} | ||||
| 	} | ||||
| 	switch key { | ||||
| 	case "EmailDomainWhitelist": | ||||
| 		common.EmailDomainWhitelist = strings.Split(value, ",") | ||||
| 		config.EmailDomainWhitelist = strings.Split(value, ",") | ||||
| 	case "SMTPServer": | ||||
| 		common.SMTPServer = value | ||||
| 		config.SMTPServer = value | ||||
| 	case "SMTPPort": | ||||
| 		intValue, _ := strconv.Atoi(value) | ||||
| 		common.SMTPPort = intValue | ||||
| 		config.SMTPPort = intValue | ||||
| 	case "SMTPAccount": | ||||
| 		common.SMTPAccount = value | ||||
| 		config.SMTPAccount = value | ||||
| 	case "SMTPFrom": | ||||
| 		common.SMTPFrom = value | ||||
| 		config.SMTPFrom = value | ||||
| 	case "SMTPToken": | ||||
| 		common.SMTPToken = value | ||||
| 		config.SMTPToken = value | ||||
| 	case "ServerAddress": | ||||
| 		common.ServerAddress = value | ||||
| 		config.ServerAddress = value | ||||
| 	case "GitHubClientId": | ||||
| 		common.GitHubClientId = value | ||||
| 		config.GitHubClientId = value | ||||
| 	case "GitHubClientSecret": | ||||
| 		common.GitHubClientSecret = value | ||||
| 		config.GitHubClientSecret = value | ||||
| 	case "Footer": | ||||
| 		common.Footer = value | ||||
| 		config.Footer = value | ||||
| 	case "SystemName": | ||||
| 		common.SystemName = value | ||||
| 		config.SystemName = value | ||||
| 	case "Logo": | ||||
| 		common.Logo = value | ||||
| 		config.Logo = value | ||||
| 	case "WeChatServerAddress": | ||||
| 		common.WeChatServerAddress = value | ||||
| 		config.WeChatServerAddress = value | ||||
| 	case "WeChatServerToken": | ||||
| 		common.WeChatServerToken = value | ||||
| 		config.WeChatServerToken = value | ||||
| 	case "WeChatAccountQRCodeImageURL": | ||||
| 		common.WeChatAccountQRCodeImageURL = value | ||||
| 		config.WeChatAccountQRCodeImageURL = value | ||||
| 	case "TurnstileSiteKey": | ||||
| 		common.TurnstileSiteKey = value | ||||
| 		config.TurnstileSiteKey = value | ||||
| 	case "TurnstileSecretKey": | ||||
| 		common.TurnstileSecretKey = value | ||||
| 		config.TurnstileSecretKey = value | ||||
| 	case "QuotaForNewUser": | ||||
| 		common.QuotaForNewUser, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForNewUser, _ = strconv.Atoi(value) | ||||
| 	case "QuotaForInviter": | ||||
| 		common.QuotaForInviter, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForInviter, _ = strconv.Atoi(value) | ||||
| 	case "QuotaForInvitee": | ||||
| 		common.QuotaForInvitee, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForInvitee, _ = strconv.Atoi(value) | ||||
| 	case "QuotaRemindThreshold": | ||||
| 		common.QuotaRemindThreshold, _ = strconv.Atoi(value) | ||||
| 		config.QuotaRemindThreshold, _ = strconv.Atoi(value) | ||||
| 	case "PreConsumedQuota": | ||||
| 		common.PreConsumedQuota, _ = strconv.Atoi(value) | ||||
| 		config.PreConsumedQuota, _ = strconv.Atoi(value) | ||||
| 	case "RetryTimes": | ||||
| 		common.RetryTimes, _ = strconv.Atoi(value) | ||||
| 		config.RetryTimes, _ = strconv.Atoi(value) | ||||
| 	case "ModelRatio": | ||||
| 		err = common.UpdateModelRatioByJSONString(value) | ||||
| 	case "GroupRatio": | ||||
| 		err = common.UpdateGroupRatioByJSONString(value) | ||||
| 	case "TopUpLink": | ||||
| 		common.TopUpLink = value | ||||
| 		config.TopUpLink = value | ||||
| 	case "ChatLink": | ||||
| 		common.ChatLink = value | ||||
| 		config.ChatLink = value | ||||
| 	case "ChannelDisableThreshold": | ||||
| 		common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) | ||||
| 		config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) | ||||
| 	case "QuotaPerUnit": | ||||
| 		common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) | ||||
| 		config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) | ||||
| 	case "Theme": | ||||
| 		config.Theme = value | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| ) | ||||
|  | ||||
| type Redemption struct { | ||||
| @@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		redemption.RedeemedTime = common.GetTimestamp() | ||||
| 		redemption.RedeemedTime = helper.GetTimestamp() | ||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | ||||
| 		err = tx.Save(redemption).Error | ||||
| 		return err | ||||
|   | ||||
| @@ -5,6 +5,9 @@ import ( | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| type Token struct { | ||||
| @@ -38,39 +41,43 @@ func ValidateUserToken(key string) (token *Token, err error) { | ||||
| 		return nil, errors.New("未提供令牌") | ||||
| 	} | ||||
| 	token, err = CacheGetTokenByKey(key) | ||||
| 	if err == nil { | ||||
| 		if token.Status == common.TokenStatusExhausted { | ||||
| 			return nil, errors.New("该令牌额度已用尽") | ||||
| 		} else if token.Status == common.TokenStatusExpired { | ||||
| 			return nil, errors.New("该令牌已过期") | ||||
| 	if err != nil { | ||||
| 		logger.SysError("CacheGetTokenByKey failed: " + err.Error()) | ||||
| 		if errors.Is(err, gorm.ErrRecordNotFound) { | ||||
| 			return nil, errors.New("无效的令牌") | ||||
| 		} | ||||
| 		if token.Status != common.TokenStatusEnabled { | ||||
| 			return nil, errors.New("该令牌状态不可用") | ||||
| 		} | ||||
| 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | ||||
| 			if !common.RedisEnabled { | ||||
| 				token.Status = common.TokenStatusExpired | ||||
| 				err := token.SelectUpdate() | ||||
| 				if err != nil { | ||||
| 					common.SysError("failed to update token status" + err.Error()) | ||||
| 				} | ||||
| 			} | ||||
| 			return nil, errors.New("该令牌已过期") | ||||
| 		} | ||||
| 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||
| 			if !common.RedisEnabled { | ||||
| 				// in this case, we can make sure the token is exhausted | ||||
| 				token.Status = common.TokenStatusExhausted | ||||
| 				err := token.SelectUpdate() | ||||
| 				if err != nil { | ||||
| 					common.SysError("failed to update token status" + err.Error()) | ||||
| 				} | ||||
| 			} | ||||
| 			return nil, errors.New("该令牌额度已用尽") | ||||
| 		} | ||||
| 		return token, nil | ||||
| 		return nil, errors.New("令牌验证失败") | ||||
| 	} | ||||
| 	return nil, errors.New("无效的令牌") | ||||
| 	if token.Status == common.TokenStatusExhausted { | ||||
| 		return nil, errors.New("该令牌额度已用尽") | ||||
| 	} else if token.Status == common.TokenStatusExpired { | ||||
| 		return nil, errors.New("该令牌已过期") | ||||
| 	} | ||||
| 	if token.Status != common.TokenStatusEnabled { | ||||
| 		return nil, errors.New("该令牌状态不可用") | ||||
| 	} | ||||
| 	if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { | ||||
| 		if !common.RedisEnabled { | ||||
| 			token.Status = common.TokenStatusExpired | ||||
| 			err := token.SelectUpdate() | ||||
| 			if err != nil { | ||||
| 				logger.SysError("failed to update token status" + err.Error()) | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, errors.New("该令牌已过期") | ||||
| 	} | ||||
| 	if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||
| 		if !common.RedisEnabled { | ||||
| 			// in this case, we can make sure the token is exhausted | ||||
| 			token.Status = common.TokenStatusExhausted | ||||
| 			err := token.SelectUpdate() | ||||
| 			if err != nil { | ||||
| 				logger.SysError("failed to update token status" + err.Error()) | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, errors.New("该令牌额度已用尽") | ||||
| 	} | ||||
| 	return token, nil | ||||
| } | ||||
|  | ||||
| func GetTokenByIds(id int, userId int) (*Token, error) { | ||||
| @@ -134,7 +141,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -146,7 +153,7 @@ func increaseTokenQuota(id int, quota int) (err error) { | ||||
| 		map[string]interface{}{ | ||||
| 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||
| 			"used_quota":    gorm.Expr("used_quota - ?", quota), | ||||
| 			"accessed_time": common.GetTimestamp(), | ||||
| 			"accessed_time": helper.GetTimestamp(), | ||||
| 		}, | ||||
| 	).Error | ||||
| 	return err | ||||
| @@ -156,7 +163,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -168,7 +175,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { | ||||
| 		map[string]interface{}{ | ||||
| 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||
| 			"accessed_time": common.GetTimestamp(), | ||||
| 			"accessed_time": helper.GetTimestamp(), | ||||
| 		}, | ||||
| 	).Error | ||||
| 	return err | ||||
| @@ -192,24 +199,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| 	if userQuota < quota { | ||||
| 		return errors.New("用户额度不足") | ||||
| 	} | ||||
| 	quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold | ||||
| 	quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold | ||||
| 	noMoreQuota := userQuota-quota <= 0 | ||||
| 	if quotaTooLow || noMoreQuota { | ||||
| 		go func() { | ||||
| 			email, err := GetUserEmail(token.UserId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("failed to fetch user email: " + err.Error()) | ||||
| 				logger.SysError("failed to fetch user email: " + err.Error()) | ||||
| 			} | ||||
| 			prompt := "您的额度即将用尽" | ||||
| 			if noMoreQuota { | ||||
| 				prompt = "您的额度已用尽" | ||||
| 			} | ||||
| 			if email != "" { | ||||
| 				topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) | ||||
| 				topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) | ||||
| 				err = common.SendEmail(prompt, email, | ||||
| 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | ||||
| 				if err != nil { | ||||
| 					common.SysError("failed to send email" + err.Error()) | ||||
| 					logger.SysError("failed to send email" + err.Error()) | ||||
| 				} | ||||
| 			} | ||||
| 		}() | ||||
|   | ||||
| @@ -5,6 +5,9 @@ import ( | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| @@ -15,7 +18,7 @@ type User struct { | ||||
| 	Username         string `json:"username" gorm:"unique;index" validate:"max=12"` | ||||
| 	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"` | ||||
| 	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"` | ||||
| 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, common | ||||
| 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, util | ||||
| 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | ||||
| 	Email            string `json:"email" gorm:"index" validate:"max=50"` | ||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||
| @@ -42,7 +45,11 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { | ||||
| } | ||||
|  | ||||
| func SearchUsers(keyword string) (users []*User, err error) { | ||||
| 	err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error | ||||
| 	if !common.UsingPostgreSQL { | ||||
| 		err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error | ||||
| 	} else { | ||||
| 		err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error | ||||
| 	} | ||||
| 	return users, err | ||||
| } | ||||
|  | ||||
| @@ -85,24 +92,24 @@ func (user *User) Insert(inviterId int) error { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	user.Quota = common.QuotaForNewUser | ||||
| 	user.AccessToken = common.GetUUID() | ||||
| 	user.AffCode = common.GetRandomString(4) | ||||
| 	user.Quota = config.QuotaForNewUser | ||||
| 	user.AccessToken = helper.GetUUID() | ||||
| 	user.AffCode = helper.GetRandomString(4) | ||||
| 	result := DB.Create(user) | ||||
| 	if result.Error != nil { | ||||
| 		return result.Error | ||||
| 	} | ||||
| 	if common.QuotaForNewUser > 0 { | ||||
| 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) | ||||
| 	if config.QuotaForNewUser > 0 { | ||||
| 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) | ||||
| 	} | ||||
| 	if inviterId != 0 { | ||||
| 		if common.QuotaForInvitee > 0 { | ||||
| 			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) | ||||
| 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) | ||||
| 		if config.QuotaForInvitee > 0 { | ||||
| 			_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) | ||||
| 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) | ||||
| 		} | ||||
| 		if common.QuotaForInviter > 0 { | ||||
| 			_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) | ||||
| 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) | ||||
| 		if config.QuotaForInviter > 0 { | ||||
| 			_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) | ||||
| 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| @@ -137,7 +144,15 @@ func (user *User) ValidateAndFill() (err error) { | ||||
| 	if user.Username == "" || password == "" { | ||||
| 		return errors.New("用户名或密码为空") | ||||
| 	} | ||||
| 	DB.Where(User{Username: user.Username}).First(user) | ||||
| 	err = DB.Where("username = ?", user.Username).First(user).Error | ||||
| 	if err != nil { | ||||
| 		// we must make sure check username firstly | ||||
| 		// consider this case: a malicious user set his username as other's email | ||||
| 		err := DB.Where("email = ?", user.Username).First(user).Error | ||||
| 		if err != nil { | ||||
| 			return errors.New("用户名或密码错误,或用户已被封禁") | ||||
| 		} | ||||
| 	} | ||||
| 	okay := common.ValidatePasswordAndHash(password, user.Password) | ||||
| 	if !okay || user.Status != common.UserStatusEnabled { | ||||
| 		return errors.New("用户名或密码错误,或用户已被封禁") | ||||
| @@ -220,7 +235,7 @@ func IsAdmin(userId int) bool { | ||||
| 	var user User | ||||
| 	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("no such user " + err.Error()) | ||||
| 		logger.SysError("no such user " + err.Error()) | ||||
| 		return false | ||||
| 	} | ||||
| 	return user.Role >= common.RoleAdminUser | ||||
| @@ -279,7 +294,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -295,7 +310,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -313,7 +328,7 @@ func GetRootUserEmail() (email string) { | ||||
| } | ||||
|  | ||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||
| 		return | ||||
| @@ -329,7 +344,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||
| 		}, | ||||
| 	).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update user used quota and request count: " + err.Error()) | ||||
| 		logger.SysError("failed to update user used quota and request count: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -340,14 +355,14 @@ func updateUserUsedQuota(id int, quota int) { | ||||
| 		}, | ||||
| 	).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update user used quota: " + err.Error()) | ||||
| 		logger.SysError("failed to update user used quota: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func updateUserRequestCount(id int, count int) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update user request count: " + err.Error()) | ||||
| 		logger.SysError("failed to update user request count: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,7 +1,8 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| @@ -28,7 +29,7 @@ func init() { | ||||
| func InitBatchUpdater() { | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) | ||||
| 			time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second) | ||||
| 			batchUpdate() | ||||
| 		} | ||||
| 	}() | ||||
| @@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) { | ||||
| } | ||||
|  | ||||
| func batchUpdate() { | ||||
| 	common.SysLog("batch update started") | ||||
| 	logger.SysLog("batch update started") | ||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||
| 		batchUpdateLocks[i].Lock() | ||||
| 		store := batchUpdateStores[i] | ||||
| @@ -57,12 +58,12 @@ func batchUpdate() { | ||||
| 			case BatchUpdateTypeUserQuota: | ||||
| 				err := increaseUserQuota(key, value) | ||||
| 				if err != nil { | ||||
| 					common.SysError("failed to batch update user quota: " + err.Error()) | ||||
| 					logger.SysError("failed to batch update user quota: " + err.Error()) | ||||
| 				} | ||||
| 			case BatchUpdateTypeTokenQuota: | ||||
| 				err := increaseTokenQuota(key, value) | ||||
| 				if err != nil { | ||||
| 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||
| 					logger.SysError("failed to batch update token quota: " + err.Error()) | ||||
| 				} | ||||
| 			case BatchUpdateTypeUsedQuota: | ||||
| 				updateUserUsedQuota(key, value) | ||||
| @@ -73,5 +74,5 @@ func batchUpdate() { | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	common.SysLog("batch update finished") | ||||
| 	logger.SysLog("batch update finished") | ||||
| } | ||||
|   | ||||
| @@ -1,3 +1,9 @@ | ||||
| [//]: # (请按照以下格式关联 issue) | ||||
| [//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) | ||||
| [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) | ||||
| [//]: # (开发者交流群:910657413) | ||||
| [//]: # (请在提交 PR 之前删除上面的注释) | ||||
|  | ||||
| close #issue_number | ||||
|  | ||||
| 我已确认该 PR 已自测通过,相关截图如下: | ||||
							
								
								
									
										22
									
								
								relay/channel/aiproxy/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/aiproxy/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package aiproxy | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package aiproxy | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| @@ -8,56 +8,29 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | ||||
| 
 | ||||
| type AIProxyLibraryRequest struct { | ||||
| 	Model     string `json:"model"` | ||||
| 	Query     string `json:"query"` | ||||
| 	LibraryId string `json:"libraryId"` | ||||
| 	Stream    bool   `json:"stream"` | ||||
| } | ||||
| 
 | ||||
| type AIProxyLibraryError struct { | ||||
| 	ErrCode int    `json:"errCode"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
| 
 | ||||
| type AIProxyLibraryDocument struct { | ||||
| 	Title string `json:"title"` | ||||
| 	URL   string `json:"url"` | ||||
| } | ||||
| 
 | ||||
| type AIProxyLibraryResponse struct { | ||||
| 	Success   bool                     `json:"success"` | ||||
| 	Answer    string                   `json:"answer"` | ||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||
| 	AIProxyLibraryError | ||||
| } | ||||
| 
 | ||||
| type AIProxyLibraryStreamResponse struct { | ||||
| 	Content   string                   `json:"content"` | ||||
| 	Finish    bool                     `json:"finish"` | ||||
| 	Model     string                   `json:"model"` | ||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||
| } | ||||
| 
 | ||||
| func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { | ||||
| func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest { | ||||
| 	query := "" | ||||
| 	if len(request.Messages) != 0 { | ||||
| 		query = request.Messages[len(request.Messages)-1].StringContent() | ||||
| 	} | ||||
| 	return &AIProxyLibraryRequest{ | ||||
| 	return &LibraryRequest{ | ||||
| 		Model:  request.Model, | ||||
| 		Stream: request.Stream, | ||||
| 		Query:  query, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | ||||
| func aiProxyDocuments2Markdown(documents []LibraryDocument) string { | ||||
| 	if len(documents) == 0 { | ||||
| 		return "" | ||||
| 	} | ||||
| @@ -68,52 +41,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | ||||
| 	return content | ||||
| } | ||||
| 
 | ||||
| func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { | ||||
| func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse { | ||||
| 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 		Message: openai.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: content, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      common.GetUUID(), | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      helper.GetUUID(), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
| 
 | ||||
| func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||
| 	choice.FinishReason = &stopFinishReason | ||||
| 	return &ChatCompletionsStreamResponse{ | ||||
| 		Id:      common.GetUUID(), | ||||
| 	choice.FinishReason = &constant.StopFinishReason | ||||
| 	return &openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      helper.GetUUID(), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = response.Content | ||||
| 	return &ChatCompletionsStreamResponse{ | ||||
| 		Id:      common.GetUUID(), | ||||
| 	return &openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      helper.GetUUID(), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   response.Model, | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var usage openai.Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| @@ -143,15 +116,15 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	var documents []AIProxyLibraryDocument | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var documents []LibraryDocument | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var AIProxyLibraryResponse AIProxyLibraryStreamResponse | ||||
| 			var AIProxyLibraryResponse LibraryStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if len(AIProxyLibraryResponse.Documents) != 0 { | ||||
| @@ -160,7 +133,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr | ||||
| 			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| @@ -169,7 +142,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr | ||||
| 			response := documentsAIProxyLibrary(documents) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| @@ -179,28 +152,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
| 
 | ||||
| func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var AIProxyLibraryResponse AIProxyLibraryResponse | ||||
| func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var AIProxyLibraryResponse LibraryResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if AIProxyLibraryResponse.ErrCode != 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: AIProxyLibraryResponse.Message, | ||||
| 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), | ||||
| 				Code:    AIProxyLibraryResponse.ErrCode, | ||||
| @@ -211,7 +184,7 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit | ||||
| 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
							
								
								
									
										32
									
								
								relay/channel/aiproxy/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								relay/channel/aiproxy/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| package aiproxy | ||||
|  | ||||
| type LibraryRequest struct { | ||||
| 	Model     string `json:"model"` | ||||
| 	Query     string `json:"query"` | ||||
| 	LibraryId string `json:"libraryId"` | ||||
| 	Stream    bool   `json:"stream"` | ||||
| } | ||||
|  | ||||
| type LibraryError struct { | ||||
| 	ErrCode int    `json:"errCode"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type LibraryDocument struct { | ||||
| 	Title string `json:"title"` | ||||
| 	URL   string `json:"url"` | ||||
| } | ||||
|  | ||||
| type LibraryResponse struct { | ||||
| 	Success   bool              `json:"success"` | ||||
| 	Answer    string            `json:"answer"` | ||||
| 	Documents []LibraryDocument `json:"documents"` | ||||
| 	LibraryError | ||||
| } | ||||
|  | ||||
| type LibraryStreamResponse struct { | ||||
| 	Content   string            `json:"content"` | ||||
| 	Finish    bool              `json:"finish"` | ||||
| 	Model     string            `json:"model"` | ||||
| 	Documents []LibraryDocument `json:"documents"` | ||||
| } | ||||
							
								
								
									
										22
									
								
								relay/channel/ali/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/ali/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
							
								
								
									
										255
									
								
								relay/channel/ali/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										255
									
								
								relay/channel/ali/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,255 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||
|  | ||||
| const EnableSearchModelSuffix = "-internet" | ||||
|  | ||||
| func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		messages = append(messages, Message{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    strings.ToLower(message.Role), | ||||
| 		}) | ||||
| 	} | ||||
| 	enableSearch := false | ||||
| 	aliModel := request.Model | ||||
| 	if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { | ||||
| 		enableSearch = true | ||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Model: aliModel, | ||||
| 		Input: Input{ | ||||
| 			Messages: messages, | ||||
| 		}, | ||||
| 		Parameters: Parameters{ | ||||
| 			EnableSearch:      enableSearch, | ||||
| 			IncrementalOutput: request.Stream, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model: "text-embedding-v1", | ||||
| 		Input: struct { | ||||
| 			Texts []string `json:"texts"` | ||||
| 		}{ | ||||
| 			Texts: request.ParseInput(), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var aliResponse EmbeddingResponse | ||||
| 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
|  | ||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Usage:  openai.Usage{TotalTokens: response.Usage.TotalTokens}, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range response.Output.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    `embedding`, | ||||
| 			Index:     item.TextIndex, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: openai.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 		Usage: openai.Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	if aliResponse.Output.FinishReason != "null" { | ||||
| 		finishReason := aliResponse.Output.FinishReason | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "qwen", | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var usage openai.Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	//lastResponseText := "" | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse ChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if aliResponse.Usage.OutputTokens != 0 { | ||||
| 				usage.PromptTokens = aliResponse.Usage.InputTokens | ||||
| 				usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			} | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			//lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var aliResponse ChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||
| 	fullTextResponse.Model = "qwen" | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
							
								
								
									
										71
									
								
								relay/channel/ali/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								relay/channel/ali/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| package ali | ||||
|  | ||||
| type Message struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| } | ||||
|  | ||||
| type Input struct { | ||||
| 	//Prompt   string       `json:"prompt"` | ||||
| 	Messages []Message `json:"messages"` | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	TopP              float64 `json:"top_p,omitempty"` | ||||
| 	TopK              int     `json:"top_k,omitempty"` | ||||
| 	Seed              uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Model      string     `json:"model"` | ||||
| 	Input      Input      `json:"input"` | ||||
| 	Parameters Parameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Input struct { | ||||
| 		Texts []string `json:"texts"` | ||||
| 	} `json:"input"` | ||||
| 	Parameters *struct { | ||||
| 		TextType string `json:"text_type,omitempty"` | ||||
| 	} `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type Embedding struct { | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	TextIndex int       `json:"text_index"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Output struct { | ||||
| 		Embeddings []Embedding `json:"embeddings"` | ||||
| 	} `json:"output"` | ||||
| 	Usage Usage `json:"usage"` | ||||
| 	Error | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type Output struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Output Output `json:"output"` | ||||
| 	Usage  Usage  `json:"usage"` | ||||
| 	Error | ||||
| } | ||||
							
								
								
									
										22
									
								
								relay/channel/anthropic/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/anthropic/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package anthropic | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package anthropic | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| @@ -8,37 +8,12 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| type ClaudeMetadata struct { | ||||
| 	UserId string `json:"user_id"` | ||||
| } | ||||
| 
 | ||||
| type ClaudeRequest struct { | ||||
| 	Model             string   `json:"model"` | ||||
| 	Prompt            string   `json:"prompt"` | ||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||
| 	Temperature       float64  `json:"temperature,omitempty"` | ||||
| 	TopP              float64  `json:"top_p,omitempty"` | ||||
| 	TopK              int      `json:"top_k,omitempty"` | ||||
| 	//ClaudeMetadata    `json:"metadata,omitempty"` | ||||
| 	Stream bool `json:"stream,omitempty"` | ||||
| } | ||||
| 
 | ||||
| type ClaudeError struct { | ||||
| 	Type    string `json:"type"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
| 
 | ||||
| type ClaudeResponse struct { | ||||
| 	Completion string      `json:"completion"` | ||||
| 	StopReason string      `json:"stop_reason"` | ||||
| 	Model      string      `json:"model"` | ||||
| 	Error      ClaudeError `json:"error"` | ||||
| } | ||||
| 
 | ||||
| func stopReasonClaude2OpenAI(reason string) string { | ||||
| 	switch reason { | ||||
| 	case "stop_sequence": | ||||
| @@ -50,8 +25,8 @@ func stopReasonClaude2OpenAI(reason string) string { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 	claudeRequest := ClaudeRequest{ | ||||
| func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request { | ||||
| 	claudeRequest := Request{ | ||||
| 		Model:             textRequest.Model, | ||||
| 		Prompt:            "", | ||||
| 		MaxTokensToSample: textRequest.MaxTokens, | ||||
| @@ -80,43 +55,43 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 	return &claudeRequest | ||||
| } | ||||
| 
 | ||||
| func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = claudeResponse.Completion | ||||
| 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | ||||
| 	if finishReason != "null" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	var response openai.ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = claudeResponse.Model | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
| 
 | ||||
| func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 		Message: openai.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
| 
 | ||||
| func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| @@ -143,16 +118,16 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var claudeResponse ClaudeResponse | ||||
| 			var claudeResponse Response | ||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += claudeResponse.Completion | ||||
| @@ -161,7 +136,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| @@ -173,28 +148,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
| 
 | ||||
| func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var claudeResponse ClaudeResponse | ||||
| 	var claudeResponse Response | ||||
| 	err = json.Unmarshal(responseBody, &claudeResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if claudeResponse.Error.Type != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: claudeResponse.Error.Message, | ||||
| 				Type:    claudeResponse.Error.Type, | ||||
| 				Param:   "", | ||||
| @@ -204,8 +179,9 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||
| 	completionTokens := countTokenText(claudeResponse.Completion, model) | ||||
| 	usage := Usage{ | ||||
| 	fullTextResponse.Model = model | ||||
| 	completionTokens := openai.CountTokenText(claudeResponse.Completion, model) | ||||
| 	usage := openai.Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| @@ -213,7 +189,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
							
								
								
									
										29
									
								
								relay/channel/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								relay/channel/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package anthropic | ||||
|  | ||||
| type Metadata struct { | ||||
| 	UserId string `json:"user_id"` | ||||
| } | ||||
|  | ||||
| type Request struct { | ||||
| 	Model             string   `json:"model"` | ||||
| 	Prompt            string   `json:"prompt"` | ||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||
| 	Temperature       float64  `json:"temperature,omitempty"` | ||||
| 	TopP              float64  `json:"top_p,omitempty"` | ||||
| 	TopK              int      `json:"top_k,omitempty"` | ||||
| 	//Metadata    `json:"metadata,omitempty"` | ||||
| 	Stream bool `json:"stream,omitempty"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Type    string `json:"type"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	Completion string `json:"completion"` | ||||
| 	StopReason string `json:"stop_reason"` | ||||
| 	Model      string `json:"model"` | ||||
| 	Error      Error  `json:"error"` | ||||
| } | ||||
							
								
								
									
										22
									
								
								relay/channel/baidu/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/baidu/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package baidu | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package baidu | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| @@ -9,6 +9,10 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"one-api/relay/util" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| @@ -16,148 +20,104 @@ import ( | ||||
| 
 | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 | ||||
| 
 | ||||
| type BaiduTokenResponse struct { | ||||
| type TokenResponse struct { | ||||
| 	ExpiresIn   int    `json:"expires_in"` | ||||
| 	AccessToken string `json:"access_token"` | ||||
| } | ||||
| 
 | ||||
| type BaiduMessage struct { | ||||
| type Message struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
| 
 | ||||
| type BaiduChatRequest struct { | ||||
| 	Messages []BaiduMessage `json:"messages"` | ||||
| 	Stream   bool           `json:"stream"` | ||||
| 	UserId   string         `json:"user_id,omitempty"` | ||||
| type ChatRequest struct { | ||||
| 	Messages []Message `json:"messages"` | ||||
| 	Stream   bool      `json:"stream"` | ||||
| 	UserId   string    `json:"user_id,omitempty"` | ||||
| } | ||||
| 
 | ||||
| type BaiduError struct { | ||||
| type Error struct { | ||||
| 	ErrorCode int    `json:"error_code"` | ||||
| 	ErrorMsg  string `json:"error_msg"` | ||||
| } | ||||
| 
 | ||||
| type BaiduChatResponse struct { | ||||
| 	Id               string `json:"id"` | ||||
| 	Object           string `json:"object"` | ||||
| 	Created          int64  `json:"created"` | ||||
| 	Result           string `json:"result"` | ||||
| 	IsTruncated      bool   `json:"is_truncated"` | ||||
| 	NeedClearHistory bool   `json:"need_clear_history"` | ||||
| 	Usage            Usage  `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
| 
 | ||||
| type BaiduChatStreamResponse struct { | ||||
| 	BaiduChatResponse | ||||
| 	SentenceId int  `json:"sentence_id"` | ||||
| 	IsEnd      bool `json:"is_end"` | ||||
| } | ||||
| 
 | ||||
| type BaiduEmbeddingRequest struct { | ||||
| 	Input []string `json:"input"` | ||||
| } | ||||
| 
 | ||||
| type BaiduEmbeddingData struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	Index     int       `json:"index"` | ||||
| } | ||||
| 
 | ||||
| type BaiduEmbeddingResponse struct { | ||||
| 	Id      string               `json:"id"` | ||||
| 	Object  string               `json:"object"` | ||||
| 	Created int64                `json:"created"` | ||||
| 	Data    []BaiduEmbeddingData `json:"data"` | ||||
| 	Usage   Usage                `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
| 
 | ||||
| type BaiduAccessToken struct { | ||||
| 	AccessToken      string    `json:"access_token"` | ||||
| 	Error            string    `json:"error,omitempty"` | ||||
| 	ErrorDescription string    `json:"error_description,omitempty"` | ||||
| 	ExpiresIn        int64     `json:"expires_in,omitempty"` | ||||
| 	ExpiresAt        time.Time `json:"-"` | ||||
| } | ||||
| 
 | ||||
| var baiduTokenStore sync.Map | ||||
| 
 | ||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||
| func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &BaiduChatRequest{ | ||||
| 	return &ChatRequest{ | ||||
| 		Messages: messages, | ||||
| 		Stream:   request.Stream, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 		Message: openai.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Result, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      response.Id, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: response.Created, | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 		Usage:   response.Usage, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
| 
 | ||||
| func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = baiduResponse.Result | ||||
| 	if baiduResponse.IsEnd { | ||||
| 		choice.FinishReason = &stopFinishReason | ||||
| 		choice.FinishReason = &constant.StopFinishReason | ||||
| 	} | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      baiduResponse.Id, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: baiduResponse.Created, | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
| 
 | ||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||
| 	return &BaiduEmbeddingRequest{ | ||||
| func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Input: request.ParseInput(), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||
| func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Model:  "baidu-embedding", | ||||
| 		Usage:  response.Usage, | ||||
| 	} | ||||
| 	for _, item := range response.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| @@ -166,8 +126,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
| 
 | ||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var usage openai.Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| @@ -194,14 +154,14 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var baiduResponse BaiduChatStreamResponse | ||||
| 			var baiduResponse ChatStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if baiduResponse.Usage.TotalTokens != 0 { | ||||
| @@ -212,7 +172,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | ||||
| 			response := streamResponseBaidu2OpenAI(&baiduResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| @@ -224,28 +184,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
| 
 | ||||
| func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduChatResponse | ||||
| func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var baiduResponse ChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| @@ -255,9 +215,10 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse) | ||||
| 	fullTextResponse.Model = "ernie-bot" | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| @@ -265,23 +226,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| 
 | ||||
| func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduEmbeddingResponse | ||||
| func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var baiduResponse EmbeddingResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| @@ -293,7 +254,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit | ||||
| 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| @@ -301,10 +262,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| 
 | ||||
| func getBaiduAccessToken(apiKey string) (string, error) { | ||||
| func GetAccessToken(apiKey string) (string, error) { | ||||
| 	if val, ok := baiduTokenStore.Load(apiKey); ok { | ||||
| 		var accessToken BaiduAccessToken | ||||
| 		if accessToken, ok = val.(BaiduAccessToken); ok { | ||||
| 		var accessToken AccessToken | ||||
| 		if accessToken, ok = val.(AccessToken); ok { | ||||
| 			// soon this will expire | ||||
| 			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { | ||||
| 				go func() { | ||||
| @@ -319,12 +280,12 @@ func getBaiduAccessToken(apiKey string) (string, error) { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if accessToken == nil { | ||||
| 		return "", errors.New("getBaiduAccessToken return a nil token") | ||||
| 		return "", errors.New("GetAccessToken return a nil token") | ||||
| 	} | ||||
| 	return (*accessToken).AccessToken, nil | ||||
| } | ||||
| 
 | ||||
| func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { | ||||
| func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { | ||||
| 	parts := strings.Split(apiKey, "|") | ||||
| 	if len(parts) != 2 { | ||||
| 		return nil, errors.New("invalid baidu apikey") | ||||
| @@ -336,13 +297,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { | ||||
| 	} | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 	req.Header.Add("Accept", "application/json") | ||||
| 	res, err := impatientHTTPClient.Do(req) | ||||
| 	res, err := util.ImpatientHTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 
 | ||||
| 	var accessToken BaiduAccessToken | ||||
| 	var accessToken AccessToken | ||||
| 	err = json.NewDecoder(res.Body).Decode(&accessToken) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
							
								
								
									
										50
									
								
								relay/channel/baidu/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								relay/channel/baidu/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| package baidu | ||||
|  | ||||
| import ( | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Id               string       `json:"id"` | ||||
| 	Object           string       `json:"object"` | ||||
| 	Created          int64        `json:"created"` | ||||
| 	Result           string       `json:"result"` | ||||
| 	IsTruncated      bool         `json:"is_truncated"` | ||||
| 	NeedClearHistory bool         `json:"need_clear_history"` | ||||
| 	Usage            openai.Usage `json:"usage"` | ||||
| 	Error | ||||
| } | ||||
|  | ||||
| type ChatStreamResponse struct { | ||||
| 	ChatResponse | ||||
| 	SentenceId int  `json:"sentence_id"` | ||||
| 	IsEnd      bool `json:"is_end"` | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	Input []string `json:"input"` | ||||
| } | ||||
|  | ||||
| type EmbeddingData struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	Index     int       `json:"index"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Id      string          `json:"id"` | ||||
| 	Object  string          `json:"object"` | ||||
| 	Created int64           `json:"created"` | ||||
| 	Data    []EmbeddingData `json:"data"` | ||||
| 	Usage   openai.Usage    `json:"usage"` | ||||
| 	Error | ||||
| } | ||||
|  | ||||
| type AccessToken struct { | ||||
| 	AccessToken      string    `json:"access_token"` | ||||
| 	Error            string    `json:"error,omitempty"` | ||||
| 	ErrorDescription string    `json:"error_description,omitempty"` | ||||
| 	ExpiresIn        int64     `json:"expires_in,omitempty"` | ||||
| 	ExpiresAt        time.Time `json:"-"` | ||||
| } | ||||
							
								
								
									
										22
									
								
								relay/channel/google/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/google/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package google | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package google | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| @@ -7,73 +7,45 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/image" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| 
 | ||||
| type GeminiChatRequest struct { | ||||
| 	Contents         []GeminiChatContent        `json:"contents"` | ||||
| 	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"` | ||||
| 	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` | ||||
| 	Tools            []GeminiChatTools          `json:"tools,omitempty"` | ||||
| } | ||||
| // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn | ||||
| 
 | ||||
| type GeminiInlineData struct { | ||||
| 	MimeType string `json:"mimeType"` | ||||
| 	Data     string `json:"data"` | ||||
| } | ||||
| 
 | ||||
| type GeminiPart struct { | ||||
| 	Text       string            `json:"text,omitempty"` | ||||
| 	InlineData *GeminiInlineData `json:"inlineData,omitempty"` | ||||
| } | ||||
| 
 | ||||
| type GeminiChatContent struct { | ||||
| 	Role  string       `json:"role,omitempty"` | ||||
| 	Parts []GeminiPart `json:"parts"` | ||||
| } | ||||
| 
 | ||||
| type GeminiChatSafetySettings struct { | ||||
| 	Category  string `json:"category"` | ||||
| 	Threshold string `json:"threshold"` | ||||
| } | ||||
| 
 | ||||
| type GeminiChatTools struct { | ||||
| 	FunctionDeclarations any `json:"functionDeclarations,omitempty"` | ||||
| } | ||||
| 
 | ||||
| type GeminiChatGenerationConfig struct { | ||||
| 	Temperature     float64  `json:"temperature,omitempty"` | ||||
| 	TopP            float64  `json:"topP,omitempty"` | ||||
| 	TopK            float64  `json:"topK,omitempty"` | ||||
| 	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"` | ||||
| 	CandidateCount  int      `json:"candidateCount,omitempty"` | ||||
| 	StopSequences   []string `json:"stopSequences,omitempty"` | ||||
| } | ||||
| const ( | ||||
| 	GeminiVisionMaxImageNum = 16 | ||||
| ) | ||||
| 
 | ||||
| // Setting safety to the lowest possible values since Gemini is already powerless enough | ||||
| func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { | ||||
| func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest { | ||||
| 	geminiRequest := GeminiChatRequest{ | ||||
| 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), | ||||
| 		//SafetySettings: []GeminiChatSafetySettings{ | ||||
| 		//	{ | ||||
| 		//		Category:  "HARM_CATEGORY_HARASSMENT", | ||||
| 		//		Threshold: "BLOCK_ONLY_HIGH", | ||||
| 		//	}, | ||||
| 		//	{ | ||||
| 		//		Category:  "HARM_CATEGORY_HATE_SPEECH", | ||||
| 		//		Threshold: "BLOCK_ONLY_HIGH", | ||||
| 		//	}, | ||||
| 		//	{ | ||||
| 		//		Category:  "HARM_CATEGORY_SEXUALLY_EXPLICIT", | ||||
| 		//		Threshold: "BLOCK_ONLY_HIGH", | ||||
| 		//	}, | ||||
| 		//	{ | ||||
| 		//		Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", | ||||
| 		//		Threshold: "BLOCK_ONLY_HIGH", | ||||
| 		//	}, | ||||
| 		//}, | ||||
| 		SafetySettings: []GeminiChatSafetySettings{ | ||||
| 			{ | ||||
| 				Category:  "HARM_CATEGORY_HARASSMENT", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Category:  "HARM_CATEGORY_HATE_SPEECH", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Category:  "HARM_CATEGORY_SEXUALLY_EXPLICIT", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 		}, | ||||
| 		GenerationConfig: GeminiChatGenerationConfig{ | ||||
| 			Temperature:     textRequest.Temperature, | ||||
| 			TopP:            textRequest.TopP, | ||||
| @@ -97,6 +69,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { | ||||
| 				}, | ||||
| 			}, | ||||
| 		} | ||||
| 		openaiContent := message.ParseContent() | ||||
| 		var parts []GeminiPart | ||||
| 		imageNum := 0 | ||||
| 		for _, part := range openaiContent { | ||||
| 			if part.Type == openai.ContentTypeText { | ||||
| 				parts = append(parts, GeminiPart{ | ||||
| 					Text: part.Text, | ||||
| 				}) | ||||
| 			} else if part.Type == openai.ContentTypeImageURL { | ||||
| 				imageNum += 1 | ||||
| 				if imageNum > GeminiVisionMaxImageNum { | ||||
| 					continue | ||||
| 				} | ||||
| 				mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) | ||||
| 				parts = append(parts, GeminiPart{ | ||||
| 					InlineData: &GeminiInlineData{ | ||||
| 						MimeType: mimeType, | ||||
| 						Data:     data, | ||||
| 					}, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 		content.Parts = parts | ||||
| 
 | ||||
| 		// there's no assistant role in gemini and API shall vomit if Role is not user or model | ||||
| 		if content.Role == "assistant" { | ||||
| 			content.Role = "model" | ||||
| @@ -156,21 +152,21 @@ type GeminiChatPromptFeedback struct { | ||||
| 	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` | ||||
| } | ||||
| 
 | ||||
| func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||
| func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), | ||||
| 	} | ||||
| 	for i, candidate := range response.Candidates { | ||||
| 		choice := OpenAITextResponseChoice{ | ||||
| 		choice := openai.TextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 			Message: openai.Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 			FinishReason: stopFinishReason, | ||||
| 			FinishReason: constant.StopFinishReason, | ||||
| 		} | ||||
| 		if len(candidate.Content.Parts) > 0 { | ||||
| 			choice.Message.Content = candidate.Content.Parts[0].Text | ||||
| @@ -180,18 +176,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse | ||||
| 	return &fullTextResponse | ||||
| } | ||||
| 
 | ||||
| func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = geminiResponse.GetResponseText() | ||||
| 	choice.FinishReason = &stopFinishReason | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	choice.FinishReason = &constant.StopFinishReason | ||||
| 	var response openai.ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = "gemini" | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
| 
 | ||||
| func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| @@ -221,7 +217,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| @@ -233,18 +229,18 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW | ||||
| 			var dummy dummyStruct | ||||
| 			err := json.Unmarshal([]byte(data), &dummy) | ||||
| 			responseText += dummy.Content | ||||
| 			var choice ChatCompletionsStreamResponseChoice | ||||
| 			var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 			choice.Delta.Content = dummy.Content | ||||
| 			response := ChatCompletionsStreamResponse{ | ||||
| 				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||
| 			response := openai.ChatCompletionsStreamResponse{ | ||||
| 				Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 				Object:  "chat.completion.chunk", | ||||
| 				Created: common.GetTimestamp(), | ||||
| 				Created: helper.GetTimestamp(), | ||||
| 				Model:   "gemini-pro", | ||||
| 				Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 				Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 			} | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| @@ -256,28 +252,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
| 
 | ||||
| func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var geminiResponse GeminiChatResponse | ||||
| 	err = json.Unmarshal(responseBody, &geminiResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if len(geminiResponse.Candidates) == 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: "No candidates returned", | ||||
| 				Type:    "server_error", | ||||
| 				Param:   "", | ||||
| @@ -287,8 +283,9 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) | ||||
| 	completionTokens := countTokenText(geminiResponse.GetResponseText(), model) | ||||
| 	usage := Usage{ | ||||
| 	fullTextResponse.Model = model | ||||
| 	completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model) | ||||
| 	usage := openai.Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| @@ -296,7 +293,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
							
								
								
									
										80
									
								
								relay/channel/google/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								relay/channel/google/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | ||||
| package google | ||||
|  | ||||
| import ( | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type GeminiChatRequest struct { | ||||
| 	Contents         []GeminiChatContent        `json:"contents"` | ||||
| 	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"` | ||||
| 	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` | ||||
| 	Tools            []GeminiChatTools          `json:"tools,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeminiInlineData struct { | ||||
| 	MimeType string `json:"mimeType"` | ||||
| 	Data     string `json:"data"` | ||||
| } | ||||
|  | ||||
| type GeminiPart struct { | ||||
| 	Text       string            `json:"text,omitempty"` | ||||
| 	InlineData *GeminiInlineData `json:"inlineData,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeminiChatContent struct { | ||||
| 	Role  string       `json:"role,omitempty"` | ||||
| 	Parts []GeminiPart `json:"parts"` | ||||
| } | ||||
|  | ||||
| type GeminiChatSafetySettings struct { | ||||
| 	Category  string `json:"category"` | ||||
| 	Threshold string `json:"threshold"` | ||||
| } | ||||
|  | ||||
| type GeminiChatTools struct { | ||||
| 	FunctionDeclarations any `json:"functionDeclarations,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeminiChatGenerationConfig struct { | ||||
| 	Temperature     float64  `json:"temperature,omitempty"` | ||||
| 	TopP            float64  `json:"topP,omitempty"` | ||||
| 	TopK            float64  `json:"topK,omitempty"` | ||||
| 	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"` | ||||
| 	CandidateCount  int      `json:"candidateCount,omitempty"` | ||||
| 	StopSequences   []string `json:"stopSequences,omitempty"` | ||||
| } | ||||
|  | ||||
| type PaLMChatMessage struct { | ||||
| 	Author  string `json:"author"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type PaLMFilter struct { | ||||
| 	Reason  string `json:"reason"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type PaLMPrompt struct { | ||||
| 	Messages []PaLMChatMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| type PaLMChatRequest struct { | ||||
| 	Prompt         PaLMPrompt `json:"prompt"` | ||||
| 	Temperature    float64    `json:"temperature,omitempty"` | ||||
| 	CandidateCount int        `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64    `json:"topP,omitempty"` | ||||
| 	TopK           int        `json:"topK,omitempty"` | ||||
| } | ||||
|  | ||||
| type PaLMError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  string `json:"status"` | ||||
| } | ||||
|  | ||||
| type PaLMChatResponse struct { | ||||
| 	Candidates []PaLMChatMessage `json:"candidates"` | ||||
| 	Messages   []openai.Message  `json:"messages"` | ||||
| 	Filters    []PaLMFilter      `json:"filters"` | ||||
| 	Error      PaLMError         `json:"error"` | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package google | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| @@ -7,47 +7,16 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| ) | ||||
| 
 | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||
| 
 | ||||
| type PaLMChatMessage struct { | ||||
| 	Author  string `json:"author"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
| 
 | ||||
| type PaLMFilter struct { | ||||
| 	Reason  string `json:"reason"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
| 
 | ||||
| type PaLMPrompt struct { | ||||
| 	Messages []PaLMChatMessage `json:"messages"` | ||||
| } | ||||
| 
 | ||||
| type PaLMChatRequest struct { | ||||
| 	Prompt         PaLMPrompt `json:"prompt"` | ||||
| 	Temperature    float64    `json:"temperature,omitempty"` | ||||
| 	CandidateCount int        `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64    `json:"topP,omitempty"` | ||||
| 	TopK           int        `json:"topK,omitempty"` | ||||
| } | ||||
| 
 | ||||
| type PaLMError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  string `json:"status"` | ||||
| } | ||||
| 
 | ||||
| type PaLMChatResponse struct { | ||||
| 	Candidates []PaLMChatMessage `json:"candidates"` | ||||
| 	Messages   []Message         `json:"messages"` | ||||
| 	Filters    []PaLMFilter      `json:"filters"` | ||||
| 	Error      PaLMError         `json:"error"` | ||||
| } | ||||
| 
 | ||||
| func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| 	palmRequest := PaLMChatRequest{ | ||||
| 		Prompt: PaLMPrompt{ | ||||
| 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), | ||||
| @@ -71,14 +40,14 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| 	return &palmRequest | ||||
| } | ||||
| 
 | ||||
| func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | ||||
| func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), | ||||
| 	} | ||||
| 	for i, candidate := range response.Candidates { | ||||
| 		choice := OpenAITextResponseChoice{ | ||||
| 		choice := openai.TextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 			Message: openai.Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: candidate.Content, | ||||
| 			}, | ||||
| @@ -89,42 +58,42 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | ||||
| 	return &fullTextResponse | ||||
| } | ||||
| 
 | ||||
| func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	if len(palmResponse.Candidates) > 0 { | ||||
| 		choice.Delta.Content = palmResponse.Candidates[0].Content | ||||
| 	} | ||||
| 	choice.FinishReason = &stopFinishReason | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	choice.FinishReason = &constant.StopFinishReason | ||||
| 	var response openai.ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = "palm2" | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
| 
 | ||||
| func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error reading stream response: " + err.Error()) | ||||
| 			logger.SysError("error reading stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			common.SysError("error closing stream response: " + err.Error()) | ||||
| 			logger.SysError("error closing stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		var palmResponse PaLMChatResponse | ||||
| 		err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| @@ -136,14 +105,14 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta | ||||
| 		} | ||||
| 		jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 			logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		dataChan <- string(jsonResponse) | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| @@ -156,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
| 
 | ||||
| func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var palmResponse PaLMChatResponse | ||||
| 	err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: palmResponse.Error.Message, | ||||
| 				Type:    palmResponse.Error.Status, | ||||
| 				Param:   "", | ||||
| @@ -187,8 +156,9 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) | ||||
| 	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) | ||||
| 	usage := Usage{ | ||||
| 	fullTextResponse.Model = model | ||||
| 	completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model) | ||||
| 	usage := openai.Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| @@ -196,7 +166,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
							
								
								
									
										15
									
								
								relay/channel/interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								relay/channel/interface.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| package channel | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor interface { | ||||
| 	GetRequestURL() string | ||||
| 	Auth(c *gin.Context) error | ||||
| 	ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) | ||||
| 	DoRequest(request *openai.GeneralOpenAIRequest) error | ||||
| 	DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) | ||||
| } | ||||
							
								
								
									
										21
									
								
								relay/channel/openai/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								relay/channel/openai/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| package openai | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
							
								
								
									
										6
									
								
								relay/channel/openai/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								relay/channel/openai/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package openai | ||||
|  | ||||
| const ( | ||||
| 	ContentTypeText     = "text" | ||||
| 	ContentTypeImageURL = "image_url" | ||||
| ) | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package openai | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| @@ -8,10 +8,12 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/constant" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| @@ -41,21 +43,21 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 			data = data[6:] | ||||
| 			if !strings.HasPrefix(data, "[DONE]") { | ||||
| 				switch relayMode { | ||||
| 				case RelayModeChatCompletions: | ||||
| 				case constant.RelayModeChatCompletions: | ||||
| 					var streamResponse ChatCompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue // just ignore the error | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Delta.Content | ||||
| 					} | ||||
| 				case RelayModeCompletions: | ||||
| 				case constant.RelayModeCompletions: | ||||
| 					var streamResponse CompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| @@ -66,7 +68,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| @@ -83,29 +85,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
| 
 | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse TextResponse | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse SlimTextResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if textResponse.Error.Type != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: textResponse.Error, | ||||
| 			StatusCode:  resp.StatusCode, | ||||
| 		return &ErrorWithStatusCode{ | ||||
| 			Error:      textResponse.Error, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	// Reset response body | ||||
| @@ -113,7 +115,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model | ||||
| 
 | ||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 	// So the httpClient will be confused by the response. | ||||
| 	// So the HTTPClient will be confused by the response. | ||||
| 	// For example, Postman will report error, and we cannot check the response at all. | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| @@ -121,17 +123,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 
 | ||||
| 	if textResponse.Usage.TotalTokens == 0 { | ||||
| 		completionTokens := 0 | ||||
| 		for _, choice := range textResponse.Choices { | ||||
| 			completionTokens += countTokenText(choice.Message.StringContent(), model) | ||||
| 			completionTokens += CountTokenText(choice.Message.StringContent(), model) | ||||
| 		} | ||||
| 		textResponse.Usage = Usage{ | ||||
| 			PromptTokens:     promptTokens, | ||||
							
								
								
									
										288
									
								
								relay/channel/openai/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										288
									
								
								relay/channel/openai/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,288 @@ | ||||
| package openai | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string  `json:"role"` | ||||
| 	Content any     `json:"content"` | ||||
| 	Name    *string `json:"name,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageURL struct { | ||||
| 	Url    string `json:"url,omitempty"` | ||||
| 	Detail string `json:"detail,omitempty"` | ||||
| } | ||||
|  | ||||
| type TextContent struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageContent struct { | ||||
| 	Type     string    `json:"type,omitempty"` | ||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` | ||||
| } | ||||
|  | ||||
| type OpenAIMessageContent struct { | ||||
| 	Type     string    `json:"type,omitempty"` | ||||
| 	Text     string    `json:"text"` | ||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m Message) IsStringContent() bool { | ||||
| 	_, ok := m.Content.(string) | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func (m Message) StringContent() string { | ||||
| 	content, ok := m.Content.(string) | ||||
| 	if ok { | ||||
| 		return content | ||||
| 	} | ||||
| 	contentList, ok := m.Content.([]any) | ||||
| 	if ok { | ||||
| 		var contentStr string | ||||
| 		for _, contentItem := range contentList { | ||||
| 			contentMap, ok := contentItem.(map[string]any) | ||||
| 			if !ok { | ||||
| 				continue | ||||
| 			} | ||||
| 			if contentMap["type"] == ContentTypeText { | ||||
| 				if subStr, ok := contentMap["text"].(string); ok { | ||||
| 					contentStr += subStr | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return contentStr | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (m Message) ParseContent() []OpenAIMessageContent { | ||||
| 	var contentList []OpenAIMessageContent | ||||
| 	content, ok := m.Content.(string) | ||||
| 	if ok { | ||||
| 		contentList = append(contentList, OpenAIMessageContent{ | ||||
| 			Type: ContentTypeText, | ||||
| 			Text: content, | ||||
| 		}) | ||||
| 		return contentList | ||||
| 	} | ||||
| 	anyList, ok := m.Content.([]any) | ||||
| 	if ok { | ||||
| 		for _, contentItem := range anyList { | ||||
| 			contentMap, ok := contentItem.(map[string]any) | ||||
| 			if !ok { | ||||
| 				continue | ||||
| 			} | ||||
| 			switch contentMap["type"] { | ||||
| 			case ContentTypeText: | ||||
| 				if subStr, ok := contentMap["text"].(string); ok { | ||||
| 					contentList = append(contentList, OpenAIMessageContent{ | ||||
| 						Type: ContentTypeText, | ||||
| 						Text: subStr, | ||||
| 					}) | ||||
| 				} | ||||
| 			case ContentTypeImageURL: | ||||
| 				if subObj, ok := contentMap["image_url"].(map[string]any); ok { | ||||
| 					contentList = append(contentList, OpenAIMessageContent{ | ||||
| 						Type: ContentTypeImageURL, | ||||
| 						ImageURL: &ImageURL{ | ||||
| 							Url: subObj["url"].(string), | ||||
| 						}, | ||||
| 					}) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return contentList | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type ResponseFormat struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Model            string          `json:"model,omitempty"` | ||||
| 	Messages         []Message       `json:"messages,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Stream           bool            `json:"stream,omitempty"` | ||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	N                int             `json:"n,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| 	Functions        any             `json:"functions,omitempty"` | ||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed             float64         `json:"seed,omitempty"` | ||||
| 	Tools            any             `json:"tools,omitempty"` | ||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
| 	if r.Input == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	var input []string | ||||
| 	switch r.Input.(type) { | ||||
| 	case string: | ||||
| 		input = []string{r.Input.(string)} | ||||
| 	case []any: | ||||
| 		input = make([]string, 0, len(r.Input.([]any))) | ||||
| 		for _, item := range r.Input.([]any) { | ||||
| 			if str, ok := item.(string); ok { | ||||
| 				input = append(input, str) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return input | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	Messages  []Message `json:"messages"` | ||||
| 	MaxTokens int       `json:"max_tokens"` | ||||
| } | ||||
|  | ||||
| type TextRequest struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	Messages  []Message `json:"messages"` | ||||
| 	Prompt    string    `json:"prompt"` | ||||
| 	MaxTokens int       `json:"max_tokens"` | ||||
| 	//Stream   bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| // ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create | ||||
| type ImageRequest struct { | ||||
| 	Model          string `json:"model"` | ||||
| 	Prompt         string `json:"prompt" binding:"required"` | ||||
| 	N              int    `json:"n,omitempty"` | ||||
| 	Size           string `json:"size,omitempty"` | ||||
| 	Quality        string `json:"quality,omitempty"` | ||||
| 	ResponseFormat string `json:"response_format,omitempty"` | ||||
| 	Style          string `json:"style,omitempty"` | ||||
| 	User           string `json:"user,omitempty"` | ||||
| } | ||||
|  | ||||
| type WhisperJSONResponse struct { | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type WhisperVerboseJSONResponse struct { | ||||
| 	Task     string    `json:"task,omitempty"` | ||||
| 	Language string    `json:"language,omitempty"` | ||||
| 	Duration float64   `json:"duration,omitempty"` | ||||
| 	Text     string    `json:"text,omitempty"` | ||||
| 	Segments []Segment `json:"segments,omitempty"` | ||||
| } | ||||
|  | ||||
| type Segment struct { | ||||
| 	Id               int     `json:"id"` | ||||
| 	Seek             int     `json:"seek"` | ||||
| 	Start            float64 `json:"start"` | ||||
| 	End              float64 `json:"end"` | ||||
| 	Text             string  `json:"text"` | ||||
| 	Tokens           []int   `json:"tokens"` | ||||
| 	Temperature      float64 `json:"temperature"` | ||||
| 	AvgLogprob       float64 `json:"avg_logprob"` | ||||
| 	CompressionRatio float64 `json:"compression_ratio"` | ||||
| 	NoSpeechProb     float64 `json:"no_speech_prob"` | ||||
| } | ||||
|  | ||||
| type TextToSpeechRequest struct { | ||||
| 	Model          string  `json:"model" binding:"required"` | ||||
| 	Input          string  `json:"input" binding:"required"` | ||||
| 	Voice          string  `json:"voice" binding:"required"` | ||||
| 	Speed          float64 `json:"speed"` | ||||
| 	ResponseFormat string  `json:"response_format"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	PromptTokens     int `json:"prompt_tokens"` | ||||
| 	CompletionTokens int `json:"completion_tokens"` | ||||
| 	TotalTokens      int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type UsageOrResponseText struct { | ||||
| 	*Usage | ||||
| 	ResponseText string | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Message string `json:"message"` | ||||
| 	Type    string `json:"type"` | ||||
| 	Param   string `json:"param"` | ||||
| 	Code    any    `json:"code"` | ||||
| } | ||||
|  | ||||
| type ErrorWithStatusCode struct { | ||||
| 	Error | ||||
| 	StatusCode int `json:"status_code"` | ||||
| } | ||||
|  | ||||
| type SlimTextResponse struct { | ||||
| 	Choices []TextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| 	Error   Error `json:"error"` | ||||
| } | ||||
|  | ||||
| type TextResponseChoice struct { | ||||
| 	Index        int `json:"index"` | ||||
| 	Message      `json:"message"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type TextResponse struct { | ||||
| 	Id      string               `json:"id"` | ||||
| 	Model   string               `json:"model,omitempty"` | ||||
| 	Object  string               `json:"object"` | ||||
| 	Created int64                `json:"created"` | ||||
| 	Choices []TextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponseItem struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Index     int       `json:"index"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Object string                  `json:"object"` | ||||
| 	Data   []EmbeddingResponseItem `json:"data"` | ||||
| 	Model  string                  `json:"model"` | ||||
| 	Usage  `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ImageResponse struct { | ||||
| 	Created int `json:"created"` | ||||
| 	Data    []struct { | ||||
| 		Url string `json:"url"` | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponseChoice struct { | ||||
| 	Delta struct { | ||||
| 		Content string `json:"content"` | ||||
| 	} `json:"delta"` | ||||
| 	FinishReason *string `json:"finish_reason,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponse struct { | ||||
| 	Id      string                                `json:"id"` | ||||
| 	Object  string                                `json:"object"` | ||||
| 	Created int64                                 `json:"created"` | ||||
| 	Model   string                                `json:"model"` | ||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||
| } | ||||
|  | ||||
| type CompletionsStreamResponse struct { | ||||
| 	Choices []struct { | ||||
| 		Text         string `json:"text"` | ||||
| 		FinishReason string `json:"finish_reason"` | ||||
| 	} `json:"choices"` | ||||
| } | ||||
| @@ -1,39 +1,31 @@ | ||||
| package controller | ||||
| package openai | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/image" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| 	"math" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/image" | ||||
| 	"one-api/common/logger" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| var stopFinishReason = "stop" | ||||
| 
 | ||||
| // tokenEncoderMap won't grow after initialization | ||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||
| var defaultTokenEncoder *tiktoken.Tiktoken | ||||
| 
 | ||||
| func InitTokenEncoders() { | ||||
| 	common.SysLog("initializing token encoders") | ||||
| 	logger.SysLog("initializing token encoders") | ||||
| 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||
| 	if err != nil { | ||||
| 		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||
| 		logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	defaultTokenEncoder = gpt35TokenEncoder | ||||
| 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||
| 	if err != nil { | ||||
| 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||
| 		logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	for model, _ := range common.ModelRatio { | ||||
| 		if strings.HasPrefix(model, "gpt-3.5") { | ||||
| @@ -44,7 +36,7 @@ func InitTokenEncoders() { | ||||
| 			tokenEncoderMap[model] = nil | ||||
| 		} | ||||
| 	} | ||||
| 	common.SysLog("token encoders initialized") | ||||
| 	logger.SysLog("token encoders initialized") | ||||
| } | ||||
| 
 | ||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||
| @@ -55,7 +47,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||
| 	if ok { | ||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||
| 		if err != nil { | ||||
| 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||
| 			logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||
| 			tokenEncoder = defaultTokenEncoder | ||||
| 		} | ||||
| 		tokenEncoderMap[model] = tokenEncoder | ||||
| @@ -65,13 +57,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||
| } | ||||
| 
 | ||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||
| 	if common.ApproximateTokenEnabled { | ||||
| 	if config.ApproximateTokenEnabled { | ||||
| 		return int(float64(len(text)) * 0.38) | ||||
| 	} | ||||
| 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||
| } | ||||
| 
 | ||||
| func countTokenMessages(messages []Message, model string) int { | ||||
| func CountTokenMessages(messages []Message, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	// Reference: | ||||
| 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
| @@ -109,7 +101,7 @@ func countTokenMessages(messages []Message, model string) int { | ||||
| 						} | ||||
| 						imageTokens, err := countImageTokens(url, detail) | ||||
| 						if err != nil { | ||||
| 							common.SysError("error counting image tokens: " + err.Error()) | ||||
| 							logger.SysError("error counting image tokens: " + err.Error()) | ||||
| 						} else { | ||||
| 							tokenNum += imageTokens | ||||
| 						} | ||||
| @@ -195,191 +187,21 @@ func countImageTokens(url string, detail string) (_ int, err error) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func countTokenInput(input any, model string) int { | ||||
| func CountTokenInput(input any, model string) int { | ||||
| 	switch v := input.(type) { | ||||
| 	case string: | ||||
| 		return countTokenText(v, model) | ||||
| 		return CountTokenText(v, model) | ||||
| 	case []string: | ||||
| 		text := "" | ||||
| 		for _, s := range v { | ||||
| 			text += s | ||||
| 		} | ||||
| 		return countTokenText(text, model) | ||||
| 		return CountTokenText(text, model) | ||||
| 	} | ||||
| 	return 0 | ||||
| } | ||||
| 
 | ||||
| func countTokenText(text string, model string) int { | ||||
| func CountTokenText(text string, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	return getTokenNum(tokenEncoder, text) | ||||
| } | ||||
| 
 | ||||
| func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { | ||||
| 	openAIError := OpenAIError{ | ||||
| 		Message: err.Error(), | ||||
| 		Type:    "one_api_error", | ||||
| 		Code:    code, | ||||
| 	} | ||||
| 	return &OpenAIErrorWithStatusCode{ | ||||
| 		OpenAIError: openAIError, | ||||
| 		StatusCode:  statusCode, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func shouldDisableChannel(err *OpenAIError, statusCode int) bool { | ||||
| 	if !common.AutomaticDisableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if statusCode == http.StatusUnauthorized { | ||||
| 		return true | ||||
| 	} | ||||
| 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func shouldEnableChannel(err error, openAIErr *OpenAIError) bool { | ||||
| 	if !common.AutomaticEnableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if openAIErr != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| func setEventStreamHeaders(c *gin.Context) { | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| } | ||||
| 
 | ||||
| type GeneralErrorResponse struct { | ||||
| 	Error    OpenAIError `json:"error"` | ||||
| 	Message  string      `json:"message"` | ||||
| 	Msg      string      `json:"msg"` | ||||
| 	Err      string      `json:"err"` | ||||
| 	ErrorMsg string      `json:"error_msg"` | ||||
| 	Header   struct { | ||||
| 		Message string `json:"message"` | ||||
| 	} `json:"header"` | ||||
| 	Response struct { | ||||
| 		Error struct { | ||||
| 			Message string `json:"message"` | ||||
| 		} `json:"error"` | ||||
| 	} `json:"response"` | ||||
| } | ||||
| 
 | ||||
| func (e GeneralErrorResponse) ToMessage() string { | ||||
| 	if e.Error.Message != "" { | ||||
| 		return e.Error.Message | ||||
| 	} | ||||
| 	if e.Message != "" { | ||||
| 		return e.Message | ||||
| 	} | ||||
| 	if e.Msg != "" { | ||||
| 		return e.Msg | ||||
| 	} | ||||
| 	if e.Err != "" { | ||||
| 		return e.Err | ||||
| 	} | ||||
| 	if e.ErrorMsg != "" { | ||||
| 		return e.ErrorMsg | ||||
| 	} | ||||
| 	if e.Header.Message != "" { | ||||
| 		return e.Header.Message | ||||
| 	} | ||||
| 	if e.Response.Error.Message != "" { | ||||
| 		return e.Response.Error.Message | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { | ||||
| 	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ | ||||
| 		StatusCode: resp.StatusCode, | ||||
| 		OpenAIError: OpenAIError{ | ||||
| 			Message: "", | ||||
| 			Type:    "upstream_error", | ||||
| 			Code:    "bad_response_status_code", | ||||
| 			Param:   strconv.Itoa(resp.StatusCode), | ||||
| 		}, | ||||
| 	} | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	var errResponse GeneralErrorResponse | ||||
| 	err = json.Unmarshal(responseBody, &errResponse) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if errResponse.Error.Message != "" { | ||||
| 		// OpenAI format error, so we override the default one | ||||
| 		openAIErrorWithStatusCode.OpenAIError = errResponse.Error | ||||
| 	} else { | ||||
| 		openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() | ||||
| 	} | ||||
| 	if openAIErrorWithStatusCode.OpenAIError.Message == "" { | ||||
| 		openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func getFullRequestURL(baseURL string, requestURL string, channelType int) string { | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 
 | ||||
| 	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { | ||||
| 		switch channelType { | ||||
| 		case common.ChannelTypeOpenAI: | ||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) | ||||
| 		case common.ChannelTypeAzure: | ||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) | ||||
| 		} | ||||
| 	} | ||||
| 	return fullRequestURL | ||||
| } | ||||
| 
 | ||||
| func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||
| 	// quotaDelta is remaining quota to be consumed | ||||
| 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 	if err != nil { | ||||
| 		common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 	} | ||||
| 	err = model.CacheUpdateUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		common.SysError("error update user quota cache: " + err.Error()) | ||||
| 	} | ||||
| 	// totalQuota is total quota consumed | ||||
| 	if totalQuota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||
| 	} | ||||
| 	if totalQuota <= 0 { | ||||
| 		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func GetAPIVersion(c *gin.Context) string { | ||||
| 	query := c.Request.URL.Query() | ||||
| 	apiVersion := query.Get("api-version") | ||||
| 	if apiVersion == "" { | ||||
| 		apiVersion = c.GetString("api_version") | ||||
| 	} | ||||
| 	return apiVersion | ||||
| } | ||||
							
								
								
									
										13
									
								
								relay/channel/openai/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								relay/channel/openai/util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| package openai | ||||
|  | ||||
| func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode { | ||||
| 	Error := Error{ | ||||
| 		Message: err.Error(), | ||||
| 		Type:    "one_api_error", | ||||
| 		Code:    code, | ||||
| 	} | ||||
| 	return &ErrorWithStatusCode{ | ||||
| 		Error:      Error, | ||||
| 		StatusCode: statusCode, | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										22
									
								
								relay/channel/tencent/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/tencent/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package tencent | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
							
								
								
									
										234
									
								
								relay/channel/tencent/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								relay/channel/tencent/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,234 @@ | ||||
| package tencent | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://cloud.tencent.com/document/product/1729/97732 | ||||
|  | ||||
| func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} | ||||
| 		messages = append(messages, Message{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    message.Role, | ||||
| 		}) | ||||
| 	} | ||||
| 	stream := 0 | ||||
| 	if request.Stream { | ||||
| 		stream = 1 | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Timestamp:   helper.GetTimestamp(), | ||||
| 		Expired:     helper.GetTimestamp() + 24*60*60, | ||||
| 		QueryID:     helper.GetUUID(), | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Stream:      stream, | ||||
| 		Messages:    messages, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Usage:   response.Usage, | ||||
| 	} | ||||
| 	if len(response.Choices) > 0 { | ||||
| 		choice := openai.TextResponseChoice{ | ||||
| 			Index: 0, | ||||
| 			Message: openai.Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: response.Choices[0].Messages.Content, | ||||
| 			}, | ||||
| 			FinishReason: response.Choices[0].FinishReason, | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "tencent-hunyuan", | ||||
| 	} | ||||
| 	if len(TencentResponse.Choices) > 0 { | ||||
| 		var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content | ||||
| 		if TencentResponse.Choices[0].FinishReason == "stop" { | ||||
| 			choice.FinishReason = &constant.StopFinishReason | ||||
| 		} | ||||
| 		response.Choices = append(response.Choices, choice) | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||
| 	var responseText string | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var TencentResponse ChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &TencentResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||
| 			if len(response.Choices) != 0 { | ||||
| 				responseText += response.Choices[0].Delta.Content | ||||
| 			} | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var TencentResponse ChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &TencentResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if TencentResponse.Error.Code != 0 { | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: TencentResponse.Error.Message, | ||||
| 				Code:    TencentResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseTencent2OpenAI(&TencentResponse) | ||||
| 	fullTextResponse.Model = "hunyuan" | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) { | ||||
| 	parts := strings.Split(config, "|") | ||||
| 	if len(parts) != 3 { | ||||
| 		err = errors.New("invalid tencent config") | ||||
| 		return | ||||
| 	} | ||||
| 	appId, err = strconv.ParseInt(parts[0], 10, 64) | ||||
| 	secretId = parts[1] | ||||
| 	secretKey = parts[2] | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetSign(req ChatRequest, secretKey string) string { | ||||
| 	params := make([]string, 0) | ||||
| 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | ||||
| 	params = append(params, "secret_id="+req.SecretId) | ||||
| 	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) | ||||
| 	params = append(params, "query_id="+req.QueryID) | ||||
| 	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) | ||||
| 	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) | ||||
| 	params = append(params, "stream="+strconv.Itoa(req.Stream)) | ||||
| 	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) | ||||
|  | ||||
| 	var messageStr string | ||||
| 	for _, msg := range req.Messages { | ||||
| 		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) | ||||
| 	} | ||||
| 	messageStr = strings.TrimSuffix(messageStr, ",") | ||||
| 	params = append(params, "messages=["+messageStr+"]") | ||||
|  | ||||
| 	sort.Sort(sort.StringSlice(params)) | ||||
| 	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") | ||||
| 	mac := hmac.New(sha1.New, []byte(secretKey)) | ||||
| 	signURL := url | ||||
| 	mac.Write([]byte(signURL)) | ||||
| 	sign := mac.Sum([]byte(nil)) | ||||
| 	return base64.StdEncoding.EncodeToString(sign) | ||||
| } | ||||
							
								
								
									
										63
									
								
								relay/channel/tencent/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								relay/channel/tencent/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| package tencent | ||||
|  | ||||
| import ( | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID | ||||
| 	SecretId string `json:"secret_id"` // 官网 SecretId | ||||
| 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 | ||||
| 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 | ||||
| 	Timestamp int64 `json:"timestamp"` | ||||
| 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, | ||||
| 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 | ||||
| 	Expired int64  `json:"expired"` | ||||
| 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 | ||||
| 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 | ||||
| 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 | ||||
| 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p | ||||
| 	Temperature float64 `json:"temperature"` | ||||
| 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 | ||||
| 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 | ||||
| 	// 建议该参数和 temperature 只设置1个,不要同时更改 | ||||
| 	TopP float64 `json:"top_p"` | ||||
| 	// Stream 0:同步,1:流式 (默认,协议:SSE) | ||||
| 	// 同步请求超时:60s,如果内容较长建议使用流式 | ||||
| 	Stream int `json:"stream"` | ||||
| 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 | ||||
| 	// 输入 content 总数最大支持 3000 token。 | ||||
| 	Messages []Message `json:"messages"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type ResponseChoices struct { | ||||
| 	FinishReason string  `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||
| 	Messages     Message `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| 	Delta        Message `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Choices []ResponseChoices `json:"choices,omitempty"` // 结果 | ||||
| 	Created string            `json:"created,omitempty"` // unix 时间戳的字符串 | ||||
| 	Id      string            `json:"id,omitempty"`      // 会话 id | ||||
| 	Usage   openai.Usage      `json:"usage,omitempty"`   // token 数量 | ||||
| 	Error   Error             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string            `json:"note,omitempty"`    // 注释 | ||||
| 	ReqID   string            `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| } | ||||
							
								
								
									
										22
									
								
								relay/channel/xunfei/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/xunfei/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package xunfei | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package xunfei | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/hmac" | ||||
| @@ -12,6 +12,10 @@ import ( | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| @@ -19,82 +23,26 @@ import ( | ||||
| // https://console.xfyun.cn/services/cbm | ||||
| // https://www.xfyun.cn/doc/spark/Web.html | ||||
| 
 | ||||
| type XunfeiMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
| 
 | ||||
| type XunfeiChatRequest struct { | ||||
| 	Header struct { | ||||
| 		AppId string `json:"app_id"` | ||||
| 	} `json:"header"` | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string  `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int     `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool    `json:"auditing,omitempty"` | ||||
| 		} `json:"chat"` | ||||
| 	} `json:"parameter"` | ||||
| 	Payload struct { | ||||
| 		Message struct { | ||||
| 			Text []XunfeiMessage `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
| 
 | ||||
| type XunfeiChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| } | ||||
| 
 | ||||
| type XunfeiChatResponse struct { | ||||
| 	Header struct { | ||||
| 		Code    int    `json:"code"` | ||||
| 		Message string `json:"message"` | ||||
| 		Sid     string `json:"sid"` | ||||
| 		Status  int    `json:"status"` | ||||
| 	} `json:"header"` | ||||
| 	Payload struct { | ||||
| 		Choices struct { | ||||
| 			Status int                          `json:"status"` | ||||
| 			Seq    int                          `json:"seq"` | ||||
| 			Text   []XunfeiChatResponseTextItem `json:"text"` | ||||
| 		} `json:"choices"` | ||||
| 		Usage struct { | ||||
| 			//Text struct { | ||||
| 			//	QuestionTokens   string `json:"question_tokens"` | ||||
| 			//	PromptTokens     string `json:"prompt_tokens"` | ||||
| 			//	CompletionTokens string `json:"completion_tokens"` | ||||
| 			//	TotalTokens      string `json:"total_tokens"` | ||||
| 			//} `json:"text"` | ||||
| 			Text Usage `json:"text"` | ||||
| 		} `json:"usage"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
| 
 | ||||
| func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { | ||||
| 	messages := make([]XunfeiMessage, 0, len(request.Messages)) | ||||
| func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	xunfeiRequest := XunfeiChatRequest{} | ||||
| 	xunfeiRequest := ChatRequest{} | ||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||
| 	xunfeiRequest.Parameter.Chat.Domain = domain | ||||
| 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||
| @@ -104,49 +52,49 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma | ||||
| 	return &xunfeiRequest | ||||
| } | ||||
| 
 | ||||
| func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | ||||
| func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 		response.Payload.Choices.Text = []ChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 		Message: openai.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Payload.Choices.Text[0].Content, | ||||
| 		}, | ||||
| 		FinishReason: stopFinishReason, | ||||
| 		FinishReason: constant.StopFinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 		Usage:   response.Payload.Usage.Text, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
| 
 | ||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | ||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 		xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 	if xunfeiResponse.Payload.Choices.Status == 2 { | ||||
| 		choice.FinishReason = &stopFinishReason | ||||
| 		choice.FinishReason = &constant.StopFinishReason | ||||
| 	} | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "SparkDesk", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
| @@ -177,14 +125,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||
| 	return callUrl | ||||
| } | ||||
| 
 | ||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	setEventStreamHeaders(c) | ||||
| 	var usage Usage | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var usage openai.Usage | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case xunfeiResponse := <-dataChan: | ||||
| @@ -194,7 +142,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId | ||||
| 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| @@ -207,15 +155,15 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId | ||||
| 	return nil, &usage | ||||
| } | ||||
| 
 | ||||
| func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var usage Usage | ||||
| 	var usage openai.Usage | ||||
| 	var content string | ||||
| 	var xunfeiResponse XunfeiChatResponse | ||||
| 	var xunfeiResponse ChatResponse | ||||
| 	stop := false | ||||
| 	for !stop { | ||||
| 		select { | ||||
| @@ -231,7 +179,7 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin | ||||
| 		} | ||||
| 	} | ||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 		xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| @@ -242,14 +190,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin | ||||
| 	response := responseXunfei2OpenAI(&xunfeiResponse) | ||||
| 	jsonResponse, err := json.Marshal(response) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	_, _ = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
| 
 | ||||
| func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { | ||||
| func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { | ||||
| 	d := websocket.Dialer{ | ||||
| 		HandshakeTimeout: 5 * time.Second, | ||||
| 	} | ||||
| @@ -263,26 +211,26 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	dataChan := make(chan XunfeiChatResponse) | ||||
| 	dataChan := make(chan ChatResponse) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := conn.ReadMessage() | ||||
| 			if err != nil { | ||||
| 				common.SysError("error reading stream response: " + err.Error()) | ||||
| 				logger.SysError("error reading stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			var response XunfeiChatResponse | ||||
| 			var response ChatResponse | ||||
| 			err = json.Unmarshal(msg, &response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			dataChan <- response | ||||
| 			if response.Payload.Choices.Status == 2 { | ||||
| 				err := conn.Close() | ||||
| 				if err != nil { | ||||
| 					common.SysError("error closing websocket connection: " + err.Error()) | ||||
| 					logger.SysError("error closing websocket connection: " + err.Error()) | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| @@ -301,7 +249,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, | ||||
| 	} | ||||
| 	if apiVersion == "" { | ||||
| 		apiVersion = "v1.1" | ||||
| 		common.SysLog("api_version not found, use default: " + apiVersion) | ||||
| 		logger.SysLog("api_version not found, use default: " + apiVersion) | ||||
| 	} | ||||
| 	domain := "general" | ||||
| 	if apiVersion != "v1.1" { | ||||
							
								
								
									
										61
									
								
								relay/channel/xunfei/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								relay/channel/xunfei/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| package xunfei | ||||
|  | ||||
| import ( | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Header struct { | ||||
| 		AppId string `json:"app_id"` | ||||
| 	} `json:"header"` | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string  `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int     `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool    `json:"auditing,omitempty"` | ||||
| 		} `json:"chat"` | ||||
| 	} `json:"parameter"` | ||||
| 	Payload struct { | ||||
| 		Message struct { | ||||
| 			Text []Message `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| type ChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Header struct { | ||||
| 		Code    int    `json:"code"` | ||||
| 		Message string `json:"message"` | ||||
| 		Sid     string `json:"sid"` | ||||
| 		Status  int    `json:"status"` | ||||
| 	} `json:"header"` | ||||
| 	Payload struct { | ||||
| 		Choices struct { | ||||
| 			Status int                    `json:"status"` | ||||
| 			Seq    int                    `json:"seq"` | ||||
| 			Text   []ChatResponseTextItem `json:"text"` | ||||
| 		} `json:"choices"` | ||||
| 		Usage struct { | ||||
| 			//Text struct { | ||||
| 			//	QuestionTokens   string `json:"question_tokens"` | ||||
| 			//	PromptTokens     string `json:"prompt_tokens"` | ||||
| 			//	CompletionTokens string `json:"completion_tokens"` | ||||
| 			//	TotalTokens      string `json:"total_tokens"` | ||||
| 			//} `json:"text"` | ||||
| 			Text openai.Usage `json:"text"` | ||||
| 		} `json:"usage"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
							
								
								
									
										22
									
								
								relay/channel/zhipu/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								relay/channel/zhipu/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package zhipu | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/relay/channel/openai" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Auth(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { | ||||
| 	return nil, nil, nil | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package zhipu | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| @@ -8,6 +8,10 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| @@ -18,53 +22,13 @@ import ( | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke | ||||
| 
 | ||||
| type ZhipuMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
| 
 | ||||
| type ZhipuRequest struct { | ||||
| 	Prompt      []ZhipuMessage `json:"prompt"` | ||||
| 	Temperature float64        `json:"temperature,omitempty"` | ||||
| 	TopP        float64        `json:"top_p,omitempty"` | ||||
| 	RequestId   string         `json:"request_id,omitempty"` | ||||
| 	Incremental bool           `json:"incremental,omitempty"` | ||||
| } | ||||
| 
 | ||||
| type ZhipuResponseData struct { | ||||
| 	TaskId     string         `json:"task_id"` | ||||
| 	RequestId  string         `json:"request_id"` | ||||
| 	TaskStatus string         `json:"task_status"` | ||||
| 	Choices    []ZhipuMessage `json:"choices"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
| 
 | ||||
| type ZhipuResponse struct { | ||||
| 	Code    int               `json:"code"` | ||||
| 	Msg     string            `json:"msg"` | ||||
| 	Success bool              `json:"success"` | ||||
| 	Data    ZhipuResponseData `json:"data"` | ||||
| } | ||||
| 
 | ||||
| type ZhipuStreamMetaResponse struct { | ||||
| 	RequestId  string `json:"request_id"` | ||||
| 	TaskId     string `json:"task_id"` | ||||
| 	TaskStatus string `json:"task_status"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
| 
 | ||||
| type zhipuTokenData struct { | ||||
| 	Token      string | ||||
| 	ExpiryTime time.Time | ||||
| } | ||||
| 
 | ||||
| var zhipuTokens sync.Map | ||||
| var expSeconds int64 = 24 * 3600 | ||||
| 
 | ||||
| func getZhipuToken(apikey string) string { | ||||
| func GetToken(apikey string) string { | ||||
| 	data, ok := zhipuTokens.Load(apikey) | ||||
| 	if ok { | ||||
| 		tokenData := data.(zhipuTokenData) | ||||
| 		tokenData := data.(tokenData) | ||||
| 		if time.Now().Before(tokenData.ExpiryTime) { | ||||
| 			return tokenData.Token | ||||
| 		} | ||||
| @@ -72,7 +36,7 @@ func getZhipuToken(apikey string) string { | ||||
| 
 | ||||
| 	split := strings.Split(apikey, ".") | ||||
| 	if len(split) != 2 { | ||||
| 		common.SysError("invalid zhipu key: " + apikey) | ||||
| 		logger.SysError("invalid zhipu key: " + apikey) | ||||
| 		return "" | ||||
| 	} | ||||
| 
 | ||||
| @@ -100,7 +64,7 @@ func getZhipuToken(apikey string) string { | ||||
| 		return "" | ||||
| 	} | ||||
| 
 | ||||
| 	zhipuTokens.Store(apikey, zhipuTokenData{ | ||||
| 	zhipuTokens.Store(apikey, tokenData{ | ||||
| 		Token:      tokenString, | ||||
| 		ExpiryTime: expiryTime, | ||||
| 	}) | ||||
| @@ -108,26 +72,26 @@ func getZhipuToken(apikey string) string { | ||||
| 	return tokenString | ||||
| } | ||||
| 
 | ||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | ||||
| func ConvertRequest(request openai.GeneralOpenAIRequest) *Request { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &ZhipuRequest{ | ||||
| 	return &Request{ | ||||
| 		Prompt:      messages, | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| @@ -135,18 +99,18 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| func responseZhipu2OpenAI(response *Response) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      response.Data.TaskId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), | ||||
| 		Usage:   response.Data.Usage, | ||||
| 	} | ||||
| 	for i, choice := range response.Data.Choices { | ||||
| 		openaiChoice := OpenAITextResponseChoice{ | ||||
| 		openaiChoice := openai.TextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 			Message: openai.Message{ | ||||
| 				Role:    choice.Role, | ||||
| 				Content: strings.Trim(choice.Content, "\""), | ||||
| 			}, | ||||
| @@ -160,34 +124,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | ||||
| 	return &fullTextResponse | ||||
| } | ||||
| 
 | ||||
| func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = zhipuResponse | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
| 
 | ||||
| func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = "" | ||||
| 	choice.FinishReason = &stopFinishReason | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 	choice.FinishReason = &constant.StopFinishReason | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      zhipuResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response, &zhipuResponse.Usage | ||||
| } | ||||
| 
 | ||||
| func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage *Usage | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var usage *openai.Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| @@ -224,29 +188,29 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			response := streamResponseZhipu2OpenAI(data) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case data := <-metaChan: | ||||
| 			var zhipuResponse ZhipuStreamMetaResponse | ||||
| 			var zhipuResponse StreamMetaResponse | ||||
| 			err := json.Unmarshal([]byte(data), &zhipuResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage = zhipuUsage | ||||
| @@ -259,28 +223,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, usage | ||||
| } | ||||
| 
 | ||||
| func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var zhipuResponse ZhipuResponse | ||||
| func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||
| 	var zhipuResponse Response | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if !zhipuResponse.Success { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 		return &openai.ErrorWithStatusCode{ | ||||
| 			Error: openai.Error{ | ||||
| 				Message: zhipuResponse.Msg, | ||||
| 				Type:    "zhipu_error", | ||||
| 				Param:   "", | ||||
| @@ -290,9 +254,10 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) | ||||
| 	fullTextResponse.Model = "chatglm" | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
							
								
								
									
										46
									
								
								relay/channel/zhipu/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								relay/channel/zhipu/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | ||||
| package zhipu | ||||
|  | ||||
| import ( | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type Request struct { | ||||
| 	Prompt      []Message `json:"prompt"` | ||||
| 	Temperature float64   `json:"temperature,omitempty"` | ||||
| 	TopP        float64   `json:"top_p,omitempty"` | ||||
| 	RequestId   string    `json:"request_id,omitempty"` | ||||
| 	Incremental bool      `json:"incremental,omitempty"` | ||||
| } | ||||
|  | ||||
| type ResponseData struct { | ||||
| 	TaskId       string    `json:"task_id"` | ||||
| 	RequestId    string    `json:"request_id"` | ||||
| 	TaskStatus   string    `json:"task_status"` | ||||
| 	Choices      []Message `json:"choices"` | ||||
| 	openai.Usage `json:"usage"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	Code    int          `json:"code"` | ||||
| 	Msg     string       `json:"msg"` | ||||
| 	Success bool         `json:"success"` | ||||
| 	Data    ResponseData `json:"data"` | ||||
| } | ||||
|  | ||||
| type StreamMetaResponse struct { | ||||
| 	RequestId    string `json:"request_id"` | ||||
| 	TaskId       string `json:"task_id"` | ||||
| 	TaskStatus   string `json:"task_status"` | ||||
| 	openai.Usage `json:"usage"` | ||||
| } | ||||
|  | ||||
| type tokenData struct { | ||||
| 	Token      string | ||||
| 	ExpiryTime time.Time | ||||
| } | ||||
							
								
								
									
										69
									
								
								relay/constant/api_type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								relay/constant/api_type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| package constant | ||||
|  | ||||
| import ( | ||||
| 	"one-api/common" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	APITypeOpenAI = iota | ||||
| 	APITypeClaude | ||||
| 	APITypePaLM | ||||
| 	APITypeBaidu | ||||
| 	APITypeZhipu | ||||
| 	APITypeAli | ||||
| 	APITypeXunfei | ||||
| 	APITypeAIProxyLibrary | ||||
| 	APITypeTencent | ||||
| 	APITypeGemini | ||||
| ) | ||||
|  | ||||
| func ChannelType2APIType(channelType int) int { | ||||
| 	apiType := APITypeOpenAI | ||||
| 	switch channelType { | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		apiType = APITypeClaude | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		apiType = APITypeBaidu | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		apiType = APITypePaLM | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		apiType = APITypeZhipu | ||||
| 	case common.ChannelTypeAli: | ||||
| 		apiType = APITypeAli | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		apiType = APITypeXunfei | ||||
| 	case common.ChannelTypeAIProxyLibrary: | ||||
| 		apiType = APITypeAIProxyLibrary | ||||
| 	case common.ChannelTypeTencent: | ||||
| 		apiType = APITypeTencent | ||||
| 	case common.ChannelTypeGemini: | ||||
| 		apiType = APITypeGemini | ||||
| 	} | ||||
| 	return apiType | ||||
| } | ||||
|  | ||||
| //func GetAdaptor(apiType int) channel.Adaptor { | ||||
| //	switch apiType { | ||||
| //	case APITypeOpenAI: | ||||
| //		return &openai.Adaptor{} | ||||
| //	case APITypeClaude: | ||||
| //		return &anthropic.Adaptor{} | ||||
| //	case APITypePaLM: | ||||
| //		return &google.Adaptor{} | ||||
| //	case APITypeZhipu: | ||||
| //		return &baidu.Adaptor{} | ||||
| //	case APITypeBaidu: | ||||
| //		return &baidu.Adaptor{} | ||||
| //	case APITypeAli: | ||||
| //		return &ali.Adaptor{} | ||||
| //	case APITypeXunfei: | ||||
| //		return &xunfei.Adaptor{} | ||||
| //	case APITypeAIProxyLibrary: | ||||
| //		return &aiproxy.Adaptor{} | ||||
| //	case APITypeTencent: | ||||
| //		return &tencent.Adaptor{} | ||||
| //	case APITypeGemini: | ||||
| //		return &google.Adaptor{} | ||||
| //	} | ||||
| //	return nil | ||||
| //} | ||||
							
								
								
									
										3
									
								
								relay/constant/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								relay/constant/common.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| package constant | ||||
|  | ||||
| var StopFinishReason = "stop" | ||||
							
								
								
									
										42
									
								
								relay/constant/relay_mode.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								relay/constant/relay_mode.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| package constant | ||||
|  | ||||
| import "strings" | ||||
|  | ||||
| const ( | ||||
| 	RelayModeUnknown = iota | ||||
| 	RelayModeChatCompletions | ||||
| 	RelayModeCompletions | ||||
| 	RelayModeEmbeddings | ||||
| 	RelayModeModerations | ||||
| 	RelayModeImagesGenerations | ||||
| 	RelayModeEdits | ||||
| 	RelayModeAudioSpeech | ||||
| 	RelayModeAudioTranscription | ||||
| 	RelayModeAudioTranslation | ||||
| ) | ||||
|  | ||||
| func Path2RelayMode(path string) int { | ||||
| 	relayMode := RelayModeUnknown | ||||
| 	if strings.HasPrefix(path, "/v1/chat/completions") { | ||||
| 		relayMode = RelayModeChatCompletions | ||||
| 	} else if strings.HasPrefix(path, "/v1/completions") { | ||||
| 		relayMode = RelayModeCompletions | ||||
| 	} else if strings.HasPrefix(path, "/v1/embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasSuffix(path, "embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasPrefix(path, "/v1/moderations") { | ||||
| 		relayMode = RelayModeModerations | ||||
| 	} else if strings.HasPrefix(path, "/v1/images/generations") { | ||||
| 		relayMode = RelayModeImagesGenerations | ||||
| 	} else if strings.HasPrefix(path, "/v1/edits") { | ||||
| 		relayMode = RelayModeEdits | ||||
| 	} else if strings.HasPrefix(path, "/v1/audio/speech") { | ||||
| 		relayMode = RelayModeAudioSpeech | ||||
| 	} else if strings.HasPrefix(path, "/v1/audio/transcriptions") { | ||||
| 		relayMode = RelayModeAudioTranscription | ||||
| 	} else if strings.HasPrefix(path, "/v1/audio/translations") { | ||||
| 		relayMode = RelayModeAudioTranslation | ||||
| 	} | ||||
| 	return relayMode | ||||
| } | ||||
| @@ -11,11 +11,16 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"one-api/relay/util" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { | ||||
| 	audioModel := "whisper-1" | ||||
| 
 | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| @@ -25,18 +30,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	group := c.GetString("group") | ||||
| 	tokenName := c.GetString("token_name") | ||||
| 
 | ||||
| 	var ttsRequest TextToSpeechRequest | ||||
| 	if relayMode == RelayModeAudioSpeech { | ||||
| 	var ttsRequest openai.TextToSpeechRequest | ||||
| 	if relayMode == constant.RelayModeAudioSpeech { | ||||
| 		// Read JSON | ||||
| 		err := common.UnmarshalBodyReusable(c, &ttsRequest) | ||||
| 		// Check if JSON is valid | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "invalid_json", http.StatusBadRequest) | ||||
| 			return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest) | ||||
| 		} | ||||
| 		audioModel = ttsRequest.Model | ||||
| 		// Check if text is too long 4096 | ||||
| 		if len(ttsRequest.Input) > 4096 { | ||||
| 			return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) | ||||
| 			return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| @@ -46,24 +51,24 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	var quota int | ||||
| 	var preConsumedQuota int | ||||
| 	switch relayMode { | ||||
| 	case RelayModeAudioSpeech: | ||||
| 	case constant.RelayModeAudioSpeech: | ||||
| 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) | ||||
| 		quota = preConsumedQuota | ||||
| 	default: | ||||
| 		preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) | ||||
| 		preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) | ||||
| 	} | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 
 | ||||
| 	// Check if user quota is enough | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota > 100*preConsumedQuota { | ||||
| 		// in this case, we do not pre-consume quota | ||||
| @@ -73,7 +78,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	if preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 			return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| @@ -83,7 +88,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[audioModel] != "" { | ||||
| 			audioModel = modelMap[audioModel] | ||||
| @@ -96,27 +101,27 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 
 | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||
| 		apiVersion := GetAPIVersion(c) | ||||
| 		apiVersion := util.GetAzureAPIVersion(c) | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | ||||
| 	} | ||||
| 
 | ||||
| 	requestBody := &bytes.Buffer{} | ||||
| 	_, err = io.Copy(requestBody, c.Request.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) | ||||
| 	responseFormat := c.DefaultPostForm("response_format", "json") | ||||
| 
 | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 
 | ||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 	if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| @@ -128,34 +133,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 
 | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	resp, err := util.HTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 
 | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 
 | ||||
| 	if relayMode != RelayModeAudioSpeech { | ||||
| 	if relayMode != constant.RelayModeAudioSpeech { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 
 | ||||
| 		var openAIErr TextResponse | ||||
| 		var openAIErr openai.SlimTextResponse | ||||
| 		if err = json.Unmarshal(responseBody, &openAIErr); err == nil { | ||||
| 			if openAIErr.Error.Message != "" { | ||||
| 				return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) | ||||
| 				return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| @@ -172,12 +177,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 		case "vtt": | ||||
| 			text, err = getTextFromVTT(responseBody) | ||||
| 		default: | ||||
| 			return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		quota = countTokenText(text, audioModel) | ||||
| 		quota = openai.CountTokenText(text, audioModel) | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| @@ -188,16 +193,16 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 					// negative means add quota back for token & user | ||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||
| 					if err != nil { | ||||
| 						common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) | ||||
| 						logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) | ||||
| 					} | ||||
| 				}() | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		return relayErrorHandler(resp) | ||||
| 		return util.RelayErrorHandler(resp) | ||||
| 	} | ||||
| 	quotaDelta := quota - preConsumedQuota | ||||
| 	defer func(ctx context.Context) { | ||||
| 		go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | ||||
| 		go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | ||||
| 	}(c.Request.Context()) | ||||
| 
 | ||||
| 	for k, v := range resp.Header { | ||||
| @@ -207,11 +212,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 
 | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @@ -221,7 +226,7 @@ func getTextFromVTT(body []byte) (string, error) { | ||||
| } | ||||
| 
 | ||||
| func getTextFromVerboseJSON(body []byte) (string, error) { | ||||
| 	var whisperResponse WhisperVerboseJSONResponse | ||||
| 	var whisperResponse openai.WhisperVerboseJSONResponse | ||||
| 	if err := json.Unmarshal(body, &whisperResponse); err != nil { | ||||
| 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) | ||||
| 	} | ||||
| @@ -254,7 +259,7 @@ func getTextFromText(body []byte) (string, error) { | ||||
| } | ||||
| 
 | ||||
| func getTextFromJSON(body []byte) (string, error) { | ||||
| 	var whisperResponse WhisperJSONResponse | ||||
| 	var whisperResponse openai.WhisperJSONResponse | ||||
| 	if err := json.Unmarshal(body, &whisperResponse); err != nil { | ||||
| 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) | ||||
| 	} | ||||
| @@ -9,7 +9,10 @@ import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/util" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -25,7 +28,7 @@ func isWithinRange(element string, value int) bool { | ||||
| 	return value >= min && value <= max | ||||
| } | ||||
| 
 | ||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { | ||||
| 	imageModel := "dall-e-2" | ||||
| 	imageSize := "1024x1024" | ||||
| 
 | ||||
| @@ -35,10 +38,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
| 
 | ||||
| 	var imageRequest ImageRequest | ||||
| 	var imageRequest openai.ImageRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
| 
 | ||||
| 	if imageRequest.N == 0 { | ||||
| @@ -67,24 +70,24 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
| 
 | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 
 | ||||
| 	// Check prompt length | ||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { | ||||
| 		return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
| 
 | ||||
| 	// Number of generated images validation | ||||
| 	if isWithinRange(imageModel, imageRequest.N) == false { | ||||
| 		// channel not azure | ||||
| 		if channelType != common.ChannelTypeAzure { | ||||
| 			return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| @@ -95,7 +98,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[imageModel] != "" { | ||||
| 			imageModel = modelMap[imageModel] | ||||
| @@ -107,10 +110,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||
| 		apiVersion := GetAPIVersion(c) | ||||
| 		apiVersion := util.GetAzureAPIVersion(c) | ||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | ||||
| 	} | ||||
| @@ -119,7 +122,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| @@ -134,12 +137,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||
| 
 | ||||
| 	if userQuota-quota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 
 | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	token := c.Request.Header.Get("Authorization") | ||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication | ||||
| @@ -152,29 +155,32 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 
 | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	resp, err := util.HTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 
 | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var textResponse ImageResponse | ||||
| 	var textResponse openai.ImageResponse | ||||
| 
 | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			return | ||||
| 		} | ||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			logger.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 		} | ||||
| 		err = model.CacheUpdateUserQuota(userId) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			logger.SysError("error update user quota cache: " + err.Error()) | ||||
| 		} | ||||
| 		if quota != 0 { | ||||
| 			tokenName := c.GetString("token_name") | ||||
| @@ -189,15 +195,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 
 | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| @@ -209,11 +215,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 
 | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										173
									
								
								relay/controller/text.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								relay/controller/text.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/config" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/constant" | ||||
| 	"one-api/relay/util" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { | ||||
| 	ctx := c.Request.Context() | ||||
| 	meta := util.GetRelayMeta(c) | ||||
| 	var textRequest openai.GeneralOpenAIRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if relayMode == constant.RelayModeModerations && textRequest.Model == "" { | ||||
| 		textRequest.Model = "text-moderation-latest" | ||||
| 	} | ||||
| 	if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { | ||||
| 		textRequest.Model = c.Param("model") | ||||
| 	} | ||||
| 	err = util.ValidateTextRequest(&textRequest, relayMode) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) | ||||
| 	} | ||||
| 	var isModelMapped bool | ||||
| 	textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) | ||||
| 	apiType := constant.ChannelType2APIType(meta.ChannelType) | ||||
| 	fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest) | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error())) | ||||
| 		return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var promptTokens int | ||||
| 	var completionTokens int | ||||
| 	switch relayMode { | ||||
| 	case constant.RelayModeChatCompletions: | ||||
| 		promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model) | ||||
| 	case constant.RelayModeCompletions: | ||||
| 		promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model) | ||||
| 	case constant.RelayModeModerations: | ||||
| 		promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model) | ||||
| 	} | ||||
| 	preConsumedTokens := config.PreConsumedQuota | ||||
| 	if textRequest.MaxTokens != 0 { | ||||
| 		preConsumedTokens = promptTokens + textRequest.MaxTokens | ||||
| 	} | ||||
| 	modelRatio := common.GetModelRatio(textRequest.Model) | ||||
| 	groupRatio := common.GetGroupRatio(meta.Group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||
| 	userQuota, err := model.CacheGetUserQuota(meta.UserId) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 	err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota > 100*preConsumedQuota { | ||||
| 		// in this case, we do not pre-consume quota | ||||
| 		// because the user has enough quota | ||||
| 		preConsumedQuota = 0 | ||||
| 		logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) | ||||
| 	} | ||||
| 	if preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 		} | ||||
| 	} | ||||
| 	requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var req *http.Request | ||||
| 	var resp *http.Response | ||||
| 	isStream := textRequest.Stream | ||||
|  | ||||
| 	if apiType != constant.APITypeXunfei { // cause xunfei use websocket | ||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		SetupRequestHeaders(c, req, apiType, meta, isStream) | ||||
| 		resp, err = util.HTTPClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = req.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = c.Request.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
|  | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) | ||||
| 			return util.RelayErrorHandler(resp) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var respErr *openai.ErrorWithStatusCode | ||||
| 	var usage *openai.Usage | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		// Why we use defer here? Because if error happened, we will have to return the pre-consumed quota. | ||||
| 		if respErr != nil { | ||||
| 			logger.Errorf(ctx, "respErr is not nil: %+v", respErr) | ||||
| 			util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) | ||||
| 			return | ||||
| 		} | ||||
| 		if usage == nil { | ||||
| 			logger.Error(ctx, "usage is nil, which is unexpected") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		go func() { | ||||
| 			quota := 0 | ||||
| 			completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 			promptTokens = usage.PromptTokens | ||||
| 			completionTokens = usage.CompletionTokens | ||||
| 			quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||
| 			if ratio != 0 && quota <= 0 { | ||||
| 				quota = 1 | ||||
| 			} | ||||
| 			totalTokens := promptTokens + completionTokens | ||||
| 			if totalTokens == 0 { | ||||
| 				// in this case, must be some error happened | ||||
| 				// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 				quota = 0 | ||||
| 			} | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				logger.Error(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(meta.UserId) | ||||
| 			if err != nil { | ||||
| 				logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 				model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||
| 			} | ||||
| 		}() | ||||
| 	}(ctx) | ||||
| 	usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens) | ||||
| 	if respErr != nil { | ||||
| 		return respErr | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										337
									
								
								relay/controller/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										337
									
								
								relay/controller/util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,337 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/common/helper" | ||||
| 	"one-api/relay/channel/aiproxy" | ||||
| 	"one-api/relay/channel/ali" | ||||
| 	"one-api/relay/channel/anthropic" | ||||
| 	"one-api/relay/channel/baidu" | ||||
| 	"one-api/relay/channel/google" | ||||
| 	"one-api/relay/channel/openai" | ||||
| 	"one-api/relay/channel/tencent" | ||||
| 	"one-api/relay/channel/xunfei" | ||||
| 	"one-api/relay/channel/zhipu" | ||||
| 	"one-api/relay/constant" | ||||
| 	"one-api/relay/util" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) { | ||||
| 	fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) | ||||
| 	switch apiType { | ||||
| 	case constant.APITypeOpenAI: | ||||
| 		if meta.ChannelType == common.ChannelTypeAzure { | ||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 			requestURL := strings.Split(requestURL, "?")[0] | ||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) | ||||
| 			task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 			model_ := textRequest.Model | ||||
| 			model_ = strings.Replace(model_, ".", "", -1) | ||||
| 			// https://github.com/songquanpeng/one-api/issues/67 | ||||
| 			model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | ||||
|  | ||||
| 			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||
| 			fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) | ||||
| 		} | ||||
| 	case constant.APITypeClaude: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL) | ||||
| 	case constant.APITypeBaidu: | ||||
| 		switch textRequest.Model { | ||||
| 		case "ERNIE-Bot": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 		case "ERNIE-Bot-turbo": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "ERNIE-Bot-4": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 		} | ||||
| 		var accessToken string | ||||
| 		var err error | ||||
| 		if accessToken, err = baidu.GetAccessToken(meta.APIKey); err != nil { | ||||
| 			return "", fmt.Errorf("failed to get baidu access token: %w", err) | ||||
| 		} | ||||
| 		fullRequestURL += "?access_token=" + accessToken | ||||
| 	case constant.APITypePaLM: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL) | ||||
| 	case constant.APITypeGemini: | ||||
| 		version := helper.AssignOrDefault(meta.APIVersion, "v1") | ||||
| 		action := "generateContent" | ||||
| 		if textRequest.Stream { | ||||
| 			action = "streamGenerateContent" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, textRequest.Model, action) | ||||
| 	case constant.APITypeZhipu: | ||||
| 		method := "invoke" | ||||
| 		if textRequest.Stream { | ||||
| 			method = "sse-invoke" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||
| 	case constant.APITypeAli: | ||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||
| 		if relayMode == constant.RelayModeEmbeddings { | ||||
| 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||
| 		} | ||||
| 	case constant.APITypeTencent: | ||||
| 		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" | ||||
| 	case constant.APITypeAIProxyLibrary: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", meta.BaseURL) | ||||
| 	} | ||||
| 	return fullRequestURL, nil | ||||
| } | ||||
|  | ||||
| func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) { | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(textRequest) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
| 	switch apiType { | ||||
| 	case constant.APITypeClaude: | ||||
| 		claudeRequest := anthropic.ConvertRequest(textRequest) | ||||
| 		jsonStr, err := json.Marshal(claudeRequest) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case constant.APITypeBaidu: | ||||
| 		var jsonData []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case constant.RelayModeEmbeddings: | ||||
| 			baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||
| 		default: | ||||
| 			baiduRequest := baidu.ConvertRequest(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonData) | ||||
| 	case constant.APITypePaLM: | ||||
| 		palmRequest := google.ConvertPaLMRequest(textRequest) | ||||
| 		jsonStr, err := json.Marshal(palmRequest) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case constant.APITypeGemini: | ||||
| 		geminiChatRequest := google.ConvertGeminiRequest(textRequest) | ||||
| 		jsonStr, err := json.Marshal(geminiChatRequest) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case constant.APITypeZhipu: | ||||
| 		zhipuRequest := zhipu.ConvertRequest(textRequest) | ||||
| 		jsonStr, err := json.Marshal(zhipuRequest) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case constant.APITypeAli: | ||||
| 		var jsonStr []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case constant.RelayModeEmbeddings: | ||||
| 			aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest) | ||||
| 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | ||||
| 		default: | ||||
| 			aliRequest := ali.ConvertRequest(textRequest) | ||||
| 			jsonStr, err = json.Marshal(aliRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case constant.APITypeTencent: | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		appId, secretId, secretKey, err := tencent.ParseConfig(apiKey) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tencentRequest := tencent.ConvertRequest(textRequest) | ||||
| 		tencentRequest.AppId = appId | ||||
| 		tencentRequest.SecretId = secretId | ||||
| 		jsonStr, err := json.Marshal(tencentRequest) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sign := tencent.GetSign(*tencentRequest, secretKey) | ||||
| 		c.Request.Header.Set("Authorization", sign) | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case constant.APITypeAIProxyLibrary: | ||||
| 		aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) | ||||
| 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||
| 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} | ||||
| 	return requestBody, nil | ||||
| } | ||||
|  | ||||
| func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { | ||||
| 	SetupAuthHeaders(c, req, apiType, meta, isStream) | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 	if isStream && c.Request.Header.Get("Accept") == "" { | ||||
| 		req.Header.Set("Accept", "text/event-stream") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { | ||||
| 	apiKey := meta.APIKey | ||||
| 	switch apiType { | ||||
| 	case constant.APITypeOpenAI: | ||||
| 		if meta.ChannelType == common.ChannelTypeAzure { | ||||
| 			req.Header.Set("api-key", apiKey) | ||||
| 		} else { | ||||
| 			req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 			if meta.ChannelType == common.ChannelTypeOpenRouter { | ||||
| 				req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") | ||||
| 				req.Header.Set("X-Title", "One API") | ||||
| 			} | ||||
| 		} | ||||
| 	case constant.APITypeClaude: | ||||
| 		req.Header.Set("x-api-key", apiKey) | ||||
| 		anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 		if anthropicVersion == "" { | ||||
| 			anthropicVersion = "2023-06-01" | ||||
| 		} | ||||
| 		req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 	case constant.APITypeZhipu: | ||||
| 		token := zhipu.GetToken(apiKey) | ||||
| 		req.Header.Set("Authorization", token) | ||||
| 	case constant.APITypeAli: | ||||
| 		req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 		if isStream { | ||||
| 			req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 		} | ||||
| 		if c.GetString("plugin") != "" { | ||||
| 			req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) | ||||
| 		} | ||||
| 	case constant.APITypeTencent: | ||||
| 		req.Header.Set("Authorization", apiKey) | ||||
| 	case constant.APITypePaLM: | ||||
| 		req.Header.Set("x-goog-api-key", apiKey) | ||||
| 	case constant.APITypeGemini: | ||||
| 		req.Header.Set("x-goog-api-key", apiKey) | ||||
| 	default: | ||||
| 		req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) { | ||||
| 	var responseText string | ||||
| 	switch apiType { | ||||
| 	case constant.APITypeOpenAI: | ||||
| 		if isStream { | ||||
| 			err, responseText = openai.StreamHandler(c, resp, relayMode) | ||||
| 		} else { | ||||
| 			err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model) | ||||
| 		} | ||||
| 	case constant.APITypeClaude: | ||||
| 		if isStream { | ||||
| 			err, responseText = anthropic.StreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			err, usage = anthropic.Handler(c, resp, promptTokens, textRequest.Model) | ||||
| 		} | ||||
| 	case constant.APITypeBaidu: | ||||
| 		if isStream { | ||||
| 			err, usage = baidu.StreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			switch relayMode { | ||||
| 			case constant.RelayModeEmbeddings: | ||||
| 				err, usage = baidu.EmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = baidu.Handler(c, resp) | ||||
| 			} | ||||
| 		} | ||||
| 	case constant.APITypePaLM: | ||||
| 		if isStream { // PaLM2 API does not support stream | ||||
| 			err, responseText = google.PaLMStreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 		} | ||||
| 	case constant.APITypeGemini: | ||||
| 		if isStream { | ||||
| 			err, responseText = google.StreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 		} | ||||
| 	case constant.APITypeZhipu: | ||||
| 		if isStream { | ||||
| 			err, usage = zhipu.StreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			err, usage = zhipu.Handler(c, resp) | ||||
| 		} | ||||
| 	case constant.APITypeAli: | ||||
| 		if isStream { | ||||
| 			err, usage = ali.StreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			switch relayMode { | ||||
| 			case constant.RelayModeEmbeddings: | ||||
| 				err, usage = ali.EmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = ali.Handler(c, resp) | ||||
| 			} | ||||
| 		} | ||||
| 	case constant.APITypeXunfei: | ||||
| 		auth := c.Request.Header.Get("Authorization") | ||||
| 		auth = strings.TrimPrefix(auth, "Bearer ") | ||||
| 		splits := strings.Split(auth, "|") | ||||
| 		if len(splits) != 3 { | ||||
| 			return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||
| 		} | ||||
| 		if isStream { | ||||
| 			err, usage = xunfei.StreamHandler(c, *textRequest, splits[0], splits[1], splits[2]) | ||||
| 		} else { | ||||
| 			err, usage = xunfei.Handler(c, *textRequest, splits[0], splits[1], splits[2]) | ||||
| 		} | ||||
| 	case constant.APITypeAIProxyLibrary: | ||||
| 		if isStream { | ||||
| 			err, usage = aiproxy.StreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			err, usage = aiproxy.Handler(c, resp) | ||||
| 		} | ||||
| 	case constant.APITypeTencent: | ||||
| 		if isStream { | ||||
| 			err, responseText = tencent.StreamHandler(c, resp) | ||||
| 		} else { | ||||
| 			err, usage = tencent.Handler(c, resp) | ||||
| 		} | ||||
| 	default: | ||||
| 		return nil, openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if usage == nil && responseText != "" { | ||||
| 		usage = &openai.Usage{} | ||||
| 		usage.PromptTokens = promptTokens | ||||
| 		usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) | ||||
| 		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||
| 	} | ||||
| 	return usage, nil | ||||
| } | ||||
							
								
								
									
										19
									
								
								relay/util/billing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								relay/util/billing.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package util | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"one-api/common/logger" | ||||
| 	"one-api/model" | ||||
| ) | ||||
|  | ||||
| func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { | ||||
| 	if preConsumedQuota != 0 { | ||||
| 		go func(ctx context.Context) { | ||||
| 			// return pre-consumed quota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||
| 			if err != nil { | ||||
| 				logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 			} | ||||
| 		}(ctx) | ||||
| 	} | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user