mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			10 Commits
		
	
	
		
			v0.6.1-alp
			...
			v0.6.2-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 12440874b0 | ||
|  | 6ebc99460e | ||
|  | 27ad8bfb98 | ||
|  | 8388aa537f | ||
|  | 2346bf70af | ||
|  | f05b403ca5 | ||
|  | b33616df44 | ||
|  | cf16f44970 | ||
|  | bf2e26a48f | ||
|  | 4fb22ad4ce | 
| @@ -78,6 +78,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|    + [x] [百川大模型](https://platform.baichuan-ai.com) |    + [x] [百川大模型](https://platform.baichuan-ai.com) | ||||||
|    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) |    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) | ||||||
|    + [x] [MINIMAX](https://api.minimax.chat/) |    + [x] [MINIMAX](https://api.minimax.chat/) | ||||||
|  |    + [x] [Groq](https://wow.groq.com/) | ||||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| @@ -374,6 +375,9 @@ graph LR | |||||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||||
|  | 19.  `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||||
|  | 20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||||
|  | 21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package blacklist | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var blackList sync.Map | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	blackList = sync.Map{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func userId2Key(id int) string { | ||||||
|  | 	return fmt.Sprintf("userid_%d", id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BanUser(id int) { | ||||||
|  | 	blackList.Store(userId2Key(id), true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnbanUser(id int) { | ||||||
|  | 	blackList.Delete(userId2Key(id)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IsUserBanned(id int) bool { | ||||||
|  | 	_, ok := blackList.Load(userId2Key(id)) | ||||||
|  | 	return ok | ||||||
|  | } | ||||||
| @@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{ | |||||||
| } | } | ||||||
|  |  | ||||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||||
|  | var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" | ||||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||||
|  |  | ||||||
| var LogConsumeEnabled = true | var LogConsumeEnabled = true | ||||||
| @@ -125,3 +126,9 @@ var ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute | var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||||
|  |  | ||||||
|  | var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) | ||||||
|  | var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) | ||||||
|  | var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) | ||||||
|  | var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) | ||||||
|  | var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ const ( | |||||||
| const ( | const ( | ||||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||||
| 	UserStatusDisabled = 2 // also don't use 0 | 	UserStatusDisabled = 2 // also don't use 0 | ||||||
|  | 	UserStatusDeleted  = 3 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -38,35 +39,38 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ChannelTypeUnknown        = 0 | 	ChannelTypeUnknown = iota | ||||||
| 	ChannelTypeOpenAI         = 1 | 	ChannelTypeOpenAI | ||||||
| 	ChannelTypeAPI2D          = 2 | 	ChannelTypeAPI2D | ||||||
| 	ChannelTypeAzure          = 3 | 	ChannelTypeAzure | ||||||
| 	ChannelTypeCloseAI        = 4 | 	ChannelTypeCloseAI | ||||||
| 	ChannelTypeOpenAISB       = 5 | 	ChannelTypeOpenAISB | ||||||
| 	ChannelTypeOpenAIMax      = 6 | 	ChannelTypeOpenAIMax | ||||||
| 	ChannelTypeOhMyGPT        = 7 | 	ChannelTypeOhMyGPT | ||||||
| 	ChannelTypeCustom         = 8 | 	ChannelTypeCustom | ||||||
| 	ChannelTypeAILS           = 9 | 	ChannelTypeAILS | ||||||
| 	ChannelTypeAIProxy        = 10 | 	ChannelTypeAIProxy | ||||||
| 	ChannelTypePaLM           = 11 | 	ChannelTypePaLM | ||||||
| 	ChannelTypeAPI2GPT        = 12 | 	ChannelTypeAPI2GPT | ||||||
| 	ChannelTypeAIGC2D         = 13 | 	ChannelTypeAIGC2D | ||||||
| 	ChannelTypeAnthropic      = 14 | 	ChannelTypeAnthropic | ||||||
| 	ChannelTypeBaidu          = 15 | 	ChannelTypeBaidu | ||||||
| 	ChannelTypeZhipu          = 16 | 	ChannelTypeZhipu | ||||||
| 	ChannelTypeAli            = 17 | 	ChannelTypeAli | ||||||
| 	ChannelTypeXunfei         = 18 | 	ChannelTypeXunfei | ||||||
| 	ChannelType360            = 19 | 	ChannelType360 | ||||||
| 	ChannelTypeOpenRouter     = 20 | 	ChannelTypeOpenRouter | ||||||
| 	ChannelTypeAIProxyLibrary = 21 | 	ChannelTypeAIProxyLibrary | ||||||
| 	ChannelTypeFastGPT        = 22 | 	ChannelTypeFastGPT | ||||||
| 	ChannelTypeTencent        = 23 | 	ChannelTypeTencent | ||||||
| 	ChannelTypeGemini         = 24 | 	ChannelTypeGemini | ||||||
| 	ChannelTypeMoonshot       = 25 | 	ChannelTypeMoonshot | ||||||
| 	ChannelTypeBaichuan       = 26 | 	ChannelTypeBaichuan | ||||||
| 	ChannelTypeMinimax        = 27 | 	ChannelTypeMinimax | ||||||
| 	ChannelTypeMistral        = 28 | 	ChannelTypeMistral | ||||||
|  | 	ChannelTypeGroq | ||||||
|  |  | ||||||
|  | 	ChannelTypeDummy | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| @@ -99,6 +103,7 @@ var ChannelBaseURLs = []string{ | |||||||
| 	"https://api.baichuan-ai.com",               // 26 | 	"https://api.baichuan-ai.com",               // 26 | ||||||
| 	"https://api.minimax.chat",                  // 27 | 	"https://api.minimax.chat",                  // 27 | ||||||
| 	"https://api.mistral.ai",                    // 28 | 	"https://api.mistral.ai",                    // 28 | ||||||
|  | 	"https://api.groq.com/openai",               // 29 | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|   | |||||||
| @@ -195,6 +195,13 @@ func Max(a int, b int) int { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetOrDefaultEnvBool(env string, defaultValue bool) bool { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return os.Getenv(env) == "true" | ||||||
|  | } | ||||||
|  |  | ||||||
| func GetOrDefaultEnvInt(env string, defaultValue int) int { | func GetOrDefaultEnvInt(env string, defaultValue int) int { | ||||||
| 	if env == "" || os.Getenv(env) == "" { | 	if env == "" || os.Getenv(env) == "" { | ||||||
| 		return defaultValue | 		return defaultValue | ||||||
| @@ -207,6 +214,18 @@ func GetOrDefaultEnvInt(env string, defaultValue int) int { | |||||||
| 	return num | 	return num | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	num, err := strconv.ParseFloat(os.Getenv(env), 64) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue)) | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|  |  | ||||||
| func GetOrDefaultEnvString(env string, defaultValue string) string { | func GetOrDefaultEnvString(env string, defaultValue string) string { | ||||||
| 	if env == "" || os.Getenv(env) == "" { | 	if env == "" || os.Getenv(env) == "" { | ||||||
| 		return defaultValue | 		return defaultValue | ||||||
|   | |||||||
| @@ -63,12 +63,15 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"text-search-ada-doc-001": 10, | 	"text-search-ada-doc-001": 10, | ||||||
| 	"text-moderation-stable":  0.1, | 	"text-moderation-stable":  0.1, | ||||||
| 	"text-moderation-latest":  0.1, | 	"text-moderation-latest":  0.1, | ||||||
| 	"dall-e-2":                8,     // $0.016 - $0.020 / image | 	"dall-e-2":                8,  // $0.016 - $0.020 / image | ||||||
| 	"dall-e-3":                20,    // $0.040 - $0.120 / image | 	"dall-e-3":                20, // $0.040 - $0.120 / image | ||||||
| 	"claude-instant-1":        0.815, // $1.63 / 1M tokens | 	// https://www.anthropic.com/api#pricing | ||||||
| 	"claude-2":                5.51,  // $11.02 / 1M tokens | 	"claude-instant-1.2":       0.8 / 1000 * USD, | ||||||
| 	"claude-2.0":              5.51,  // $11.02 / 1M tokens | 	"claude-2.0":               8.0 / 1000 * USD, | ||||||
| 	"claude-2.1":              5.51,  // $11.02 / 1M tokens | 	"claude-2.1":               8.0 / 1000 * USD, | ||||||
|  | 	"claude-3-haiku-20240229":  0.25 / 1000 * USD, | ||||||
|  | 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | ||||||
|  | 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | ||||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||||
| 	"ERNIE-Bot":         0.8572,     // ¥0.012 / 1k tokens | 	"ERNIE-Bot":         0.8572,     // ¥0.012 / 1k tokens | ||||||
| 	"ERNIE-Bot-turbo":   0.5715,     // ¥0.008 / 1k tokens | 	"ERNIE-Bot-turbo":   0.5715,     // ¥0.008 / 1k tokens | ||||||
| @@ -122,6 +125,11 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"mistral-medium-latest": 2.7 / 1000 * USD, | 	"mistral-medium-latest": 2.7 / 1000 * USD, | ||||||
| 	"mistral-large-latest":  8.0 / 1000 * USD, | 	"mistral-large-latest":  8.0 / 1000 * USD, | ||||||
| 	"mistral-embed":         0.1 / 1000 * USD, | 	"mistral-embed":         0.1 / 1000 * USD, | ||||||
|  | 	// https://wow.groq.com/ | ||||||
|  | 	"llama2-70b-4096":    0.7 / 1000 * USD, | ||||||
|  | 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||||
|  | 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, | ||||||
|  | 	"gemma-7b-it":        0.1 / 1000 * USD, | ||||||
| } | } | ||||||
|  |  | ||||||
| var CompletionRatio = map[string]float64{} | var CompletionRatio = map[string]float64{} | ||||||
| @@ -206,7 +214,7 @@ func GetCompletionRatio(name string) float64 { | |||||||
| 				return 2 | 				return 2 | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		return 1.333333 | 		return 4.0 / 3.0 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "gpt-4") { | 	if strings.HasPrefix(name, "gpt-4") { | ||||||
| 		if strings.HasSuffix(name, "preview") { | 		if strings.HasSuffix(name, "preview") { | ||||||
| @@ -214,14 +222,18 @@ func GetCompletionRatio(name string) float64 { | |||||||
| 		} | 		} | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "claude-instant-1") { | 	if strings.HasPrefix(name, "claude-3") { | ||||||
| 		return 3.38 | 		return 5 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "claude-2") { | 	if strings.HasPrefix(name, "claude-") { | ||||||
| 		return 2.965517 | 		return 3 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "mistral-") { | 	if strings.HasPrefix(name, "mistral-") { | ||||||
| 		return 3 | 		return 3 | ||||||
| 	} | 	} | ||||||
|  | 	switch name { | ||||||
|  | 	case "llama2-70b-4096": | ||||||
|  | 		return 0.8 / 0.7 | ||||||
|  | 	} | ||||||
| 	return 1 | 	return 1 | ||||||
| } | } | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { | |||||||
| 		} else { | 		} else { | ||||||
| 			// err is nil & balance <= 0 means quota is used up | 			// err is nil & balance <= 0 means quota is used up | ||||||
| 			if balance <= 0 { | 			if balance <= 0 { | ||||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | 				monitor.DisableChannel(channel.Id, channel.Name, "余额不足") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		time.Sleep(config.RequestInterval) | 		time.Sleep(config.RequestInterval) | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/middleware" | 	"github.com/songquanpeng/one-api/middleware" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" | 	"github.com/songquanpeng/one-api/relay/helper" | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
| @@ -148,32 +149,6 @@ func TestChannel(c *gin.Context) { | |||||||
| var testAllChannelsLock sync.Mutex | var testAllChannelsLock sync.Mutex | ||||||
| var testAllChannelsRunning bool = false | var testAllChannelsRunning bool = false | ||||||
|  |  | ||||||
| func notifyRootUser(subject string, content string) { |  | ||||||
| 	if config.RootUserEmail == "" { |  | ||||||
| 		config.RootUserEmail = model.GetRootUserEmail() |  | ||||||
| 	} |  | ||||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // disable & notify |  | ||||||
| func disableChannel(channelId int, channelName string, reason string) { |  | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) |  | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) |  | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) |  | ||||||
| 	notifyRootUser(subject, content) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // enable & notify |  | ||||||
| func enableChannel(channelId int, channelName string) { |  | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) |  | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) |  | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) |  | ||||||
| 	notifyRootUser(subject, content) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func testAllChannels(notify bool) error { | func testAllChannels(notify bool) error { | ||||||
| 	if config.RootUserEmail == "" { | 	if config.RootUserEmail == "" { | ||||||
| 		config.RootUserEmail = model.GetRootUserEmail() | 		config.RootUserEmail = model.GetRootUserEmail() | ||||||
| @@ -202,13 +177,13 @@ func testAllChannels(notify bool) error { | |||||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | 			milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | 			if isChannelEnabled && milliseconds > disableThreshold { | ||||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | ||||||
| 				enableChannel(channel.Id, channel.Name) | 				monitor.EnableChannel(channel.Id, channel.Name) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
| 			time.Sleep(config.RequestInterval) | 			time.Sleep(config.RequestInterval) | ||||||
|   | |||||||
| @@ -3,14 +3,13 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/mistral" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" | 	"github.com/songquanpeng/one-api/relay/helper" | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
|  | 	"net/http" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/models/list | // https://platform.openai.com/docs/api-reference/models/list | ||||||
| @@ -42,6 +41,7 @@ type OpenAIModels struct { | |||||||
|  |  | ||||||
| var openAIModels []OpenAIModels | var openAIModels []OpenAIModels | ||||||
| var openAIModelsMap map[string]OpenAIModels | var openAIModelsMap map[string]OpenAIModels | ||||||
|  | var channelId2Models map[int][]string | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	var permission []OpenAIModelPermission | 	var permission []OpenAIModelPermission | ||||||
| @@ -79,65 +79,44 @@ func init() { | |||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	for _, modelName := range ai360.ModelList { | 	for _, channelType := range openai.CompatibleChannels { | ||||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | 		if channelType == common.ChannelTypeAzure { | ||||||
| 			Id:         modelName, | 			continue | ||||||
| 			Object:     "model", | 		} | ||||||
| 			Created:    1626777600, | 		channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) | ||||||
| 			OwnedBy:    "360", | 		for _, modelName := range channelModelList { | ||||||
| 			Permission: permission, | 			openAIModels = append(openAIModels, OpenAIModels{ | ||||||
| 			Root:       modelName, | 				Id:         modelName, | ||||||
| 			Parent:     nil, | 				Object:     "model", | ||||||
| 		}) | 				Created:    1626777600, | ||||||
| 	} | 				OwnedBy:    channelName, | ||||||
| 	for _, modelName := range moonshot.ModelList { | 				Permission: permission, | ||||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | 				Root:       modelName, | ||||||
| 			Id:         modelName, | 				Parent:     nil, | ||||||
| 			Object:     "model", | 			}) | ||||||
| 			Created:    1626777600, | 		} | ||||||
| 			OwnedBy:    "moonshot", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       modelName, |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	for _, modelName := range baichuan.ModelList { |  | ||||||
| 		openAIModels = append(openAIModels, OpenAIModels{ |  | ||||||
| 			Id:         modelName, |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1626777600, |  | ||||||
| 			OwnedBy:    "baichuan", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       modelName, |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	for _, modelName := range minimax.ModelList { |  | ||||||
| 		openAIModels = append(openAIModels, OpenAIModels{ |  | ||||||
| 			Id:         modelName, |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1626777600, |  | ||||||
| 			OwnedBy:    "minimax", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       modelName, |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	for _, modelName := range mistral.ModelList { |  | ||||||
| 		openAIModels = append(openAIModels, OpenAIModels{ |  | ||||||
| 			Id:         modelName, |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1626777600, |  | ||||||
| 			OwnedBy:    "mistralai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       modelName, |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | 	openAIModelsMap = make(map[string]OpenAIModels) | ||||||
| 	for _, model := range openAIModels { | 	for _, model := range openAIModels { | ||||||
| 		openAIModelsMap[model.Id] = model | 		openAIModelsMap[model.Id] = model | ||||||
| 	} | 	} | ||||||
|  | 	channelId2Models = make(map[int][]string) | ||||||
|  | 	for i := 1; i < common.ChannelTypeDummy; i++ { | ||||||
|  | 		adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) | ||||||
|  | 		meta := &util.RelayMeta{ | ||||||
|  | 			ChannelType: i, | ||||||
|  | 		} | ||||||
|  | 		adaptor.Init(meta) | ||||||
|  | 		channelId2Models[i] = adaptor.GetModelList() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DashboardListModels(c *gin.Context) { | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    channelId2Models, | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func ListModels(c *gin.Context) { | func ListModels(c *gin.Context) { | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/middleware" | 	"github.com/songquanpeng/one-api/middleware" | ||||||
| 	dbmodel "github.com/songquanpeng/one-api/model" | 	dbmodel "github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/controller" | 	"github.com/songquanpeng/one-api/relay/controller" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| @@ -45,11 +46,12 @@ func Relay(c *gin.Context) { | |||||||
| 		requestBody, _ := common.GetRequestBody(c) | 		requestBody, _ := common.GetRequestBody(c) | ||||||
| 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||||
| 	} | 	} | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
| 	bizErr := relay(c, relayMode) | 	bizErr := relay(c, relayMode) | ||||||
| 	if bizErr == nil { | 	if bizErr == nil { | ||||||
|  | 		monitor.Emit(channelId, true) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
| 	lastFailedChannelId := channelId | 	lastFailedChannelId := channelId | ||||||
| 	channelName := c.GetString("channel_name") | 	channelName := c.GetString("channel_name") | ||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
| @@ -117,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st | |||||||
| 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | ||||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||||
| 		disableChannel(channelId, channelName, err.Message) | 		monitor.DisableChannel(channelId, channelName, err.Message) | ||||||
|  | 	} else { | ||||||
|  | 		monitor.Emit(channelId, false) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								main.go
									
									
									
									
									
								
							| @@ -83,6 +83,9 @@ func main() { | |||||||
| 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | ||||||
| 		model.InitBatchUpdater() | 		model.InitBatchUpdater() | ||||||
| 	} | 	} | ||||||
|  | 	if config.EnableMetric { | ||||||
|  | 		logger.SysLog("metric enabled, will disable channel if too much request failed") | ||||||
|  | 	} | ||||||
| 	openai.InitTokenEncoders() | 	openai.InitTokenEncoders() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/blacklist" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if status.(int) == common.UserStatusDisabled { | 	if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "用户已被封禁", | 			"message": "用户已被封禁", | ||||||
| 		}) | 		}) | ||||||
|  | 		session := sessions.Default(c) | ||||||
|  | 		session.Clear() | ||||||
|  | 		_ = session.Save() | ||||||
| 		c.Abort() | 		c.Abort() | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if !userEnabled { | 		if !userEnabled || blacklist.IsUserBanned(token.UserId) { | ||||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -72,7 +72,7 @@ func chooseDB() (*gorm.DB, error) { | |||||||
| func InitDB() (err error) { | func InitDB() (err error) { | ||||||
| 	db, err := chooseDB() | 	db, err := chooseDB() | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		if config.DebugEnabled { | 		if config.DebugSQLEnabled { | ||||||
| 			db = db.Debug() | 			db = db.Debug() | ||||||
| 		} | 		} | ||||||
| 		DB = db | 		DB = db | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/blacklist" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| @@ -40,7 +41,7 @@ func GetMaxUserId() int { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUsers(startIdx int, num int) (users []*User, err error) { | func GetAllUsers(startIdx int, num int) (users []*User, err error) { | ||||||
| 	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error | 	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error | ||||||
| 	return users, err | 	return users, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -123,6 +124,11 @@ func (user *User) Update(updatePassword bool) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	if user.Status == common.UserStatusDisabled { | ||||||
|  | 		blacklist.BanUser(user.Id) | ||||||
|  | 	} else if user.Status == common.UserStatusEnabled { | ||||||
|  | 		blacklist.UnbanUser(user.Id) | ||||||
|  | 	} | ||||||
| 	err = DB.Model(user).Updates(user).Error | 	err = DB.Model(user).Updates(user).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -131,7 +137,10 @@ func (user *User) Delete() error { | |||||||
| 	if user.Id == 0 { | 	if user.Id == 0 { | ||||||
| 		return errors.New("id 为空!") | 		return errors.New("id 为空!") | ||||||
| 	} | 	} | ||||||
| 	err := DB.Delete(user).Error | 	blacklist.BanUser(user.Id) | ||||||
|  | 	user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) | ||||||
|  | 	user.Status = common.UserStatusDeleted | ||||||
|  | 	err := DB.Model(user).Updates(user).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										46
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | package monitor | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func notifyRootUser(subject string, content string) { | ||||||
|  | 	if config.RootUserEmail == "" { | ||||||
|  | 		config.RootUserEmail = model.GetRootUserEmail() | ||||||
|  | 	} | ||||||
|  | 	err := common.SendEmail(subject, config.RootUserEmail, content) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // DisableChannel disable & notify | ||||||
|  | func DisableChannel(channelId int, channelName string, reason string) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) | ||||||
|  | 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||||
|  | 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func MetricDisableChannel(channelId int, successRate float64) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) | ||||||
|  | 	subject := fmt.Sprintf("通道 #%d 已被禁用", channelId) | ||||||
|  | 	content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", | ||||||
|  | 		config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // EnableChannel enable & notify | ||||||
|  | func EnableChannel(channelId int, channelName string) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) | ||||||
|  | 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||||
|  | 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
							
								
								
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | |||||||
|  | package monitor | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var store = make(map[int][]bool) | ||||||
|  | var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) | ||||||
|  | var metricFailChan = make(chan int, config.MetricFailChanSize) | ||||||
|  |  | ||||||
|  | func consumeSuccess(channelId int) { | ||||||
|  | 	if len(store[channelId]) > config.MetricQueueSize { | ||||||
|  | 		store[channelId] = store[channelId][1:] | ||||||
|  | 	} | ||||||
|  | 	store[channelId] = append(store[channelId], true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func consumeFail(channelId int) (bool, float64) { | ||||||
|  | 	if len(store[channelId]) > config.MetricQueueSize { | ||||||
|  | 		store[channelId] = store[channelId][1:] | ||||||
|  | 	} | ||||||
|  | 	store[channelId] = append(store[channelId], false) | ||||||
|  | 	successCount := 0 | ||||||
|  | 	for _, success := range store[channelId] { | ||||||
|  | 		if success { | ||||||
|  | 			successCount++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	successRate := float64(successCount) / float64(len(store[channelId])) | ||||||
|  | 	if len(store[channelId]) < config.MetricQueueSize { | ||||||
|  | 		return false, successRate | ||||||
|  | 	} | ||||||
|  | 	if successRate < config.MetricSuccessRateThreshold { | ||||||
|  | 		store[channelId] = make([]bool, 0) | ||||||
|  | 		return true, successRate | ||||||
|  | 	} | ||||||
|  | 	return false, successRate | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func metricSuccessConsumer() { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case channelId := <-metricSuccessChan: | ||||||
|  | 			consumeSuccess(channelId) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func metricFailConsumer() { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case channelId := <-metricFailChan: | ||||||
|  | 			disable, successRate := consumeFail(channelId) | ||||||
|  | 			if disable { | ||||||
|  | 				go MetricDisableChannel(channelId, successRate) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	if config.EnableMetric { | ||||||
|  | 		go metricSuccessConsumer() | ||||||
|  | 		go metricFailConsumer() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Emit(channelId int, success bool) { | ||||||
|  | 	if !config.EnableMetric { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	go func() { | ||||||
|  | 		if success { | ||||||
|  | 			metricSuccessChan <- channelId | ||||||
|  | 		} else { | ||||||
|  | 			metricFailChan <- channelId | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
| @@ -5,7 +5,6 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -20,7 +19,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
| 	return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil | 	return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||||
| @@ -31,6 +30,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut | |||||||
| 		anthropicVersion = "2023-06-01" | 		anthropicVersion = "2023-06-01" | ||||||
| 	} | 	} | ||||||
| 	req.Header.Set("anthropic-version", anthropicVersion) | 	req.Header.Set("anthropic-version", anthropicVersion) | ||||||
|  | 	req.Header.Set("anthropic-beta", "messages-2023-12-15") | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -47,9 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | |||||||
|  |  | ||||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		var responseText string | 		err, usage = StreamHandler(c, resp) | ||||||
| 		err, responseText = StreamHandler(c, resp) |  | ||||||
| 		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) |  | ||||||
| 	} else { | 	} else { | ||||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,5 +1,8 @@ | |||||||
| package anthropic | package anthropic | ||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", | 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||||
|  | 	"claude-3-haiku-20240229", | ||||||
|  | 	"claude-3-sonnet-20240229", | ||||||
|  | 	"claude-3-opus-20240229", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/image" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| @@ -15,73 +16,135 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func stopReasonClaude2OpenAI(reason string) string { | func stopReasonClaude2OpenAI(reason *string) string { | ||||||
| 	switch reason { | 	if reason == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	switch *reason { | ||||||
|  | 	case "end_turn": | ||||||
|  | 		return "stop" | ||||||
| 	case "stop_sequence": | 	case "stop_sequence": | ||||||
| 		return "stop" | 		return "stop" | ||||||
| 	case "max_tokens": | 	case "max_tokens": | ||||||
| 		return "length" | 		return "length" | ||||||
| 	default: | 	default: | ||||||
| 		return reason | 		return *reason | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||||
| 	claudeRequest := Request{ | 	claudeRequest := Request{ | ||||||
| 		Model:             textRequest.Model, | 		Model:       textRequest.Model, | ||||||
| 		Prompt:            "", | 		MaxTokens:   textRequest.MaxTokens, | ||||||
| 		MaxTokensToSample: textRequest.MaxTokens, | 		Temperature: textRequest.Temperature, | ||||||
| 		StopSequences:     nil, | 		TopP:        textRequest.TopP, | ||||||
| 		Temperature:       textRequest.Temperature, | 		Stream:      textRequest.Stream, | ||||||
| 		TopP:              textRequest.TopP, |  | ||||||
| 		Stream:            textRequest.Stream, |  | ||||||
| 	} | 	} | ||||||
| 	if claudeRequest.MaxTokensToSample == 0 { | 	if claudeRequest.MaxTokens == 0 { | ||||||
| 		claudeRequest.MaxTokensToSample = 1000000 | 		claudeRequest.MaxTokens = 4096 | ||||||
|  | 	} | ||||||
|  | 	// legacy model name mapping | ||||||
|  | 	if claudeRequest.Model == "claude-instant-1" { | ||||||
|  | 		claudeRequest.Model = "claude-instant-1.1" | ||||||
|  | 	} else if claudeRequest.Model == "claude-2" { | ||||||
|  | 		claudeRequest.Model = "claude-2.1" | ||||||
| 	} | 	} | ||||||
| 	prompt := "" |  | ||||||
| 	for _, message := range textRequest.Messages { | 	for _, message := range textRequest.Messages { | ||||||
| 		if message.Role == "user" { | 		if message.Role == "system" && claudeRequest.System == "" { | ||||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | 			claudeRequest.System = message.StringContent() | ||||||
| 		} else if message.Role == "assistant" { | 			continue | ||||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) |  | ||||||
| 		} else if message.Role == "system" { |  | ||||||
| 			if prompt == "" { |  | ||||||
| 				prompt = message.StringContent() |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
|  | 		claudeMessage := Message{ | ||||||
|  | 			Role: message.Role, | ||||||
|  | 		} | ||||||
|  | 		var content Content | ||||||
|  | 		if message.IsStringContent() { | ||||||
|  | 			content.Type = "text" | ||||||
|  | 			content.Text = message.StringContent() | ||||||
|  | 			claudeMessage.Content = append(claudeMessage.Content, content) | ||||||
|  | 			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		var contents []Content | ||||||
|  | 		openaiContent := message.ParseContent() | ||||||
|  | 		for _, part := range openaiContent { | ||||||
|  | 			var content Content | ||||||
|  | 			if part.Type == model.ContentTypeText { | ||||||
|  | 				content.Type = "text" | ||||||
|  | 				content.Text = part.Text | ||||||
|  | 			} else if part.Type == model.ContentTypeImageURL { | ||||||
|  | 				content.Type = "image" | ||||||
|  | 				content.Source = &ImageSource{ | ||||||
|  | 					Type: "base64", | ||||||
|  | 				} | ||||||
|  | 				mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) | ||||||
|  | 				content.Source.MediaType = mimeType | ||||||
|  | 				content.Source.Data = data | ||||||
|  | 			} | ||||||
|  | 			contents = append(contents, content) | ||||||
|  | 		} | ||||||
|  | 		claudeMessage.Content = contents | ||||||
|  | 		claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||||
| 	} | 	} | ||||||
| 	prompt += "\n\nAssistant:" |  | ||||||
| 	claudeRequest.Prompt = prompt |  | ||||||
| 	return &claudeRequest | 	return &claudeRequest | ||||||
| } | } | ||||||
|  |  | ||||||
| func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { | // https://docs.anthropic.com/claude/reference/messages-streaming | ||||||
|  | func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { | ||||||
|  | 	var response *Response | ||||||
|  | 	var responseText string | ||||||
|  | 	var stopReason string | ||||||
|  | 	switch claudeResponse.Type { | ||||||
|  | 	case "message_start": | ||||||
|  | 		return nil, claudeResponse.Message | ||||||
|  | 	case "content_block_start": | ||||||
|  | 		if claudeResponse.ContentBlock != nil { | ||||||
|  | 			responseText = claudeResponse.ContentBlock.Text | ||||||
|  | 		} | ||||||
|  | 	case "content_block_delta": | ||||||
|  | 		if claudeResponse.Delta != nil { | ||||||
|  | 			responseText = claudeResponse.Delta.Text | ||||||
|  | 		} | ||||||
|  | 	case "message_delta": | ||||||
|  | 		if claudeResponse.Usage != nil { | ||||||
|  | 			response = &Response{ | ||||||
|  | 				Usage: *claudeResponse.Usage, | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { | ||||||
|  | 			stopReason = *claudeResponse.Delta.StopReason | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = claudeResponse.Completion | 	choice.Delta.Content = responseText | ||||||
| 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | 	choice.Delta.Role = "assistant" | ||||||
|  | 	finishReason := stopReasonClaude2OpenAI(&stopReason) | ||||||
| 	if finishReason != "null" { | 	if finishReason != "null" { | ||||||
| 		choice.FinishReason = &finishReason | 		choice.FinishReason = &finishReason | ||||||
| 	} | 	} | ||||||
| 	var response openai.ChatCompletionsStreamResponse | 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||||
| 	response.Object = "chat.completion.chunk" | 	openaiResponse.Object = "chat.completion.chunk" | ||||||
| 	response.Model = claudeResponse.Model | 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||||
| 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | 	return &openaiResponse, response | ||||||
| 	return &response |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||||
|  | 	var responseText string | ||||||
|  | 	if len(claudeResponse.Content) > 0 { | ||||||
|  | 		responseText = claudeResponse.Content[0].Text | ||||||
|  | 	} | ||||||
| 	choice := openai.TextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: model.Message{ | 		Message: model.Message{ | ||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | 			Content: responseText, | ||||||
| 			Name:    nil, | 			Name:    nil, | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := openai.TextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | 		Id:      fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), | ||||||
|  | 		Model:   claudeResponse.Model, | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Choices: []openai.TextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| @@ -89,17 +152,15 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | |||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
|  |  | ||||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
| 	responseText := "" |  | ||||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) |  | ||||||
| 	createdTime := helper.GetTimestamp() | 	createdTime := helper.GetTimestamp() | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| 		if atEOF && len(data) == 0 { | 		if atEOF && len(data) == 0 { | ||||||
| 			return 0, nil, nil | 			return 0, nil, nil | ||||||
| 		} | 		} | ||||||
| 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
| 			return i + 4, data[0:i], nil | 			return i + 1, data[0:i], nil | ||||||
| 		} | 		} | ||||||
| 		if atEOF { | 		if atEOF { | ||||||
| 			return len(data), data, nil | 			return len(data), data, nil | ||||||
| @@ -111,29 +172,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| 	go func() { | 	go func() { | ||||||
| 		for scanner.Scan() { | 		for scanner.Scan() { | ||||||
| 			data := scanner.Text() | 			data := scanner.Text() | ||||||
| 			if !strings.HasPrefix(data, "event: completion") { | 			if len(data) < 6 { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | 			if !strings.HasPrefix(data, "data: ") { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = strings.TrimPrefix(data, "data: ") | ||||||
| 			dataChan <- data | 			dataChan <- data | ||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	common.SetEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
|  | 	var usage model.Usage | ||||||
|  | 	var modelName string | ||||||
|  | 	var id string | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| 			// some implementations may add \r at the end of data | 			// some implementations may add \r at the end of data | ||||||
| 			data = strings.TrimSuffix(data, "\r") | 			data = strings.TrimSuffix(data, "\r") | ||||||
| 			var claudeResponse Response | 			var claudeResponse StreamResponse | ||||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| 				return true | 				return true | ||||||
| 			} | 			} | ||||||
| 			responseText += claudeResponse.Completion | 			response, meta := streamResponseClaude2OpenAI(&claudeResponse) | ||||||
| 			response := streamResponseClaude2OpenAI(&claudeResponse) | 			if meta != nil { | ||||||
| 			response.Id = responseId | 				usage.PromptTokens += meta.Usage.InputTokens | ||||||
|  | 				usage.CompletionTokens += meta.Usage.OutputTokens | ||||||
|  | 				modelName = meta.Model | ||||||
|  | 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if response == nil { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			response.Id = id | ||||||
|  | 			response.Model = modelName | ||||||
| 			response.Created = createdTime | 			response.Created = createdTime | ||||||
| 			jsonStr, err := json.Marshal(response) | 			jsonStr, err := json.Marshal(response) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -147,11 +224,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| 			return false | 			return false | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	_ = resp.Body.Close() | ||||||
| 	if err != nil { | 	return nil, &usage | ||||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |  | ||||||
| 	} |  | ||||||
| 	return nil, responseText |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
| @@ -181,11 +255,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st | |||||||
| 	} | 	} | ||||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||||
| 	fullTextResponse.Model = modelName | 	fullTextResponse.Model = modelName | ||||||
| 	completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName) |  | ||||||
| 	usage := model.Usage{ | 	usage := model.Usage{ | ||||||
| 		PromptTokens:     promptTokens, | 		PromptTokens:     claudeResponse.Usage.InputTokens, | ||||||
| 		CompletionTokens: completionTokens, | 		CompletionTokens: claudeResponse.Usage.OutputTokens, | ||||||
| 		TotalTokens:      promptTokens + completionTokens, | 		TotalTokens:      claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse.Usage = usage | 	fullTextResponse.Usage = usage | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|   | |||||||
| @@ -1,19 +1,44 @@ | |||||||
| package anthropic | package anthropic | ||||||
|  |  | ||||||
|  | // https://docs.anthropic.com/claude/reference/messages_post | ||||||
|  |  | ||||||
| type Metadata struct { | type Metadata struct { | ||||||
| 	UserId string `json:"user_id"` | 	UserId string `json:"user_id"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type ImageSource struct { | ||||||
|  | 	Type      string `json:"type"` | ||||||
|  | 	MediaType string `json:"media_type"` | ||||||
|  | 	Data      string `json:"data"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Content struct { | ||||||
|  | 	Type   string       `json:"type"` | ||||||
|  | 	Text   string       `json:"text,omitempty"` | ||||||
|  | 	Source *ImageSource `json:"source,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Role    string    `json:"role"` | ||||||
|  | 	Content []Content `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type Request struct { | type Request struct { | ||||||
| 	Model             string   `json:"model"` | 	Model         string    `json:"model"` | ||||||
| 	Prompt            string   `json:"prompt"` | 	Messages      []Message `json:"messages"` | ||||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | 	System        string    `json:"system,omitempty"` | ||||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | 	MaxTokens     int       `json:"max_tokens,omitempty"` | ||||||
| 	Temperature       float64  `json:"temperature,omitempty"` | 	StopSequences []string  `json:"stop_sequences,omitempty"` | ||||||
| 	TopP              float64  `json:"top_p,omitempty"` | 	Stream        bool      `json:"stream,omitempty"` | ||||||
| 	TopK              int      `json:"top_k,omitempty"` | 	Temperature   float64   `json:"temperature,omitempty"` | ||||||
|  | 	TopP          float64   `json:"top_p,omitempty"` | ||||||
|  | 	TopK          int       `json:"top_k,omitempty"` | ||||||
| 	//Metadata    `json:"metadata,omitempty"` | 	//Metadata    `json:"metadata,omitempty"` | ||||||
| 	Stream bool `json:"stream,omitempty"` | } | ||||||
|  |  | ||||||
|  | type Usage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type Error struct { | type Error struct { | ||||||
| @@ -22,8 +47,29 @@ type Error struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type Response struct { | type Response struct { | ||||||
| 	Completion string `json:"completion"` | 	Id           string    `json:"id"` | ||||||
| 	StopReason string `json:"stop_reason"` | 	Type         string    `json:"type"` | ||||||
| 	Model      string `json:"model"` | 	Role         string    `json:"role"` | ||||||
| 	Error      Error  `json:"error"` | 	Content      []Content `json:"content"` | ||||||
|  | 	Model        string    `json:"model"` | ||||||
|  | 	StopReason   *string   `json:"stop_reason"` | ||||||
|  | 	StopSequence *string   `json:"stop_sequence"` | ||||||
|  | 	Usage        Usage     `json:"usage"` | ||||||
|  | 	Error        Error     `json:"error"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Delta struct { | ||||||
|  | 	Type         string  `json:"type"` | ||||||
|  | 	Text         string  `json:"text"` | ||||||
|  | 	StopReason   *string `json:"stop_reason"` | ||||||
|  | 	StopSequence *string `json:"stop_sequence"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type StreamResponse struct { | ||||||
|  | 	Type         string    `json:"type"` | ||||||
|  | 	Message      *Response `json:"message"` | ||||||
|  | 	Index        int       `json:"index"` | ||||||
|  | 	ContentBlock *Content  `json:"content_block"` | ||||||
|  | 	Delta        *Delta    `json:"delta"` | ||||||
|  | 	Usage        *Usage    `json:"usage"` | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package baidu | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| @@ -9,6 +10,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Adaptor struct { | type Adaptor struct { | ||||||
| @@ -20,23 +22,33 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | |||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t | 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t | ||||||
| 	var fullRequestURL string | 	suffix := "chat/" | ||||||
| 	switch meta.ActualModelName { | 	if strings.HasPrefix("Embedding", meta.ActualModelName) { | ||||||
| 	case "ERNIE-Bot-4": | 		suffix = "embeddings/" | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" |  | ||||||
| 	case "ERNIE-Bot-8K": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" |  | ||||||
| 	case "ERNIE-Bot": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" |  | ||||||
| 	case "ERNIE-Speed": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" |  | ||||||
| 	case "ERNIE-Bot-turbo": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" |  | ||||||
| 	case "BLOOMZ-7B": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" |  | ||||||
| 	case "Embedding-V1": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" |  | ||||||
| 	} | 	} | ||||||
|  | 	switch meta.ActualModelName { | ||||||
|  | 	case "ERNIE-4.0": | ||||||
|  | 		suffix += "completions_pro" | ||||||
|  | 	case "ERNIE-Bot-4": | ||||||
|  | 		suffix += "completions_pro" | ||||||
|  | 	case "ERNIE-3.5-8K": | ||||||
|  | 		suffix += "completions" | ||||||
|  | 	case "ERNIE-Bot-8K": | ||||||
|  | 		suffix += "ernie_bot_8k" | ||||||
|  | 	case "ERNIE-Bot": | ||||||
|  | 		suffix += "completions" | ||||||
|  | 	case "ERNIE-Speed": | ||||||
|  | 		suffix += "ernie_speed" | ||||||
|  | 	case "ERNIE-Bot-turbo": | ||||||
|  | 		suffix += "eb-instant" | ||||||
|  | 	case "BLOOMZ-7B": | ||||||
|  | 		suffix += "bloomz_7b1" | ||||||
|  | 	case "Embedding-V1": | ||||||
|  | 		suffix += "embedding-v1" | ||||||
|  | 	default: | ||||||
|  | 		suffix += meta.ActualModelName | ||||||
|  | 	} | ||||||
|  | 	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) | ||||||
| 	var accessToken string | 	var accessToken string | ||||||
| 	var err error | 	var err error | ||||||
| 	if accessToken, err = GetAccessToken(meta.APIKey); err != nil { | 	if accessToken, err = GetAccessToken(meta.APIKey); err != nil { | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								relay/channel/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/channel/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | package groq | ||||||
|  |  | ||||||
|  | // https://console.groq.com/docs/models | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"gemma-7b-it", | ||||||
|  | 	"llama2-7b-2048", | ||||||
|  | 	"llama2-70b-4096", | ||||||
|  | 	"mixtral-8x7b-32768", | ||||||
|  | } | ||||||
| @@ -6,11 +6,7 @@ import ( | |||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/baichuan" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" | 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/mistral" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -86,37 +82,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetModelList() []string { | func (a *Adaptor) GetModelList() []string { | ||||||
| 	switch a.ChannelType { | 	_, modelList := GetCompatibleChannelMeta(a.ChannelType) | ||||||
| 	case common.ChannelType360: | 	return modelList | ||||||
| 		return ai360.ModelList |  | ||||||
| 	case common.ChannelTypeMoonshot: |  | ||||||
| 		return moonshot.ModelList |  | ||||||
| 	case common.ChannelTypeBaichuan: |  | ||||||
| 		return baichuan.ModelList |  | ||||||
| 	case common.ChannelTypeMinimax: |  | ||||||
| 		return minimax.ModelList |  | ||||||
| 	case common.ChannelTypeMistral: |  | ||||||
| 		return mistral.ModelList |  | ||||||
| 	default: |  | ||||||
| 		return ModelList |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetChannelName() string { | func (a *Adaptor) GetChannelName() string { | ||||||
| 	switch a.ChannelType { | 	channelName, _ := GetCompatibleChannelMeta(a.ChannelType) | ||||||
| 	case common.ChannelTypeAzure: | 	return channelName | ||||||
| 		return "azure" |  | ||||||
| 	case common.ChannelType360: |  | ||||||
| 		return "360" |  | ||||||
| 	case common.ChannelTypeMoonshot: |  | ||||||
| 		return "moonshot" |  | ||||||
| 	case common.ChannelTypeBaichuan: |  | ||||||
| 		return "baichuan" |  | ||||||
| 	case common.ChannelTypeMinimax: |  | ||||||
| 		return "minimax" |  | ||||||
| 	case common.ChannelTypeMistral: |  | ||||||
| 		return "mistralai" |  | ||||||
| 	default: |  | ||||||
| 		return "openai" |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										42
									
								
								relay/channel/openai/compatible.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								relay/channel/openai/compatible.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | |||||||
|  | package openai | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/groq" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/mistral" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var CompatibleChannels = []int{ | ||||||
|  | 	common.ChannelTypeAzure, | ||||||
|  | 	common.ChannelType360, | ||||||
|  | 	common.ChannelTypeMoonshot, | ||||||
|  | 	common.ChannelTypeBaichuan, | ||||||
|  | 	common.ChannelTypeMinimax, | ||||||
|  | 	common.ChannelTypeMistral, | ||||||
|  | 	common.ChannelTypeGroq, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||||
|  | 	switch channelType { | ||||||
|  | 	case common.ChannelTypeAzure: | ||||||
|  | 		return "azure", ModelList | ||||||
|  | 	case common.ChannelType360: | ||||||
|  | 		return "360", ai360.ModelList | ||||||
|  | 	case common.ChannelTypeMoonshot: | ||||||
|  | 		return "moonshot", moonshot.ModelList | ||||||
|  | 	case common.ChannelTypeBaichuan: | ||||||
|  | 		return "baichuan", baichuan.ModelList | ||||||
|  | 	case common.ChannelTypeMinimax: | ||||||
|  | 		return "minimax", minimax.ModelList | ||||||
|  | 	case common.ChannelTypeMistral: | ||||||
|  | 		return "mistralai", mistral.ModelList | ||||||
|  | 	case common.ChannelTypeGroq: | ||||||
|  | 		return "groq", groq.ModelList | ||||||
|  | 	default: | ||||||
|  | 		return "openai", ModelList | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -28,17 +28,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | |||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	for i := 0; i < len(request.Messages); i++ { | 	for i := 0; i < len(request.Messages); i++ { | ||||||
| 		message := request.Messages[i] | 		message := request.Messages[i] | ||||||
| 		if message.Role == "system" { |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		messages = append(messages, Message{ | 		messages = append(messages, Message{ | ||||||
| 			Content: message.StringContent(), | 			Content: message.StringContent(), | ||||||
| 			Role:    message.Role, | 			Role:    message.Role, | ||||||
|   | |||||||
| @@ -27,21 +27,10 @@ import ( | |||||||
| func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		messages = append(messages, Message{ | ||||||
| 			messages = append(messages, Message{ | 			Role:    message.Role, | ||||||
| 				Role:    "user", | 			Content: message.StringContent(), | ||||||
| 				Content: message.StringContent(), | 		}) | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    message.Role, |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	xunfeiRequest := ChatRequest{} | 	xunfeiRequest := ChatRequest{} | ||||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||||
|   | |||||||
| @@ -76,21 +76,10 @@ func GetToken(apikey string) string { | |||||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *Request { | func ConvertRequest(request model.GeneralOpenAIRequest) *Request { | ||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		messages = append(messages, Message{ | ||||||
| 			messages = append(messages, Message{ | 			Role:    message.Role, | ||||||
| 				Role:    "system", | 			Content: message.StringContent(), | ||||||
| 				Content: message.StringContent(), | 		}) | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    message.Role, |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	return &Request{ | 	return &Request{ | ||||||
| 		Prompt:      messages, | 		Prompt:      messages, | ||||||
|   | |||||||
| @@ -83,11 +83,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 		logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) | 		logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) | ||||||
| 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | 	errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if errorHappened { | ||||||
| 		util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) | 		util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) | ||||||
| 		return util.RelayErrorHandler(resp) | 		return util.RelayErrorHandler(resp) | ||||||
| 	} | 	} | ||||||
|  | 	meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
|  |  | ||||||
| 	// do response | 	// do response | ||||||
| 	usage, respErr := adaptor.DoResponse(c, resp, meta) | 	usage, respErr := adaptor.DoResponse(c, resp, meta) | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | ||||||
| 	{ | 	{ | ||||||
| 		apiRouter.GET("/status", controller.GetStatus) | 		apiRouter.GET("/status", controller.GetStatus) | ||||||
|  | 		apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) | ||||||
| 		apiRouter.GET("/notice", controller.GetNotice) | 		apiRouter.GET("/notice", controller.GetNotice) | ||||||
| 		apiRouter.GET("/about", controller.GetAbout) | 		apiRouter.GET("/about", controller.GetAbout) | ||||||
| 		apiRouter.GET("/home_page_content", controller.GetHomePageContent) | 		apiRouter.GET("/home_page_content", controller.GetHomePageContent) | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ export const CHANNEL_OPTIONS = { | |||||||
|     key: 3, |     key: 3, | ||||||
|     text: 'Azure OpenAI', |     text: 'Azure OpenAI', | ||||||
|     value: 3, |     value: 3, | ||||||
|     color: 'orange' |     color: 'secondary' | ||||||
|   }, |   }, | ||||||
|   11: { |   11: { | ||||||
|     key: 11, |     key: 11, | ||||||
| @@ -89,6 +89,12 @@ export const CHANNEL_OPTIONS = { | |||||||
|     value: 27, |     value: 27, | ||||||
|     color: 'default' |     color: 'default' | ||||||
|   }, |   }, | ||||||
|  |   29: { | ||||||
|  |     key: 29, | ||||||
|  |     text: 'Groq', | ||||||
|  |     value: 29, | ||||||
|  |     color: 'default' | ||||||
|  |   }, | ||||||
|   8: { |   8: { | ||||||
|     key: 8, |     key: 8, | ||||||
|     text: '自定义渠道', |     text: '自定义渠道', | ||||||
|   | |||||||
| @@ -163,6 +163,9 @@ const typeConfig = { | |||||||
|     }, |     }, | ||||||
|     modelGroup: "minimax", |     modelGroup: "minimax", | ||||||
|   }, |   }, | ||||||
|  |   29: { | ||||||
|  |     modelGroup: "groq", | ||||||
|  |   }, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| export { defaultConfig, typeConfig }; | export { defaultConfig, typeConfig }; | ||||||
|   | |||||||
| @@ -1,7 +1,16 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link } from 'react-router-dom'; | ||||||
| import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | import { | ||||||
|  |   API, | ||||||
|  |   loadChannelModels, | ||||||
|  |   setPromptShown, | ||||||
|  |   shouldShowPrompt, | ||||||
|  |   showError, | ||||||
|  |   showInfo, | ||||||
|  |   showSuccess, | ||||||
|  |   timestamp2string | ||||||
|  | } from '../helpers'; | ||||||
|  |  | ||||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||||
| import { renderGroup, renderNumber } from '../helpers/render'; | import { renderGroup, renderNumber } from '../helpers/render'; | ||||||
| @@ -95,6 +104,7 @@ const ChannelsTable = () => { | |||||||
|       .catch((reason) => { |       .catch((reason) => { | ||||||
|         showError(reason); |         showError(reason); | ||||||
|       }); |       }); | ||||||
|  |     loadChannelModels().then(); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|   const manageChannel = async (id, action, idx, value) => { |   const manageChannel = async (id, action, idx, value) => { | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ export const CHANNEL_OPTIONS = [ | |||||||
|   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, |   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, | ||||||
|   { key: 26, text: '百川大模型', value: 26, color: 'orange' }, |   { key: 26, text: '百川大模型', value: 26, color: 'orange' }, | ||||||
|   { key: 27, text: 'MiniMax', value: 27, color: 'red' }, |   { key: 27, text: 'MiniMax', value: 27, color: 'red' }, | ||||||
|  |   { key: 29, text: 'Groq', value: 29, color: 'orange' }, | ||||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, |   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, |   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|   | |||||||
| @@ -1,11 +1,13 @@ | |||||||
| import { toast } from 'react-toastify'; | import { toast } from 'react-toastify'; | ||||||
| import { toastConstants } from '../constants'; | import { toastConstants } from '../constants'; | ||||||
| import React from 'react'; | import React from 'react'; | ||||||
|  | import { API } from './api'; | ||||||
|  |  | ||||||
| const HTMLToastContent = ({ htmlContent }) => { | const HTMLToastContent = ({ htmlContent }) => { | ||||||
|   return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />; |   return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />; | ||||||
| }; | }; | ||||||
| export default HTMLToastContent; | export default HTMLToastContent; | ||||||
|  |  | ||||||
| export function isAdmin() { | export function isAdmin() { | ||||||
|   let user = localStorage.getItem('user'); |   let user = localStorage.getItem('user'); | ||||||
|   if (!user) return false; |   if (!user) return false; | ||||||
| @@ -29,7 +31,7 @@ export function getSystemName() { | |||||||
| export function getLogo() { | export function getLogo() { | ||||||
|   let logo = localStorage.getItem('logo'); |   let logo = localStorage.getItem('logo'); | ||||||
|   if (!logo) return '/logo.png'; |   if (!logo) return '/logo.png'; | ||||||
|   return logo |   return logo; | ||||||
| } | } | ||||||
|  |  | ||||||
| export function getFooterHTML() { | export function getFooterHTML() { | ||||||
| @@ -196,4 +198,30 @@ export function shouldShowPrompt(id) { | |||||||
|  |  | ||||||
| export function setPromptShown(id) { | export function setPromptShown(id) { | ||||||
|   localStorage.setItem(`prompt-${id}`, 'true'); |   localStorage.setItem(`prompt-${id}`, 'true'); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | let channelModels = undefined; | ||||||
|  | export async function loadChannelModels() { | ||||||
|  |   const res = await API.get('/api/models'); | ||||||
|  |   const { success, data } = res.data; | ||||||
|  |   if (!success) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |   channelModels = data; | ||||||
|  |   localStorage.setItem('channel_models', JSON.stringify(data)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export function getChannelModels(type) { | ||||||
|  |   if (channelModels !== undefined && type in channelModels) { | ||||||
|  |     return channelModels[type]; | ||||||
|  |   } | ||||||
|  |   let models = localStorage.getItem('channel_models'); | ||||||
|  |   if (!models) { | ||||||
|  |     return []; | ||||||
|  |   } | ||||||
|  |   channelModels = JSON.parse(models); | ||||||
|  |   if (type in channelModels) { | ||||||
|  |     return channelModels[type]; | ||||||
|  |   } | ||||||
|  |   return []; | ||||||
| } | } | ||||||
| @@ -1,7 +1,7 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | ||||||
| import { useNavigate, useParams } from 'react-router-dom'; | import { useNavigate, useParams } from 'react-router-dom'; | ||||||
| import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | import { API, copy, getChannelModels, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | ||||||
| import { CHANNEL_OPTIONS } from '../../constants'; | import { CHANNEL_OPTIONS } from '../../constants'; | ||||||
|  |  | ||||||
| const MODEL_MAPPING_EXAMPLE = { | const MODEL_MAPPING_EXAMPLE = { | ||||||
| @@ -56,60 +56,12 @@ const EditChannel = () => { | |||||||
|   const [customModel, setCustomModel] = useState(''); |   const [customModel, setCustomModel] = useState(''); | ||||||
|   const handleInputChange = (e, { name, value }) => { |   const handleInputChange = (e, { name, value }) => { | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|     if (name === 'type' && inputs.models.length === 0) { |     if (name === 'type') { | ||||||
|       let localModels = []; |       let localModels = getChannelModels(value); | ||||||
|       switch (value) { |       if (inputs.models.length === 0) { | ||||||
|         case 14: |         setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||||
|           localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; |  | ||||||
|           break; |  | ||||||
|         case 11: |  | ||||||
|           localModels = ['PaLM-2']; |  | ||||||
|           break; |  | ||||||
|         case 15: |  | ||||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; |  | ||||||
|           break; |  | ||||||
|         case 17: |  | ||||||
|           localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; |  | ||||||
|           let withInternetVersion = []; |  | ||||||
|           for (let i = 0; i < localModels.length; i++) { |  | ||||||
|             if (localModels[i].startsWith('qwen-')) { |  | ||||||
|               withInternetVersion.push(localModels[i] + '-internet'); |  | ||||||
|             } |  | ||||||
|           } |  | ||||||
|           localModels = [...localModels, ...withInternetVersion]; |  | ||||||
|           break; |  | ||||||
|         case 16: |  | ||||||
|           localModels = ["glm-4", "glm-4v", "glm-3-turbo",'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; |  | ||||||
|           break; |  | ||||||
|         case 18: |  | ||||||
|           localModels = [ |  | ||||||
|             'SparkDesk', |  | ||||||
|             'SparkDesk-v1.1', |  | ||||||
|             'SparkDesk-v2.1', |  | ||||||
|             'SparkDesk-v3.1', |  | ||||||
|             'SparkDesk-v3.5' |  | ||||||
|           ]; |  | ||||||
|           break; |  | ||||||
|         case 19: |  | ||||||
|           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; |  | ||||||
|           break; |  | ||||||
|         case 23: |  | ||||||
|           localModels = ['hunyuan']; |  | ||||||
|           break; |  | ||||||
|         case 24: |  | ||||||
|           localModels = ['gemini-pro', 'gemini-pro-vision']; |  | ||||||
|           break; |  | ||||||
|         case 25: |  | ||||||
|           localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']; |  | ||||||
|           break; |  | ||||||
|         case 26: |  | ||||||
|           localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding']; |  | ||||||
|           break; |  | ||||||
|         case 27: |  | ||||||
|           localModels = ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat']; |  | ||||||
|           break; |  | ||||||
|       } |       } | ||||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); |       setBasicModels(localModels); | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
| @@ -262,6 +214,7 @@ const EditChannel = () => { | |||||||
|               label='类型' |               label='类型' | ||||||
|               name='type' |               name='type' | ||||||
|               required |               required | ||||||
|  |               search | ||||||
|               options={CHANNEL_OPTIONS} |               options={CHANNEL_OPTIONS} | ||||||
|               value={inputs.type} |               value={inputs.type} | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
| @@ -390,6 +343,8 @@ const EditChannel = () => { | |||||||
|               required |               required | ||||||
|               fluid |               fluid | ||||||
|               multiple |               multiple | ||||||
|  |               search | ||||||
|  |               onLabelClick={(e, { value }) => {copy(value).then()}} | ||||||
|               selection |               selection | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={inputs.models} |               value={inputs.models} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user