mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 05:43:42 +08:00 
			
		
		
		
	Compare commits
	
		
			24 Commits
		
	
	
		
			v0.6.0
			...
			v0.6.2-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 12440874b0 | ||
|  | 6ebc99460e | ||
|  | 27ad8bfb98 | ||
|  | 8388aa537f | ||
|  | 2346bf70af | ||
|  | f05b403ca5 | ||
|  | b33616df44 | ||
|  | cf16f44970 | ||
|  | bf2e26a48f | ||
|  | 4fb22ad4ce | ||
|  | 95cfb8e8c9 | ||
|  | c6ace985c2 | ||
|  | 10a926b8f3 | ||
|  | 2df877a352 | ||
|  | 9d8967f7d3 | ||
|  | b35f3523d3 | ||
|  | 82e916b5ff | ||
|  | de18d6fe16 | ||
|  | 1d0b7fb5ae | ||
|  | f9490bb72e | ||
|  | 76467285e8 | ||
|  | df1fd9aa81 | ||
|  | 614c2e0442 | ||
|  | eac6a0b9aa | 
							
								
								
									
										2
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -38,7 +38,7 @@ jobs: | ||||
|       - name: Build Backend (amd64) | ||||
|         run: | | ||||
|           go mod download | ||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||
|           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||
|  | ||||
|       - name: Build Backend (arm64) | ||||
|         run: | | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -38,7 +38,7 @@ jobs: | ||||
|       - name: Build Backend | ||||
|         run: | | ||||
|           go mod download | ||||
|           go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||
|           go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||
|       - name: Release | ||||
|         uses: softprops/action-gh-release@v1 | ||||
|         if: startsWith(github.ref, 'refs/tags/') | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -41,7 +41,7 @@ jobs: | ||||
|       - name: Build Backend | ||||
|         run: | | ||||
|           go mod download | ||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||
|           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||
|       - name: Release | ||||
|         uses: softprops/action-gh-release@v1 | ||||
|         if: startsWith(github.ref, 'refs/tags/') | ||||
|   | ||||
| @@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) | ||||
|    + [x] [Mistral 系列模型](https://mistral.ai/) | ||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
| @@ -74,8 +75,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [360 智脑](https://ai.360.cn) | ||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||
|    + [x] [Moonshot AI](https://platform.moonshot.cn/) | ||||
|    + [x] [百川大模型](https://platform.baichuan-ai.com) | ||||
|    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) | ||||
|    + [ ] [MINIMAX](https://api.minimax.chat/) (WIP) | ||||
|    + [x] [MINIMAX](https://api.minimax.chat/) | ||||
|    + [x] [Groq](https://wow.groq.com/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| @@ -372,6 +375,9 @@ graph LR | ||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 19.  `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||
| 21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
|   | ||||
							
								
								
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package blacklist | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| var blackList sync.Map | ||||
|  | ||||
| func init() { | ||||
| 	blackList = sync.Map{} | ||||
| } | ||||
|  | ||||
| func userId2Key(id int) string { | ||||
| 	return fmt.Sprintf("userid_%d", id) | ||||
| } | ||||
|  | ||||
| func BanUser(id int) { | ||||
| 	blackList.Store(userId2Key(id), true) | ||||
| } | ||||
|  | ||||
| func UnbanUser(id int) { | ||||
| 	blackList.Delete(userId2Key(id)) | ||||
| } | ||||
|  | ||||
| func IsUserBanned(id int) bool { | ||||
| 	_, ok := blackList.Load(userId2Key(id)) | ||||
| 	return ok | ||||
| } | ||||
| @@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{ | ||||
| } | ||||
|  | ||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||
| var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" | ||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||
|  | ||||
| var LogConsumeEnabled = true | ||||
| @@ -125,3 +126,9 @@ var ( | ||||
| ) | ||||
|  | ||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||
|  | ||||
| var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) | ||||
| var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) | ||||
| var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) | ||||
| var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) | ||||
| var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) | ||||
|   | ||||
| @@ -15,6 +15,7 @@ const ( | ||||
| const ( | ||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	UserStatusDisabled = 2 // also don't use 0 | ||||
| 	UserStatusDeleted  = 3 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -38,32 +39,38 @@ const ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelTypeUnknown        = 0 | ||||
| 	ChannelTypeOpenAI         = 1 | ||||
| 	ChannelTypeAPI2D          = 2 | ||||
| 	ChannelTypeAzure          = 3 | ||||
| 	ChannelTypeCloseAI        = 4 | ||||
| 	ChannelTypeOpenAISB       = 5 | ||||
| 	ChannelTypeOpenAIMax      = 6 | ||||
| 	ChannelTypeOhMyGPT        = 7 | ||||
| 	ChannelTypeCustom         = 8 | ||||
| 	ChannelTypeAILS           = 9 | ||||
| 	ChannelTypeAIProxy        = 10 | ||||
| 	ChannelTypePaLM           = 11 | ||||
| 	ChannelTypeAPI2GPT        = 12 | ||||
| 	ChannelTypeAIGC2D         = 13 | ||||
| 	ChannelTypeAnthropic      = 14 | ||||
| 	ChannelTypeBaidu          = 15 | ||||
| 	ChannelTypeZhipu          = 16 | ||||
| 	ChannelTypeAli            = 17 | ||||
| 	ChannelTypeXunfei         = 18 | ||||
| 	ChannelType360            = 19 | ||||
| 	ChannelTypeOpenRouter     = 20 | ||||
| 	ChannelTypeAIProxyLibrary = 21 | ||||
| 	ChannelTypeFastGPT        = 22 | ||||
| 	ChannelTypeTencent        = 23 | ||||
| 	ChannelTypeGemini         = 24 | ||||
| 	ChannelTypeMoonshot       = 25 | ||||
| 	ChannelTypeUnknown = iota | ||||
| 	ChannelTypeOpenAI | ||||
| 	ChannelTypeAPI2D | ||||
| 	ChannelTypeAzure | ||||
| 	ChannelTypeCloseAI | ||||
| 	ChannelTypeOpenAISB | ||||
| 	ChannelTypeOpenAIMax | ||||
| 	ChannelTypeOhMyGPT | ||||
| 	ChannelTypeCustom | ||||
| 	ChannelTypeAILS | ||||
| 	ChannelTypeAIProxy | ||||
| 	ChannelTypePaLM | ||||
| 	ChannelTypeAPI2GPT | ||||
| 	ChannelTypeAIGC2D | ||||
| 	ChannelTypeAnthropic | ||||
| 	ChannelTypeBaidu | ||||
| 	ChannelTypeZhipu | ||||
| 	ChannelTypeAli | ||||
| 	ChannelTypeXunfei | ||||
| 	ChannelType360 | ||||
| 	ChannelTypeOpenRouter | ||||
| 	ChannelTypeAIProxyLibrary | ||||
| 	ChannelTypeFastGPT | ||||
| 	ChannelTypeTencent | ||||
| 	ChannelTypeGemini | ||||
| 	ChannelTypeMoonshot | ||||
| 	ChannelTypeBaichuan | ||||
| 	ChannelTypeMinimax | ||||
| 	ChannelTypeMistral | ||||
| 	ChannelTypeGroq | ||||
|  | ||||
| 	ChannelTypeDummy | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| @@ -93,6 +100,10 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://hunyuan.cloud.tencent.com",         // 23 | ||||
| 	"https://generativelanguage.googleapis.com", // 24 | ||||
| 	"https://api.moonshot.cn",                   // 25 | ||||
| 	"https://api.baichuan-ai.com",               // 26 | ||||
| 	"https://api.minimax.chat",                  // 27 | ||||
| 	"https://api.mistral.ai",                    // 28 | ||||
| 	"https://api.groq.com/openai",               // 29 | ||||
| } | ||||
|  | ||||
| const ( | ||||
|   | ||||
| @@ -195,6 +195,13 @@ func Max(a int, b int) int { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvBool(env string, defaultValue bool) bool { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) == "true" | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvInt(env string, defaultValue int) int { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| @@ -207,6 +214,18 @@ func GetOrDefaultEnvInt(env string, defaultValue int) int { | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.ParseFloat(os.Getenv(env), 64) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue)) | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvString(env string, defaultValue string) string { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
|   | ||||
| @@ -13,6 +13,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	loggerDEBUG = "DEBUG" | ||||
| 	loggerINFO  = "INFO" | ||||
| 	loggerWarn  = "WARN" | ||||
| 	loggerError = "ERR" | ||||
| @@ -55,6 +56,10 @@ func SysError(s string) { | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||
| } | ||||
|  | ||||
| func Debug(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerDEBUG, msg) | ||||
| } | ||||
|  | ||||
| func Info(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerINFO, msg) | ||||
| } | ||||
| @@ -67,6 +72,10 @@ func Error(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerError, msg) | ||||
| } | ||||
|  | ||||
| func Debugf(ctx context.Context, format string, a ...any) { | ||||
| 	Debug(ctx, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Infof(ctx context.Context, format string, a ...any) { | ||||
| 	Info(ctx, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|   | ||||
| @@ -7,29 +7,6 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var DalleSizeRatios = map[string]map[string]float64{ | ||||
| 	"dall-e-2": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1.125, | ||||
| 		"1024x1024": 1.25, | ||||
| 	}, | ||||
| 	"dall-e-3": { | ||||
| 		"1024x1024": 1, | ||||
| 		"1024x1792": 2, | ||||
| 		"1792x1024": 2, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var DalleGenerationImageAmounts = map[string][2]int{ | ||||
| 	"dall-e-2": {1, 10}, | ||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||
| } | ||||
|  | ||||
| var DalleImagePromptLengthLimitations = map[string]int{ | ||||
| 	"dall-e-2": 1000, | ||||
| 	"dall-e-3": 4000, | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	USD2RMB = 7 | ||||
| 	USD     = 500 // $0.002 = 1 -> $1 = 500 | ||||
| @@ -40,7 +17,6 @@ const ( | ||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||
| // https://openai.com/pricing | ||||
| // TODO: when a new api is enabled, check the pricing here | ||||
| // 1 === $0.002 / 1K tokens | ||||
| // 1 === ¥0.014 / 1k tokens | ||||
| var ModelRatio = map[string]float64{ | ||||
| @@ -87,21 +63,28 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-search-ada-doc-001": 10, | ||||
| 	"text-moderation-stable":  0.1, | ||||
| 	"text-moderation-latest":  0.1, | ||||
| 	"dall-e-2":                8,     // $0.016 - $0.020 / image | ||||
| 	"dall-e-3":                20,    // $0.040 - $0.120 / image | ||||
| 	"claude-instant-1":        0.815, // $1.63 / 1M tokens | ||||
| 	"claude-2":                5.51,  // $11.02 / 1M tokens | ||||
| 	"claude-2.0":              5.51,  // $11.02 / 1M tokens | ||||
| 	"claude-2.1":              5.51,  // $11.02 / 1M tokens | ||||
| 	"dall-e-2":                8,  // $0.016 - $0.020 / image | ||||
| 	"dall-e-3":                20, // $0.040 - $0.120 / image | ||||
| 	// https://www.anthropic.com/api#pricing | ||||
| 	"claude-instant-1.2":       0.8 / 1000 * USD, | ||||
| 	"claude-2.0":               8.0 / 1000 * USD, | ||||
| 	"claude-2.1":               8.0 / 1000 * USD, | ||||
| 	"claude-3-haiku-20240229":  0.25 / 1000 * USD, | ||||
| 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | ||||
| 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||
| 	"ERNIE-Bot":                 0.8572,     // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":           0.5715,     // ¥0.008 / 1k tokens | ||||
| 	"ERNIE-Bot-4":               0.12 * RMB, // ¥0.12 / 1k tokens | ||||
| 	"ERNIE-Bot-8k":              0.024 * RMB, | ||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                    1, | ||||
| 	"gemini-pro":                1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-pro-vision":         1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"ERNIE-Bot":         0.8572,     // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":   0.5715,     // ¥0.008 / 1k tokens | ||||
| 	"ERNIE-Bot-4":       0.12 * RMB, // ¥0.12 / 1k tokens | ||||
| 	"ERNIE-Bot-8k":      0.024 * RMB, | ||||
| 	"Embedding-V1":      0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":            1, | ||||
| 	"gemini-pro":        1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	// https://open.bigmodel.cn/pricing | ||||
| 	"glm-4":                     0.1 * RMB, | ||||
| 	"glm-4v":                    0.1 * RMB, | ||||
| 	"glm-3-turbo":               0.005 * RMB, | ||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||
| @@ -127,6 +110,42 @@ var ModelRatio = map[string]float64{ | ||||
| 	"moonshot-v1-8k":   0.012 * RMB, | ||||
| 	"moonshot-v1-32k":  0.024 * RMB, | ||||
| 	"moonshot-v1-128k": 0.06 * RMB, | ||||
| 	// https://platform.baichuan-ai.com/price | ||||
| 	"Baichuan2-Turbo":      0.008 * RMB, | ||||
| 	"Baichuan2-Turbo-192k": 0.016 * RMB, | ||||
| 	"Baichuan2-53B":        0.02 * RMB, | ||||
| 	// https://api.minimax.chat/document/price | ||||
| 	"abab6-chat":    0.1 * RMB, | ||||
| 	"abab5.5-chat":  0.015 * RMB, | ||||
| 	"abab5.5s-chat": 0.005 * RMB, | ||||
| 	// https://docs.mistral.ai/platform/pricing/ | ||||
| 	"open-mistral-7b":       0.25 / 1000 * USD, | ||||
| 	"open-mixtral-8x7b":     0.7 / 1000 * USD, | ||||
| 	"mistral-small-latest":  2.0 / 1000 * USD, | ||||
| 	"mistral-medium-latest": 2.7 / 1000 * USD, | ||||
| 	"mistral-large-latest":  8.0 / 1000 * USD, | ||||
| 	"mistral-embed":         0.1 / 1000 * USD, | ||||
| 	// https://wow.groq.com/ | ||||
| 	"llama2-70b-4096":    0.7 / 1000 * USD, | ||||
| 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||
| 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, | ||||
| 	"gemma-7b-it":        0.1 / 1000 * USD, | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{} | ||||
|  | ||||
| var DefaultModelRatio map[string]float64 | ||||
| var DefaultCompletionRatio map[string]float64 | ||||
|  | ||||
| func init() { | ||||
| 	DefaultModelRatio = make(map[string]float64) | ||||
| 	for k, v := range ModelRatio { | ||||
| 		DefaultModelRatio[k] = v | ||||
| 	} | ||||
| 	DefaultCompletionRatio = make(map[string]float64) | ||||
| 	for k, v := range CompletionRatio { | ||||
| 		DefaultCompletionRatio[k] = v | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
| @@ -147,6 +166,9 @@ func GetModelRatio(name string) float64 { | ||||
| 		name = strings.TrimSuffix(name, "-internet") | ||||
| 	} | ||||
| 	ratio, ok := ModelRatio[name] | ||||
| 	if !ok { | ||||
| 		ratio, ok = DefaultModelRatio[name] | ||||
| 	} | ||||
| 	if !ok { | ||||
| 		logger.SysError("model ratio not found: " + name) | ||||
| 		return 30 | ||||
| @@ -154,8 +176,6 @@ func GetModelRatio(name string) float64 { | ||||
| 	return ratio | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{} | ||||
|  | ||||
| func CompletionRatio2JSONString() string { | ||||
| 	jsonBytes, err := json.Marshal(CompletionRatio) | ||||
| 	if err != nil { | ||||
| @@ -173,6 +193,9 @@ func GetCompletionRatio(name string) float64 { | ||||
| 	if ratio, ok := CompletionRatio[name]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if ratio, ok := DefaultCompletionRatio[name]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-3.5") { | ||||
| 		if strings.HasSuffix(name, "0125") { | ||||
| 			// https://openai.com/blog/new-embedding-models-and-api-updates | ||||
| @@ -191,7 +214,7 @@ func GetCompletionRatio(name string) float64 { | ||||
| 				return 2 | ||||
| 			} | ||||
| 		} | ||||
| 		return 1.333333 | ||||
| 		return 4.0 / 3.0 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-4") { | ||||
| 		if strings.HasSuffix(name, "preview") { | ||||
| @@ -199,11 +222,18 @@ func GetCompletionRatio(name string) float64 { | ||||
| 		} | ||||
| 		return 2 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "claude-instant-1") { | ||||
| 		return 3.38 | ||||
| 	if strings.HasPrefix(name, "claude-3") { | ||||
| 		return 5 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "claude-2") { | ||||
| 		return 2.965517 | ||||
| 	if strings.HasPrefix(name, "claude-") { | ||||
| 		return 3 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "mistral-") { | ||||
| 		return 3 | ||||
| 	} | ||||
| 	switch name { | ||||
| 	case "llama2-70b-4096": | ||||
| 		return 0.8 / 0.7 | ||||
| 	} | ||||
| 	return 1 | ||||
| } | ||||
|   | ||||
							
								
								
									
										8
									
								
								common/random.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								common/random.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| package common | ||||
|  | ||||
| import "math/rand" | ||||
|  | ||||
| // RandRange returns a random number between min and max (max is not included) | ||||
| func RandRange(min, max int) int { | ||||
| 	return min + rand.Intn(max-min) | ||||
| } | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { | ||||
| 		} else { | ||||
| 			// err is nil & balance <= 0 means quota is used up | ||||
| 			if balance <= 0 { | ||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | ||||
| 				monitor.DisableChannel(channel.Id, channel.Name, "余额不足") | ||||
| 			} | ||||
| 		} | ||||
| 		time.Sleep(config.RequestInterval) | ||||
|   | ||||
| @@ -8,7 +8,9 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -18,6 +20,7 @@ import ( | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| @@ -51,6 +54,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	c.Request.Header.Set("Content-Type", "application/json") | ||||
| 	c.Set("channel", channel.Type) | ||||
| 	c.Set("base_url", channel.GetBaseURL()) | ||||
| 	middleware.SetupContextForSelectedChannel(c, channel, "") | ||||
| 	meta := util.GetRelayMeta(c) | ||||
| 	apiType := constant.ChannelType2APIType(channel.Type) | ||||
| 	adaptor := helper.GetAdaptor(apiType) | ||||
| @@ -59,6 +63,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
| 	modelName := adaptor.GetModelList()[0] | ||||
| 	if !strings.Contains(channel.Models, modelName) { | ||||
| 		modelNames := strings.Split(channel.Models, ",") | ||||
| 		if len(modelNames) > 0 { | ||||
| 			modelName = modelNames[0] | ||||
| 		} | ||||
| 	} | ||||
| 	request := buildTestRequest() | ||||
| 	request.Model = modelName | ||||
| 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | ||||
| @@ -139,32 +149,6 @@ func TestChannel(c *gin.Context) { | ||||
| var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // disable & notify | ||||
| func disableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // enable & notify | ||||
| func enableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func testAllChannels(notify bool) error { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| @@ -193,13 +177,13 @@ func testAllChannels(notify bool) error { | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | ||||
| 				enableChannel(channel.Id, channel.Name) | ||||
| 				monitor.EnableChannel(channel.Id, channel.Name) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
| 			time.Sleep(config.RequestInterval) | ||||
|   | ||||
| @@ -3,11 +3,13 @@ package controller | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/models/list | ||||
| @@ -39,6 +41,7 @@ type OpenAIModels struct { | ||||
|  | ||||
| var openAIModels []OpenAIModels | ||||
| var openAIModelsMap map[string]OpenAIModels | ||||
| var channelId2Models map[int][]string | ||||
|  | ||||
| func init() { | ||||
| 	var permission []OpenAIModelPermission | ||||
| @@ -76,32 +79,44 @@ func init() { | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, modelName := range ai360.ModelList { | ||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | ||||
| 			Id:         modelName, | ||||
| 			Object:     "model", | ||||
| 			Created:    1626777600, | ||||
| 			OwnedBy:    "360", | ||||
| 			Permission: permission, | ||||
| 			Root:       modelName, | ||||
| 			Parent:     nil, | ||||
| 		}) | ||||
| 	} | ||||
| 	for _, modelName := range moonshot.ModelList { | ||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | ||||
| 			Id:         modelName, | ||||
| 			Object:     "model", | ||||
| 			Created:    1626777600, | ||||
| 			OwnedBy:    "moonshot", | ||||
| 			Permission: permission, | ||||
| 			Root:       modelName, | ||||
| 			Parent:     nil, | ||||
| 		}) | ||||
| 	for _, channelType := range openai.CompatibleChannels { | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| 			continue | ||||
| 		} | ||||
| 		channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) | ||||
| 		for _, modelName := range channelModelList { | ||||
| 			openAIModels = append(openAIModels, OpenAIModels{ | ||||
| 				Id:         modelName, | ||||
| 				Object:     "model", | ||||
| 				Created:    1626777600, | ||||
| 				OwnedBy:    channelName, | ||||
| 				Permission: permission, | ||||
| 				Root:       modelName, | ||||
| 				Parent:     nil, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | ||||
| 	for _, model := range openAIModels { | ||||
| 		openAIModelsMap[model.Id] = model | ||||
| 	} | ||||
| 	channelId2Models = make(map[int][]string) | ||||
| 	for i := 1; i < common.ChannelTypeDummy; i++ { | ||||
| 		adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) | ||||
| 		meta := &util.RelayMeta{ | ||||
| 			ChannelType: i, | ||||
| 		} | ||||
| 		adaptor.Init(meta) | ||||
| 		channelId2Models[i] = adaptor.GetModelList() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func DashboardListModels(c *gin.Context) { | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    channelId2Models, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func ListModels(c *gin.Context) { | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	dbmodel "github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/controller" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -41,11 +42,16 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { | ||||
| func Relay(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) | ||||
| 	bizErr := relay(c, relayMode) | ||||
| 	if bizErr == nil { | ||||
| 		return | ||||
| 	if config.DebugEnabled { | ||||
| 		requestBody, _ := common.GetRequestBody(c) | ||||
| 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||
| 	} | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	bizErr := relay(c, relayMode) | ||||
| 	if bizErr == nil { | ||||
| 		monitor.Emit(channelId, true) | ||||
| 		return | ||||
| 	} | ||||
| 	lastFailedChannelId := channelId | ||||
| 	channelName := c.GetString("channel_name") | ||||
| 	group := c.GetString("group") | ||||
| @@ -58,7 +64,7 @@ func Relay(c *gin.Context) { | ||||
| 		retryTimes = 0 | ||||
| 	} | ||||
| 	for i := retryTimes; i > 0; i-- { | ||||
| 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) | ||||
| 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) | ||||
| 			break | ||||
| @@ -113,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st | ||||
| 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | ||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
| 		disableChannel(channelId, channelName, err.Message) | ||||
| 		monitor.DisableChannel(channelId, channelName, err.Message) | ||||
| 	} else { | ||||
| 		monitor.Emit(channelId, false) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										3
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								main.go
									
									
									
									
									
								
							| @@ -83,6 +83,9 @@ func main() { | ||||
| 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | ||||
| 		model.InitBatchUpdater() | ||||
| 	} | ||||
| 	if config.EnableMetric { | ||||
| 		logger.SysLog("metric enabled, will disable channel if too much request failed") | ||||
| 	} | ||||
| 	openai.InitTokenEncoders() | ||||
|  | ||||
| 	// Initialize HTTP server | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	if status.(int) == common.UserStatusDisabled { | ||||
| 	if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "用户已被封禁", | ||||
| 		}) | ||||
| 		session := sessions.Default(c) | ||||
| 		session.Clear() | ||||
| 		_ = session.Save() | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
| @@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		if !userEnabled { | ||||
| 		if !userEnabled || blacklist.IsUserBanned(token.UserId) { | ||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 				} | ||||
| 			} | ||||
| 			requestModel = modelRequest.Model | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				if channel != nil { | ||||
|   | ||||
| @@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||
| 	if !config.MemoryCacheEnabled { | ||||
| 		return GetRandomSatisfiedChannel(group, model) | ||||
| 	} | ||||
| @@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | ||||
| 		} | ||||
| 	} | ||||
| 	idx := rand.Intn(endIdx) | ||||
| 	if ignoreFirstPriority { | ||||
| 		if endIdx < len(channels) { // which means there are more than one priority | ||||
| 			idx = common.RandRange(endIdx, len(channels)) | ||||
| 		} | ||||
| 	} | ||||
| 	return channels[idx], nil | ||||
| } | ||||
|   | ||||
| @@ -72,7 +72,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| func InitDB() (err error) { | ||||
| 	db, err := chooseDB() | ||||
| 	if err == nil { | ||||
| 		if config.DebugEnabled { | ||||
| 		if config.DebugSQLEnabled { | ||||
| 			db = db.Debug() | ||||
| 		} | ||||
| 		DB = db | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| @@ -40,7 +41,7 @@ func GetMaxUserId() int { | ||||
| } | ||||
|  | ||||
| func GetAllUsers(startIdx int, num int) (users []*User, err error) { | ||||
| 	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error | ||||
| 	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error | ||||
| 	return users, err | ||||
| } | ||||
|  | ||||
| @@ -123,6 +124,11 @@ func (user *User) Update(updatePassword bool) error { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	if user.Status == common.UserStatusDisabled { | ||||
| 		blacklist.BanUser(user.Id) | ||||
| 	} else if user.Status == common.UserStatusEnabled { | ||||
| 		blacklist.UnbanUser(user.Id) | ||||
| 	} | ||||
| 	err = DB.Model(user).Updates(user).Error | ||||
| 	return err | ||||
| } | ||||
| @@ -131,7 +137,10 @@ func (user *User) Delete() error { | ||||
| 	if user.Id == 0 { | ||||
| 		return errors.New("id 为空!") | ||||
| 	} | ||||
| 	err := DB.Delete(user).Error | ||||
| 	blacklist.BanUser(user.Id) | ||||
| 	user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) | ||||
| 	user.Status = common.UserStatusDeleted | ||||
| 	err := DB.Model(user).Updates(user).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										46
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| ) | ||||
|  | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // DisableChannel disable & notify | ||||
| func DisableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func MetricDisableChannel(channelId int, successRate float64) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) | ||||
| 	subject := fmt.Sprintf("通道 #%d 已被禁用", channelId) | ||||
| 	content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", | ||||
| 		config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // EnableChannel enable & notify | ||||
| func EnableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
							
								
								
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| var store = make(map[int][]bool) | ||||
| var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) | ||||
| var metricFailChan = make(chan int, config.MetricFailChanSize) | ||||
|  | ||||
| func consumeSuccess(channelId int) { | ||||
| 	if len(store[channelId]) > config.MetricQueueSize { | ||||
| 		store[channelId] = store[channelId][1:] | ||||
| 	} | ||||
| 	store[channelId] = append(store[channelId], true) | ||||
| } | ||||
|  | ||||
| func consumeFail(channelId int) (bool, float64) { | ||||
| 	if len(store[channelId]) > config.MetricQueueSize { | ||||
| 		store[channelId] = store[channelId][1:] | ||||
| 	} | ||||
| 	store[channelId] = append(store[channelId], false) | ||||
| 	successCount := 0 | ||||
| 	for _, success := range store[channelId] { | ||||
| 		if success { | ||||
| 			successCount++ | ||||
| 		} | ||||
| 	} | ||||
| 	successRate := float64(successCount) / float64(len(store[channelId])) | ||||
| 	if len(store[channelId]) < config.MetricQueueSize { | ||||
| 		return false, successRate | ||||
| 	} | ||||
| 	if successRate < config.MetricSuccessRateThreshold { | ||||
| 		store[channelId] = make([]bool, 0) | ||||
| 		return true, successRate | ||||
| 	} | ||||
| 	return false, successRate | ||||
| } | ||||
|  | ||||
| func metricSuccessConsumer() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case channelId := <-metricSuccessChan: | ||||
| 			consumeSuccess(channelId) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func metricFailConsumer() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case channelId := <-metricFailChan: | ||||
| 			disable, successRate := consumeFail(channelId) | ||||
| 			if disable { | ||||
| 				go MetricDisableChannel(channelId, successRate) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| 	if config.EnableMetric { | ||||
| 		go metricSuccessConsumer() | ||||
| 		go metricFailConsumer() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Emit(channelId int, success bool) { | ||||
| 	if !config.EnableMetric { | ||||
| 		return | ||||
| 	} | ||||
| 	go func() { | ||||
| 		if success { | ||||
| 			metricSuccessChan <- channelId | ||||
| 		} else { | ||||
| 			metricFailChan <- channelId | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
| @@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		enableSearch = true | ||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||
| 	} | ||||
| 	if request.TopP >= 1 { | ||||
| 		request.TopP = 0.9999 | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Model: aliModel, | ||||
| 		Input: Input{ | ||||
| @@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			EnableSearch:      enableSearch, | ||||
| 			IncrementalOutput: request.Stream, | ||||
| 			Seed:              uint64(request.Seed), | ||||
| 			MaxTokens:         request.MaxTokens, | ||||
| 			Temperature:       request.Temperature, | ||||
| 			TopP:              request.TopP, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -16,6 +16,8 @@ type Parameters struct { | ||||
| 	Seed              uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int     `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64 `json:"temperature,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
|   | ||||
| @@ -5,7 +5,6 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| @@ -20,7 +19,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil | ||||
| 	return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||
| @@ -31,6 +30,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut | ||||
| 		anthropicVersion = "2023-06-01" | ||||
| 	} | ||||
| 	req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 	req.Header.Set("anthropic-beta", "messages-2023-12-15") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -47,9 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		var responseText string | ||||
| 		err, responseText = StreamHandler(c, resp) | ||||
| 		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
|   | ||||
| @@ -1,5 +1,8 @@ | ||||
| package anthropic | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", | ||||
| 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||
| 	"claude-3-haiku-20240229", | ||||
| 	"claude-3-sonnet-20240229", | ||||
| 	"claude-3-opus-20240229", | ||||
| } | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/image" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -15,73 +16,135 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func stopReasonClaude2OpenAI(reason string) string { | ||||
| 	switch reason { | ||||
| func stopReasonClaude2OpenAI(reason *string) string { | ||||
| 	if reason == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	switch *reason { | ||||
| 	case "end_turn": | ||||
| 		return "stop" | ||||
| 	case "stop_sequence": | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	default: | ||||
| 		return reason | ||||
| 		return *reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	claudeRequest := Request{ | ||||
| 		Model:             textRequest.Model, | ||||
| 		Prompt:            "", | ||||
| 		MaxTokensToSample: textRequest.MaxTokens, | ||||
| 		StopSequences:     nil, | ||||
| 		Temperature:       textRequest.Temperature, | ||||
| 		TopP:              textRequest.TopP, | ||||
| 		Stream:            textRequest.Stream, | ||||
| 		Model:       textRequest.Model, | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 		TopP:        textRequest.TopP, | ||||
| 		Stream:      textRequest.Stream, | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokensToSample == 0 { | ||||
| 		claudeRequest.MaxTokensToSample = 1000000 | ||||
| 	if claudeRequest.MaxTokens == 0 { | ||||
| 		claudeRequest.MaxTokens = 4096 | ||||
| 	} | ||||
| 	// legacy model name mapping | ||||
| 	if claudeRequest.Model == "claude-instant-1" { | ||||
| 		claudeRequest.Model = "claude-instant-1.1" | ||||
| 	} else if claudeRequest.Model == "claude-2" { | ||||
| 		claudeRequest.Model = "claude-2.1" | ||||
| 	} | ||||
| 	prompt := "" | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		if message.Role == "user" { | ||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | ||||
| 		} else if message.Role == "assistant" { | ||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||
| 		} else if message.Role == "system" { | ||||
| 			if prompt == "" { | ||||
| 				prompt = message.StringContent() | ||||
| 			} | ||||
| 		if message.Role == "system" && claudeRequest.System == "" { | ||||
| 			claudeRequest.System = message.StringContent() | ||||
| 			continue | ||||
| 		} | ||||
| 		claudeMessage := Message{ | ||||
| 			Role: message.Role, | ||||
| 		} | ||||
| 		var content Content | ||||
| 		if message.IsStringContent() { | ||||
| 			content.Type = "text" | ||||
| 			content.Text = message.StringContent() | ||||
| 			claudeMessage.Content = append(claudeMessage.Content, content) | ||||
| 			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||
| 			continue | ||||
| 		} | ||||
| 		var contents []Content | ||||
| 		openaiContent := message.ParseContent() | ||||
| 		for _, part := range openaiContent { | ||||
| 			var content Content | ||||
| 			if part.Type == model.ContentTypeText { | ||||
| 				content.Type = "text" | ||||
| 				content.Text = part.Text | ||||
| 			} else if part.Type == model.ContentTypeImageURL { | ||||
| 				content.Type = "image" | ||||
| 				content.Source = &ImageSource{ | ||||
| 					Type: "base64", | ||||
| 				} | ||||
| 				mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) | ||||
| 				content.Source.MediaType = mimeType | ||||
| 				content.Source.Data = data | ||||
| 			} | ||||
| 			contents = append(contents, content) | ||||
| 		} | ||||
| 		claudeMessage.Content = contents | ||||
| 		claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||
| 	} | ||||
| 	prompt += "\n\nAssistant:" | ||||
| 	claudeRequest.Prompt = prompt | ||||
| 	return &claudeRequest | ||||
| } | ||||
|  | ||||
| func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { | ||||
| // https://docs.anthropic.com/claude/reference/messages-streaming | ||||
| func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { | ||||
| 	var response *Response | ||||
| 	var responseText string | ||||
| 	var stopReason string | ||||
| 	switch claudeResponse.Type { | ||||
| 	case "message_start": | ||||
| 		return nil, claudeResponse.Message | ||||
| 	case "content_block_start": | ||||
| 		if claudeResponse.ContentBlock != nil { | ||||
| 			responseText = claudeResponse.ContentBlock.Text | ||||
| 		} | ||||
| 	case "content_block_delta": | ||||
| 		if claudeResponse.Delta != nil { | ||||
| 			responseText = claudeResponse.Delta.Text | ||||
| 		} | ||||
| 	case "message_delta": | ||||
| 		if claudeResponse.Usage != nil { | ||||
| 			response = &Response{ | ||||
| 				Usage: *claudeResponse.Usage, | ||||
| 			} | ||||
| 		} | ||||
| 		if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { | ||||
| 			stopReason = *claudeResponse.Delta.StopReason | ||||
| 		} | ||||
| 	} | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = claudeResponse.Completion | ||||
| 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | ||||
| 	choice.Delta.Content = responseText | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	finishReason := stopReasonClaude2OpenAI(&stopReason) | ||||
| 	if finishReason != "null" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var response openai.ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = claudeResponse.Model | ||||
| 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||
| 	openaiResponse.Object = "chat.completion.chunk" | ||||
| 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &openaiResponse, response | ||||
| } | ||||
|  | ||||
| func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||
| 	var responseText string | ||||
| 	if len(claudeResponse.Content) > 0 { | ||||
| 		responseText = claudeResponse.Content[0].Text | ||||
| 	} | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | ||||
| 			Content: responseText, | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), | ||||
| 		Model:   claudeResponse.Model, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| @@ -89,17 +152,15 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | ||||
| 			return i + 4, data[0:i], nil | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| @@ -111,29 +172,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if !strings.HasPrefix(data, "event: completion") { | ||||
| 			if len(data) < 6 { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | ||||
| 			if !strings.HasPrefix(data, "data: ") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data: ") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var usage model.Usage | ||||
| 	var modelName string | ||||
| 	var id string | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var claudeResponse Response | ||||
| 			var claudeResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += claudeResponse.Completion | ||||
| 			response := streamResponseClaude2OpenAI(&claudeResponse) | ||||
| 			response.Id = responseId | ||||
| 			response, meta := streamResponseClaude2OpenAI(&claudeResponse) | ||||
| 			if meta != nil { | ||||
| 				usage.PromptTokens += meta.Usage.InputTokens | ||||
| 				usage.CompletionTokens += meta.Usage.OutputTokens | ||||
| 				modelName = meta.Model | ||||
| 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 				return true | ||||
| 			} | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			response.Id = id | ||||
| 			response.Model = modelName | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| @@ -147,11 +224,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| 	_ = resp.Body.Close() | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| @@ -181,11 +255,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st | ||||
| 	} | ||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||
| 	fullTextResponse.Model = modelName | ||||
| 	completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName) | ||||
| 	usage := model.Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| 		PromptTokens:     claudeResponse.Usage.InputTokens, | ||||
| 		CompletionTokens: claudeResponse.Usage.OutputTokens, | ||||
| 		TotalTokens:      claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
|   | ||||
| @@ -1,19 +1,44 @@ | ||||
| package anthropic | ||||
|  | ||||
| // https://docs.anthropic.com/claude/reference/messages_post | ||||
|  | ||||
| type Metadata struct { | ||||
| 	UserId string `json:"user_id"` | ||||
| } | ||||
|  | ||||
| type ImageSource struct { | ||||
| 	Type      string `json:"type"` | ||||
| 	MediaType string `json:"media_type"` | ||||
| 	Data      string `json:"data"` | ||||
| } | ||||
|  | ||||
| type Content struct { | ||||
| 	Type   string       `json:"type"` | ||||
| 	Text   string       `json:"text,omitempty"` | ||||
| 	Source *ImageSource `json:"source,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string    `json:"role"` | ||||
| 	Content []Content `json:"content"` | ||||
| } | ||||
|  | ||||
| type Request struct { | ||||
| 	Model             string   `json:"model"` | ||||
| 	Prompt            string   `json:"prompt"` | ||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||
| 	Temperature       float64  `json:"temperature,omitempty"` | ||||
| 	TopP              float64  `json:"top_p,omitempty"` | ||||
| 	TopK              int      `json:"top_k,omitempty"` | ||||
| 	Model         string    `json:"model"` | ||||
| 	Messages      []Message `json:"messages"` | ||||
| 	System        string    `json:"system,omitempty"` | ||||
| 	MaxTokens     int       `json:"max_tokens,omitempty"` | ||||
| 	StopSequences []string  `json:"stop_sequences,omitempty"` | ||||
| 	Stream        bool      `json:"stream,omitempty"` | ||||
| 	Temperature   float64   `json:"temperature,omitempty"` | ||||
| 	TopP          float64   `json:"top_p,omitempty"` | ||||
| 	TopK          int       `json:"top_k,omitempty"` | ||||
| 	//Metadata    `json:"metadata,omitempty"` | ||||
| 	Stream bool `json:"stream,omitempty"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| @@ -22,8 +47,29 @@ type Error struct { | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	Completion string `json:"completion"` | ||||
| 	StopReason string `json:"stop_reason"` | ||||
| 	Model      string `json:"model"` | ||||
| 	Error      Error  `json:"error"` | ||||
| 	Id           string    `json:"id"` | ||||
| 	Type         string    `json:"type"` | ||||
| 	Role         string    `json:"role"` | ||||
| 	Content      []Content `json:"content"` | ||||
| 	Model        string    `json:"model"` | ||||
| 	StopReason   *string   `json:"stop_reason"` | ||||
| 	StopSequence *string   `json:"stop_sequence"` | ||||
| 	Usage        Usage     `json:"usage"` | ||||
| 	Error        Error     `json:"error"` | ||||
| } | ||||
|  | ||||
| type Delta struct { | ||||
| 	Type         string  `json:"type"` | ||||
| 	Text         string  `json:"text"` | ||||
| 	StopReason   *string `json:"stop_reason"` | ||||
| 	StopSequence *string `json:"stop_sequence"` | ||||
| } | ||||
|  | ||||
| type StreamResponse struct { | ||||
| 	Type         string    `json:"type"` | ||||
| 	Message      *Response `json:"message"` | ||||
| 	Index        int       `json:"index"` | ||||
| 	ContentBlock *Content  `json:"content_block"` | ||||
| 	Delta        *Delta    `json:"delta"` | ||||
| 	Usage        *Usage    `json:"usage"` | ||||
| } | ||||
|   | ||||
							
								
								
									
										7
									
								
								relay/channel/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package baichuan | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"Baichuan2-Turbo", | ||||
| 	"Baichuan2-Turbo-192k", | ||||
| 	"Baichuan-Text-Embedding", | ||||
| } | ||||
| @@ -2,6 +2,7 @@ package baidu | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| @@ -9,6 +10,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| @@ -20,23 +22,33 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t | ||||
| 	var fullRequestURL string | ||||
| 	switch meta.ActualModelName { | ||||
| 	case "ERNIE-Bot-4": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" | ||||
| 	case "ERNIE-Bot-8K": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" | ||||
| 	case "ERNIE-Bot": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 	case "ERNIE-Speed": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" | ||||
| 	case "ERNIE-Bot-turbo": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 	case "BLOOMZ-7B": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 	case "Embedding-V1": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 	suffix := "chat/" | ||||
| 	if strings.HasPrefix("Embedding", meta.ActualModelName) { | ||||
| 		suffix = "embeddings/" | ||||
| 	} | ||||
| 	switch meta.ActualModelName { | ||||
| 	case "ERNIE-4.0": | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-Bot-4": | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-3.5-8K": | ||||
| 		suffix += "completions" | ||||
| 	case "ERNIE-Bot-8K": | ||||
| 		suffix += "ernie_bot_8k" | ||||
| 	case "ERNIE-Bot": | ||||
| 		suffix += "completions" | ||||
| 	case "ERNIE-Speed": | ||||
| 		suffix += "ernie_speed" | ||||
| 	case "ERNIE-Bot-turbo": | ||||
| 		suffix += "eb-instant" | ||||
| 	case "BLOOMZ-7B": | ||||
| 		suffix += "bloomz_7b1" | ||||
| 	case "Embedding-V1": | ||||
| 		suffix += "embedding-v1" | ||||
| 	default: | ||||
| 		suffix += meta.ActualModelName | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) | ||||
| 	var accessToken string | ||||
| 	var err error | ||||
| 	if accessToken, err = GetAccessToken(meta.APIKey); err != nil { | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| package gemini | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-pro", | ||||
| 	"gemini-pro-vision", | ||||
| 	"gemini-pro", "gemini-1.0-pro-001", | ||||
| 	"gemini-pro-vision", "gemini-1.0-pro-vision-001", | ||||
| } | ||||
|   | ||||
							
								
								
									
										10
									
								
								relay/channel/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/channel/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package groq | ||||
|  | ||||
| // https://console.groq.com/docs/models | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemma-7b-it", | ||||
| 	"llama2-7b-2048", | ||||
| 	"llama2-70b-4096", | ||||
| 	"mixtral-8x7b-32768", | ||||
| } | ||||
							
								
								
									
										7
									
								
								relay/channel/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package minimax | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"abab5.5s-chat", | ||||
| 	"abab5.5-chat", | ||||
| 	"abab6-chat", | ||||
| } | ||||
							
								
								
									
										14
									
								
								relay/channel/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								relay/channel/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| package minimax | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| ) | ||||
|  | ||||
| func GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	if meta.Mode == constant.RelayModeChatCompletions { | ||||
| 		return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil | ||||
| 	} | ||||
| 	return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) | ||||
| } | ||||
							
								
								
									
										10
									
								
								relay/channel/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/channel/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package mistral | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"open-mistral-7b", | ||||
| 	"open-mixtral-8x7b", | ||||
| 	"mistral-small-latest", | ||||
| 	"mistral-medium-latest", | ||||
| 	"mistral-large-latest", | ||||
| 	"mistral-embed", | ||||
| } | ||||
| @@ -6,8 +6,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| @@ -24,7 +23,8 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	if meta.ChannelType == common.ChannelTypeAzure { | ||||
| 	switch meta.ChannelType { | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 		requestURL := strings.Split(meta.RequestURLPath, "?")[0] | ||||
| 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) | ||||
| @@ -38,8 +38,11 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
|  | ||||
| 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||
| 		return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil | ||||
| 	case common.ChannelTypeMinimax: | ||||
| 		return minimax.GetRequestURL(meta) | ||||
| 	default: | ||||
| 		return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil | ||||
| 	} | ||||
| 	return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||
| @@ -70,7 +73,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		var responseText string | ||||
| 		err, responseText = StreamHandler(c, resp, meta.Mode) | ||||
| 		err, responseText, _ = StreamHandler(c, resp, meta.Mode) | ||||
| 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| @@ -79,25 +82,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	switch a.ChannelType { | ||||
| 	case common.ChannelType360: | ||||
| 		return ai360.ModelList | ||||
| 	case common.ChannelTypeMoonshot: | ||||
| 		return moonshot.ModelList | ||||
| 	default: | ||||
| 		return ModelList | ||||
| 	} | ||||
| 	_, modelList := GetCompatibleChannelMeta(a.ChannelType) | ||||
| 	return modelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	switch a.ChannelType { | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		return "azure" | ||||
| 	case common.ChannelType360: | ||||
| 		return "360" | ||||
| 	case common.ChannelTypeMoonshot: | ||||
| 		return "moonshot" | ||||
| 	default: | ||||
| 		return "openai" | ||||
| 	} | ||||
| 	channelName, _ := GetCompatibleChannelMeta(a.ChannelType) | ||||
| 	return channelName | ||||
| } | ||||
|   | ||||
							
								
								
									
										42
									
								
								relay/channel/openai/compatible.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								relay/channel/openai/compatible.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| package openai | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/groq" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| ) | ||||
|  | ||||
| var CompatibleChannels = []int{ | ||||
| 	common.ChannelTypeAzure, | ||||
| 	common.ChannelType360, | ||||
| 	common.ChannelTypeMoonshot, | ||||
| 	common.ChannelTypeBaichuan, | ||||
| 	common.ChannelTypeMinimax, | ||||
| 	common.ChannelTypeMistral, | ||||
| 	common.ChannelTypeGroq, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 	switch channelType { | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		return "azure", ModelList | ||||
| 	case common.ChannelType360: | ||||
| 		return "360", ai360.ModelList | ||||
| 	case common.ChannelTypeMoonshot: | ||||
| 		return "moonshot", moonshot.ModelList | ||||
| 	case common.ChannelTypeBaichuan: | ||||
| 		return "baichuan", baichuan.ModelList | ||||
| 	case common.ChannelTypeMinimax: | ||||
| 		return "minimax", minimax.ModelList | ||||
| 	case common.ChannelTypeMistral: | ||||
| 		return "mistralai", mistral.ModelList | ||||
| 	case common.ChannelTypeGroq: | ||||
| 		return "groq", groq.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
| } | ||||
| @@ -14,7 +14,7 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) { | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| @@ -31,6 +31,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	var usage *model.Usage | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| @@ -54,6 +55,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Delta.Content | ||||
| 					} | ||||
| 					if streamResponse.Usage != nil { | ||||
| 						usage = streamResponse.Usage | ||||
| 					} | ||||
| 				case constant.RelayModeCompletions: | ||||
| 					var streamResponse CompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| @@ -86,9 +90,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| 	return nil, responseText, usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
|   | ||||
| @@ -132,6 +132,7 @@ type ChatCompletionsStreamResponse struct { | ||||
| 	Created int64                                 `json:"created"` | ||||
| 	Model   string                                `json:"model"` | ||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||
| 	Usage   *model.Usage                          `json:"usage"` | ||||
| } | ||||
|  | ||||
| type CompletionsStreamResponse struct { | ||||
|   | ||||
| @@ -28,17 +28,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} | ||||
| 		messages = append(messages, Message{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    message.Role, | ||||
| @@ -81,6 +70,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
|  | ||||
| func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "tencent-hunyuan", | ||||
|   | ||||
| @@ -27,21 +27,10 @@ import ( | ||||
| func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 		messages = append(messages, Message{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.StringContent(), | ||||
| 		}) | ||||
| 	} | ||||
| 	xunfeiRequest := ChatRequest{} | ||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||
|   | ||||
| @@ -5,20 +5,35 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	APIVersion string | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetVersionByModeName(modelName string) { | ||||
| 	if strings.HasPrefix(modelName, "glm-") { | ||||
| 		a.APIVersion = "v4" | ||||
| 	} else { | ||||
| 		a.APIVersion = "v3" | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	a.SetVersionByModeName(meta.ActualModelName) | ||||
| 	if a.APIVersion == "v4" { | ||||
| 		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil | ||||
| 	} | ||||
| 	method := "invoke" | ||||
| 	if meta.IsStream { | ||||
| 		method = "sse-invoke" | ||||
| @@ -37,6 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	if request.TopP >= 1 { | ||||
| 		request.TopP = 0.99 | ||||
| 	} | ||||
| 	a.SetVersionByModeName(request.Model) | ||||
| 	if a.APIVersion == "v4" { | ||||
| 		return request, nil | ||||
| 	} | ||||
| 	return ConvertRequest(*request), nil | ||||
| } | ||||
|  | ||||
| @@ -44,7 +66,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | ||||
| 	return channel.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, _, usage = openai.StreamHandler(c, resp, meta.Mode) | ||||
| 	} else { | ||||
| 		err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if a.APIVersion == "v4" { | ||||
| 		return a.DoResponseV4(c, resp, meta) | ||||
| 	} | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
|   | ||||
| @@ -2,4 +2,5 @@ package zhipu | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | ||||
| 	"glm-4", "glm-4v", "glm-3-turbo", | ||||
| } | ||||
|   | ||||
| @@ -76,21 +76,10 @@ func GetToken(apikey string) string { | ||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *Request { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 		messages = append(messages, Message{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.StringContent(), | ||||
| 		}) | ||||
| 	} | ||||
| 	return &Request{ | ||||
| 		Prompt:      messages, | ||||
|   | ||||
							
								
								
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| package constant | ||||
|  | ||||
| var DalleSizeRatios = map[string]map[string]float64{ | ||||
| 	"dall-e-2": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1.125, | ||||
| 		"1024x1024": 1.25, | ||||
| 	}, | ||||
| 	"dall-e-3": { | ||||
| 		"1024x1024": 1, | ||||
| 		"1024x1792": 2, | ||||
| 		"1792x1024": 2, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var DalleGenerationImageAmounts = map[string][2]int{ | ||||
| 	"dall-e-2": {1, 10}, | ||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||
| } | ||||
|  | ||||
| var DalleImagePromptLengthLimitations = map[string]int{ | ||||
| 	"dall-e-2": 1000, | ||||
| 	"dall-e-3": 4000, | ||||
| } | ||||
| @@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener | ||||
| 	return textRequest, nil | ||||
| } | ||||
|  | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { | ||||
| 	imageRequest := &openai.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
| 	if imageRequest.Size == "" { | ||||
| 		imageRequest.Size = "1024x1024" | ||||
| 	} | ||||
| 	if imageRequest.Model == "" { | ||||
| 		imageRequest.Model = "dall-e-2" | ||||
| 	} | ||||
| 	return imageRequest, nil | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// model validation | ||||
| 	_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||
| 	if !hasValidSize { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||
| 		// channel not azure | ||||
| 		if meta.ChannelType != common.ChannelTypeAzure { | ||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { | ||||
| 	if imageRequest == nil { | ||||
| 		return 0, errors.New("imageRequest is nil") | ||||
| 	} | ||||
| 	imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||
| 	if !hasValidSize { | ||||
| 		return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) | ||||
| 	} | ||||
| 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||
| 		if imageRequest.Size == "1024x1024" { | ||||
| 			imageCostRatio *= 2 | ||||
| 		} else { | ||||
| 			imageCostRatio *= 1.5 | ||||
| 		} | ||||
| 	} | ||||
| 	return imageCostRatio, nil | ||||
| } | ||||
|  | ||||
| func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | ||||
| 	switch relayMode { | ||||
| 	case constant.RelayModeChatCompletions: | ||||
| @@ -113,10 +172,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||
| 	} | ||||
| 	if quota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | ||||
| 		model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 		model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||
| 	} | ||||
| 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | ||||
| 	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||
| 	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 	model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||
| } | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| @@ -20,120 +21,65 @@ import ( | ||||
| ) | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	if _, ok := common.DalleGenerationImageAmounts[element]; !ok { | ||||
| 	if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { | ||||
| 		return false | ||||
| 	} | ||||
| 	min := common.DalleGenerationImageAmounts[element][0] | ||||
| 	max := common.DalleGenerationImageAmounts[element][1] | ||||
| 	min := constant.DalleGenerationImageAmounts[element][0] | ||||
| 	max := constant.DalleGenerationImageAmounts[element][1] | ||||
|  | ||||
| 	return value >= min && value <= max | ||||
| } | ||||
|  | ||||
| func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||
| 	imageModel := "dall-e-2" | ||||
| 	imageSize := "1024x1024" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	var imageRequest openai.ImageRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 	ctx := c.Request.Context() | ||||
| 	meta := util.GetRelayMeta(c) | ||||
| 	imageRequest, err := getImageRequest(c, meta.Mode) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
|  | ||||
| 	// Size validation | ||||
| 	if imageRequest.Size != "" { | ||||
| 		imageSize = imageRequest.Size | ||||
| 	} | ||||
|  | ||||
| 	// Model validation | ||||
| 	if imageRequest.Model != "" { | ||||
| 		imageModel = imageRequest.Model | ||||
| 	} | ||||
|  | ||||
| 	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] | ||||
|  | ||||
| 	// Check if model is supported | ||||
| 	if hasValidSize { | ||||
| 		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { | ||||
| 			if imageSize == "1024x1024" { | ||||
| 				imageCostRatio *= 2 | ||||
| 			} else { | ||||
| 				imageCostRatio *= 1.5 | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Check prompt length | ||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageModel, imageRequest.N) { | ||||
| 		// channel not azure | ||||
| 		if channelType != common.ChannelTypeAzure { | ||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 		} | ||||
| 		logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) | ||||
| 		return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[imageModel] != "" { | ||||
| 			imageModel = modelMap[imageModel] | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	var isModelMapped bool | ||||
| 	meta.OriginModelName = imageRequest.Model | ||||
| 	imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) | ||||
| 	meta.ActualModelName = imageRequest.Model | ||||
|  | ||||
| 	// model validation | ||||
| 	bizErr := validateImageRequest(imageRequest, meta) | ||||
| 	if bizErr != nil { | ||||
| 		return bizErr | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
|  | ||||
| 	imageCostRatio, err := getImageCostRatio(imageRequest) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if channelType == common.ChannelTypeAzure { | ||||
| 	fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) | ||||
| 	if meta.ChannelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||
| 		apiVersion := util.GetAzureAPIVersion(c) | ||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) | ||||
| 	} | ||||
|  | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | ||||
| 	if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
|  | ||||
| 	modelRatio := common.GetModelRatio(imageModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	modelRatio := common.GetModelRatio(imageRequest.Model) | ||||
| 	groupRatio := common.GetGroupRatio(meta.Group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	userQuota, err := model.CacheGetUserQuota(meta.UserId) | ||||
|  | ||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||
|  | ||||
| @@ -146,7 +92,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	token := c.Request.Header.Get("Authorization") | ||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication | ||||
| 	if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication | ||||
| 		token = strings.TrimPrefix(token, "Bearer ") | ||||
| 		req.Header.Set("api-key", token) | ||||
| 	} else { | ||||
| @@ -169,25 +115,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var textResponse openai.ImageResponse | ||||
| 	var imageResponse openai.ImageResponse | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			return | ||||
| 		} | ||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 		err := model.PostConsumeTokenQuota(meta.TokenId, quota) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 		} | ||||
| 		err = model.CacheUpdateUserQuota(userId) | ||||
| 		err = model.CacheUpdateUserQuota(meta.UserId) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error update user quota cache: " + err.Error()) | ||||
| 		} | ||||
| 		if quota != 0 { | ||||
| 			tokenName := c.GetString("token_name") | ||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 			model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 		} | ||||
| @@ -202,7 +148,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	err = json.Unmarshal(responseBody, &imageResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|   | ||||
| @@ -55,7 +55,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 	var requestBody io.Reader | ||||
| 	if meta.APIType == constant.APITypeOpenAI { | ||||
| 		// no need to convert request for openai | ||||
| 		if isModelMapped { | ||||
| 		shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan | ||||
| 		if shouldResetRequestBody { | ||||
| 			jsonStr, err := json.Marshal(textRequest) | ||||
| 			if err != nil { | ||||
| 				return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | ||||
| @@ -82,11 +83,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 		logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) | ||||
| 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 	errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") | ||||
| 	if errorHappened { | ||||
| 		util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) | ||||
| 		return util.RelayErrorHandler(resp) | ||||
| 	} | ||||
| 	meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
|  | ||||
| 	// do response | ||||
| 	usage, respErr := adaptor.DoResponse(c, resp, meta) | ||||
|   | ||||
| @@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | ||||
| 	{ | ||||
| 		apiRouter.GET("/status", controller.GetStatus) | ||||
| 		apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) | ||||
| 		apiRouter.GET("/notice", controller.GetNotice) | ||||
| 		apiRouter.GET("/about", controller.GetAbout) | ||||
| 		apiRouter.GET("/home_page_content", controller.GetHomePageContent) | ||||
|   | ||||
| @@ -15,7 +15,7 @@ export const CHANNEL_OPTIONS = { | ||||
|     key: 3, | ||||
|     text: 'Azure OpenAI', | ||||
|     value: 3, | ||||
|     color: 'orange' | ||||
|     color: 'secondary' | ||||
|   }, | ||||
|   11: { | ||||
|     key: 11, | ||||
| @@ -29,6 +29,12 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 24, | ||||
|     color: 'orange' | ||||
|   }, | ||||
|   28: { | ||||
|     key: 28, | ||||
|     text: 'Mistral AI', | ||||
|     value: 28, | ||||
|     color: 'orange' | ||||
|   }, | ||||
|   15: { | ||||
|     key: 15, | ||||
|     text: '百度文心千帆', | ||||
| @@ -71,6 +77,24 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 23, | ||||
|     color: 'default' | ||||
|   }, | ||||
|   26: { | ||||
|     key: 26, | ||||
|     text: '百川大模型', | ||||
|     value: 26, | ||||
|     color: 'default' | ||||
|   }, | ||||
|   27: { | ||||
|     key: 27, | ||||
|     text: 'MiniMax', | ||||
|     value: 27, | ||||
|     color: 'default' | ||||
|   }, | ||||
|   29: { | ||||
|     key: 29, | ||||
|     text: 'Groq', | ||||
|     value: 29, | ||||
|     color: 'default' | ||||
|   }, | ||||
|   8: { | ||||
|     key: 8, | ||||
|     text: '自定义渠道', | ||||
|   | ||||
| @@ -67,7 +67,7 @@ const typeConfig = { | ||||
|   }, | ||||
|   16: { | ||||
|     input: { | ||||
|       models: ["chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], | ||||
|       models: ["glm-4", "glm-4v", "glm-3-turbo", "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], | ||||
|     }, | ||||
|     modelGroup: "zhipu", | ||||
|   }, | ||||
| @@ -145,6 +145,27 @@ const typeConfig = { | ||||
|     }, | ||||
|     modelGroup: "google gemini", | ||||
|   }, | ||||
|   25: { | ||||
|     input: { | ||||
|       models: ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'], | ||||
|     }, | ||||
|     modelGroup: "moonshot", | ||||
|   }, | ||||
|   26: { | ||||
|     input: { | ||||
|       models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'], | ||||
|     }, | ||||
|     modelGroup: "baichuan", | ||||
|   }, | ||||
|   27: { | ||||
|     input: { | ||||
|       models: ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat'], | ||||
|     }, | ||||
|     modelGroup: "minimax", | ||||
|   }, | ||||
|   29: { | ||||
|     modelGroup: "groq", | ||||
|   }, | ||||
| }; | ||||
|  | ||||
| export { defaultConfig, typeConfig }; | ||||
|   | ||||
| @@ -1,7 +1,16 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | ||||
| import { | ||||
|   API, | ||||
|   loadChannelModels, | ||||
|   setPromptShown, | ||||
|   shouldShowPrompt, | ||||
|   showError, | ||||
|   showInfo, | ||||
|   showSuccess, | ||||
|   timestamp2string | ||||
| } from '../helpers'; | ||||
|  | ||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||
| import { renderGroup, renderNumber } from '../helpers/render'; | ||||
| @@ -95,6 +104,7 @@ const ChannelsTable = () => { | ||||
|       .catch((reason) => { | ||||
|         showError(reason); | ||||
|       }); | ||||
|     loadChannelModels().then(); | ||||
|   }, []); | ||||
|  | ||||
|   const manageChannel = async (id, action, idx, value) => { | ||||
|   | ||||
| @@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||
|   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||
|   { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, | ||||
|   { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, | ||||
|   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||
|   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||
| @@ -11,6 +12,9 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||
|   { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, | ||||
|   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, | ||||
|   { key: 26, text: '百川大模型', value: 26, color: 'orange' }, | ||||
|   { key: 27, text: 'MiniMax', value: 27, color: 'red' }, | ||||
|   { key: 29, text: 'Groq', value: 29, color: 'orange' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||
|   | ||||
| @@ -1,11 +1,13 @@ | ||||
| import { toast } from 'react-toastify'; | ||||
| import { toastConstants } from '../constants'; | ||||
| import React from 'react'; | ||||
| import { API } from './api'; | ||||
|  | ||||
| const HTMLToastContent = ({ htmlContent }) => { | ||||
|   return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />; | ||||
| }; | ||||
| export default HTMLToastContent; | ||||
|  | ||||
| export function isAdmin() { | ||||
|   let user = localStorage.getItem('user'); | ||||
|   if (!user) return false; | ||||
| @@ -29,7 +31,7 @@ export function getSystemName() { | ||||
| export function getLogo() { | ||||
|   let logo = localStorage.getItem('logo'); | ||||
|   if (!logo) return '/logo.png'; | ||||
|   return logo | ||||
|   return logo; | ||||
| } | ||||
|  | ||||
| export function getFooterHTML() { | ||||
| @@ -197,3 +199,29 @@ export function shouldShowPrompt(id) { | ||||
| export function setPromptShown(id) { | ||||
|   localStorage.setItem(`prompt-${id}`, 'true'); | ||||
| } | ||||
|  | ||||
| let channelModels = undefined; | ||||
| export async function loadChannelModels() { | ||||
|   const res = await API.get('/api/models'); | ||||
|   const { success, data } = res.data; | ||||
|   if (!success) { | ||||
|     return; | ||||
|   } | ||||
|   channelModels = data; | ||||
|   localStorage.setItem('channel_models', JSON.stringify(data)); | ||||
| } | ||||
|  | ||||
| export function getChannelModels(type) { | ||||
|   if (channelModels !== undefined && type in channelModels) { | ||||
|     return channelModels[type]; | ||||
|   } | ||||
|   let models = localStorage.getItem('channel_models'); | ||||
|   if (!models) { | ||||
|     return []; | ||||
|   } | ||||
|   channelModels = JSON.parse(models); | ||||
|   if (type in channelModels) { | ||||
|     return channelModels[type]; | ||||
|   } | ||||
|   return []; | ||||
| } | ||||
| @@ -1,7 +1,7 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | ||||
| import { useNavigate, useParams } from 'react-router-dom'; | ||||
| import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | ||||
| import { API, copy, getChannelModels, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | ||||
| import { CHANNEL_OPTIONS } from '../../constants'; | ||||
|  | ||||
| const MODEL_MAPPING_EXAMPLE = { | ||||
| @@ -56,54 +56,12 @@ const EditChannel = () => { | ||||
|   const [customModel, setCustomModel] = useState(''); | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|     if (name === 'type' && inputs.models.length === 0) { | ||||
|       let localModels = []; | ||||
|       switch (value) { | ||||
|         case 14: | ||||
|           localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; | ||||
|           break; | ||||
|         case 11: | ||||
|           localModels = ['PaLM-2']; | ||||
|           break; | ||||
|         case 15: | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; | ||||
|           break; | ||||
|         case 17: | ||||
|           localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; | ||||
|           let withInternetVersion = []; | ||||
|           for (let i = 0; i < localModels.length; i++) { | ||||
|             if (localModels[i].startsWith('qwen-')) { | ||||
|               withInternetVersion.push(localModels[i] + '-internet'); | ||||
|             } | ||||
|           } | ||||
|           localModels = [...localModels, ...withInternetVersion]; | ||||
|           break; | ||||
|         case 16: | ||||
|           localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           break; | ||||
|         case 18: | ||||
|           localModels = [ | ||||
|             'SparkDesk', | ||||
|             'SparkDesk-v1.1', | ||||
|             'SparkDesk-v2.1', | ||||
|             'SparkDesk-v3.1', | ||||
|             'SparkDesk-v3.5' | ||||
|           ]; | ||||
|           break; | ||||
|         case 19: | ||||
|           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; | ||||
|           break; | ||||
|         case 23: | ||||
|           localModels = ['hunyuan']; | ||||
|           break; | ||||
|         case 24: | ||||
|           localModels = ['gemini-pro', 'gemini-pro-vision']; | ||||
|           break; | ||||
|         case 25: | ||||
|           localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']; | ||||
|           break; | ||||
|     if (name === 'type') { | ||||
|       let localModels = getChannelModels(value); | ||||
|       if (inputs.models.length === 0) { | ||||
|         setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||
|       } | ||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||
|       setBasicModels(localModels); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
| @@ -256,6 +214,7 @@ const EditChannel = () => { | ||||
|               label='类型' | ||||
|               name='type' | ||||
|               required | ||||
|               search | ||||
|               options={CHANNEL_OPTIONS} | ||||
|               value={inputs.type} | ||||
|               onChange={handleInputChange} | ||||
| @@ -384,6 +343,8 @@ const EditChannel = () => { | ||||
|               required | ||||
|               fluid | ||||
|               multiple | ||||
|               search | ||||
|               onLabelClick={(e, { value }) => {copy(value).then()}} | ||||
|               selection | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.models} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user