mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 05:13:41 +08:00 
			
		
		
		
	Merge remote-tracking branch 'origin/main'
# Conflicts: # controller/relay.go # main.go # middleware/distributor.go
This commit is contained in:
		
							
								
								
									
										15
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								README.md
									
									
									
									
									
								
							@@ -68,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
 | 
			
		||||
   + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
 | 
			
		||||
   + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 | 
			
		||||
   + [x] [360 智脑](https://ai.360.cn)
 | 
			
		||||
2. 支持配置镜像以及众多第三方代理服务:
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
@@ -108,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
 | 
			
		||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
 | 
			
		||||
 | 
			
		||||
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
 | 
			
		||||
 | 
			
		||||
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
 | 
			
		||||
 | 
			
		||||
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
 | 
			
		||||
@@ -274,8 +277,9 @@ graph LR
 | 
			
		||||
不加的话将会使用负载均衡的方式使用多个渠道。
 | 
			
		||||
 | 
			
		||||
### 环境变量
 | 
			
		||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。
 | 
			
		||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
 | 
			
		||||
   + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
			
		||||
   + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
 | 
			
		||||
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
 | 
			
		||||
   + 例子:`SESSION_SECRET=random_string`
 | 
			
		||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
 | 
			
		||||
@@ -302,6 +306,14 @@ graph LR
 | 
			
		||||
   + 例子:`CHANNEL_TEST_FREQUENCY=1440`
 | 
			
		||||
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
 | 
			
		||||
   + 例子:`POLLING_INTERVAL=5`
 | 
			
		||||
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
 | 
			
		||||
    + 例子:`BATCH_UPDATE_ENABLED=true`
 | 
			
		||||
    + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
 | 
			
		||||
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
 | 
			
		||||
    + 例子:`BATCH_UPDATE_INTERVAL=5`
 | 
			
		||||
12. 请求频率限制:
 | 
			
		||||
    + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
 | 
			
		||||
    + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
 | 
			
		||||
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
@@ -338,6 +350,7 @@ https://openai.justsong.cn
 | 
			
		||||
5. ChatGPT Next Web 报错:`Failed to fetch`
 | 
			
		||||
   + 部署的时候不要设置 `BASE_URL`。
 | 
			
		||||
   + 检查你的接口地址和 API Key 有没有填对。
 | 
			
		||||
   + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
 | 
			
		||||
6. 报错:`当前分组负载已饱和,请稍后再试`
 | 
			
		||||
   + 上游通道 429 了。
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -98,6 +98,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
			
		||||
 | 
			
		||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 | 
			
		||||
 | 
			
		||||
var BatchUpdateEnabled = false
 | 
			
		||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
@@ -115,10 +118,10 @@ var (
 | 
			
		||||
// All duration's unit is seconds
 | 
			
		||||
// Shouldn't larger then RateLimitKeyExpirationDuration
 | 
			
		||||
var (
 | 
			
		||||
	GlobalApiRateLimitNum            = 180
 | 
			
		||||
	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
 | 
			
		||||
	GlobalApiRateLimitDuration int64 = 3 * 60
 | 
			
		||||
 | 
			
		||||
	GlobalWebRateLimitNum            = 60
 | 
			
		||||
	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
 | 
			
		||||
	GlobalWebRateLimitDuration int64 = 3 * 60
 | 
			
		||||
 | 
			
		||||
	UploadRateLimitNum            = 10
 | 
			
		||||
@@ -158,45 +161,53 @@ const (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ChannelTypeUnknown   = 0
 | 
			
		||||
	ChannelTypeOpenAI    = 1
 | 
			
		||||
	ChannelTypeAPI2D     = 2
 | 
			
		||||
	ChannelTypeAzure     = 3
 | 
			
		||||
	ChannelTypeCloseAI   = 4
 | 
			
		||||
	ChannelTypeOpenAISB  = 5
 | 
			
		||||
	ChannelTypeOpenAIMax = 6
 | 
			
		||||
	ChannelTypeOhMyGPT   = 7
 | 
			
		||||
	ChannelTypeCustom    = 8
 | 
			
		||||
	ChannelTypeAILS      = 9
 | 
			
		||||
	ChannelTypeAIProxy   = 10
 | 
			
		||||
	ChannelTypePaLM      = 11
 | 
			
		||||
	ChannelTypeAPI2GPT   = 12
 | 
			
		||||
	ChannelTypeAIGC2D    = 13
 | 
			
		||||
	ChannelTypeAnthropic = 14
 | 
			
		||||
	ChannelTypeBaidu     = 15
 | 
			
		||||
	ChannelTypeZhipu     = 16
 | 
			
		||||
	ChannelTypeAli       = 17
 | 
			
		||||
	ChannelTypeXunfei    = 18
 | 
			
		||||
	ChannelTypeUnknown        = 0
 | 
			
		||||
	ChannelTypeOpenAI         = 1
 | 
			
		||||
	ChannelTypeAPI2D          = 2
 | 
			
		||||
	ChannelTypeAzure          = 3
 | 
			
		||||
	ChannelTypeCloseAI        = 4
 | 
			
		||||
	ChannelTypeOpenAISB       = 5
 | 
			
		||||
	ChannelTypeOpenAIMax      = 6
 | 
			
		||||
	ChannelTypeOhMyGPT        = 7
 | 
			
		||||
	ChannelTypeCustom         = 8
 | 
			
		||||
	ChannelTypeAILS           = 9
 | 
			
		||||
	ChannelTypeAIProxy        = 10
 | 
			
		||||
	ChannelTypePaLM           = 11
 | 
			
		||||
	ChannelTypeAPI2GPT        = 12
 | 
			
		||||
	ChannelTypeAIGC2D         = 13
 | 
			
		||||
	ChannelTypeAnthropic      = 14
 | 
			
		||||
	ChannelTypeBaidu          = 15
 | 
			
		||||
	ChannelTypeZhipu          = 16
 | 
			
		||||
	ChannelTypeAli            = 17
 | 
			
		||||
	ChannelTypeXunfei         = 18
 | 
			
		||||
	ChannelType360            = 19
 | 
			
		||||
	ChannelTypeOpenRouter     = 20
 | 
			
		||||
	ChannelTypeAIProxyLibrary = 21
 | 
			
		||||
	ChannelTypeFastGPT        = 22
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                               // 0
 | 
			
		||||
	"https://api.openai.com",         // 1
 | 
			
		||||
	"https://oa.api2d.net",           // 2
 | 
			
		||||
	"",                               // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz",  // 4
 | 
			
		||||
	"https://api.openai-sb.com",      // 5
 | 
			
		||||
	"https://api.openaimax.com",      // 6
 | 
			
		||||
	"https://api.ohmygpt.com",        // 7
 | 
			
		||||
	"",                               // 8
 | 
			
		||||
	"https://api.caipacity.com",      // 9
 | 
			
		||||
	"https://api.aiproxy.io",         // 10
 | 
			
		||||
	"",                               // 11
 | 
			
		||||
	"https://api.api2gpt.com",        // 12
 | 
			
		||||
	"https://api.aigc2d.com",         // 13
 | 
			
		||||
	"https://api.anthropic.com",      // 14
 | 
			
		||||
	"https://aip.baidubce.com",       // 15
 | 
			
		||||
	"https://open.bigmodel.cn",       // 16
 | 
			
		||||
	"https://dashscope.aliyuncs.com", // 17
 | 
			
		||||
	"",                               // 18
 | 
			
		||||
	"",                                // 0
 | 
			
		||||
	"https://api.openai.com",          // 1
 | 
			
		||||
	"https://oa.api2d.net",            // 2
 | 
			
		||||
	"",                                // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz",   // 4
 | 
			
		||||
	"https://api.openai-sb.com",       // 5
 | 
			
		||||
	"https://api.openaimax.com",       // 6
 | 
			
		||||
	"https://api.ohmygpt.com",         // 7
 | 
			
		||||
	"",                                // 8
 | 
			
		||||
	"https://api.caipacity.com",       // 9
 | 
			
		||||
	"https://api.aiproxy.io",          // 10
 | 
			
		||||
	"",                                // 11
 | 
			
		||||
	"https://api.api2gpt.com",         // 12
 | 
			
		||||
	"https://api.aigc2d.com",          // 13
 | 
			
		||||
	"https://api.anthropic.com",       // 14
 | 
			
		||||
	"https://aip.baidubce.com",        // 15
 | 
			
		||||
	"https://open.bigmodel.cn",        // 16
 | 
			
		||||
	"https://dashscope.aliyuncs.com",  // 17
 | 
			
		||||
	"",                                // 18
 | 
			
		||||
	"https://ai.360.cn",               // 19
 | 
			
		||||
	"https://openrouter.ai/api",       // 20
 | 
			
		||||
	"https://api.aiproxy.io",          // 21
 | 
			
		||||
	"https://fastgpt.run/api/openapi", // 22
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -13,46 +13,52 @@ import (
 | 
			
		||||
// 1 === $0.002 / 1K tokens
 | 
			
		||||
// 1 === ¥0.014 / 1k tokens
 | 
			
		||||
var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-4":                   15,
 | 
			
		||||
	"gpt-4-0314":              15,
 | 
			
		||||
	"gpt-4-0613":              15,
 | 
			
		||||
	"gpt-4-32k":               30,
 | 
			
		||||
	"gpt-4-32k-0314":          30,
 | 
			
		||||
	"gpt-4-32k-0613":          30,
 | 
			
		||||
	"gpt-3.5-turbo":           0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0301":      0.75,
 | 
			
		||||
	"gpt-3.5-turbo-0613":      0.75,
 | 
			
		||||
	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-16k-0613":  1.5,
 | 
			
		||||
	"text-ada-001":            0.2,
 | 
			
		||||
	"text-babbage-001":        0.25,
 | 
			
		||||
	"text-curie-001":          1,
 | 
			
		||||
	"text-davinci-002":        10,
 | 
			
		||||
	"text-davinci-003":        10,
 | 
			
		||||
	"text-davinci-edit-001":   10,
 | 
			
		||||
	"code-davinci-edit-001":   10,
 | 
			
		||||
	"whisper-1":               10,
 | 
			
		||||
	"davinci":                 10,
 | 
			
		||||
	"curie":                   10,
 | 
			
		||||
	"babbage":                 10,
 | 
			
		||||
	"ada":                     10,
 | 
			
		||||
	"text-embedding-ada-002":  0.05,
 | 
			
		||||
	"text-search-ada-doc-001": 10,
 | 
			
		||||
	"text-moderation-stable":  0.1,
 | 
			
		||||
	"text-moderation-latest":  0.1,
 | 
			
		||||
	"dall-e":                  8,
 | 
			
		||||
	"claude-instant-1":        0.815,  // $1.63 / 1M tokens
 | 
			
		||||
	"claude-2":                5.51,   // $11.02 / 1M tokens
 | 
			
		||||
	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"PaLM-2":                  1,
 | 
			
		||||
	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
 | 
			
		||||
	"qwen-plus-v1":            0.5715, // Same as above
 | 
			
		||||
	"SparkDesk":               0.8572, // TBD
 | 
			
		||||
	"gpt-4":                     15,
 | 
			
		||||
	"gpt-4-0314":                15,
 | 
			
		||||
	"gpt-4-0613":                15,
 | 
			
		||||
	"gpt-4-32k":                 30,
 | 
			
		||||
	"gpt-4-32k-0314":            30,
 | 
			
		||||
	"gpt-4-32k-0613":            30,
 | 
			
		||||
	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0301":        0.75,
 | 
			
		||||
	"gpt-3.5-turbo-0613":        0.75,
 | 
			
		||||
	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-16k-0613":    1.5,
 | 
			
		||||
	"text-ada-001":              0.2,
 | 
			
		||||
	"text-babbage-001":          0.25,
 | 
			
		||||
	"text-curie-001":            1,
 | 
			
		||||
	"text-davinci-002":          10,
 | 
			
		||||
	"text-davinci-003":          10,
 | 
			
		||||
	"text-davinci-edit-001":     10,
 | 
			
		||||
	"code-davinci-edit-001":     10,
 | 
			
		||||
	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
 | 
			
		||||
	"davinci":                   10,
 | 
			
		||||
	"curie":                     10,
 | 
			
		||||
	"babbage":                   10,
 | 
			
		||||
	"ada":                       10,
 | 
			
		||||
	"text-embedding-ada-002":    0.05,
 | 
			
		||||
	"text-search-ada-doc-001":   10,
 | 
			
		||||
	"text-moderation-stable":    0.1,
 | 
			
		||||
	"text-moderation-latest":    0.1,
 | 
			
		||||
	"dall-e":                    8,
 | 
			
		||||
	"claude-instant-1":          0.815,  // $1.63 / 1M tokens
 | 
			
		||||
	"claude-2":                  5.51,   // $11.02 / 1M tokens
 | 
			
		||||
	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"PaLM-2":                    1,
 | 
			
		||||
	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"qwen-v1":                   0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"qwen-plus-v1":              1,      // ¥0.014 / 1k tokens
 | 
			
		||||
	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
 | 
			
		||||
	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelRatio2JSONString() string {
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypePaLM:
 | 
			
		||||
		fallthrough
 | 
			
		||||
@@ -24,10 +24,19 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeZhipu:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeAli:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelType360:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeXunfei:
 | 
			
		||||
		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		request.Model = "gpt-35-turbo"
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	default:
 | 
			
		||||
		request.Model = "gpt-3.5-turbo"
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
	channel.CreatedTime = common.GetTimestamp()
 | 
			
		||||
	keys := strings.Split(channel.Key, "\n")
 | 
			
		||||
	channels := make([]model.Channel, 0)
 | 
			
		||||
	channels := make([]model.Channel, 0, len(keys))
 | 
			
		||||
	for _, key := range keys {
 | 
			
		||||
		if key == "" {
 | 
			
		||||
			continue
 | 
			
		||||
 
 | 
			
		||||
@@ -63,6 +63,15 @@ func init() {
 | 
			
		||||
			Root:       "dall-e",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "whisper-1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "whisper-1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -351,6 +360,15 @@ func init() {
 | 
			
		||||
			Root:       "qwen-plus-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-embedding-v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "ali",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-embedding-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "SparkDesk",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -360,6 +378,51 @@ func init() {
 | 
			
		||||
			Root:       "SparkDesk",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "360GPT_S2_V9",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "360GPT_S2_V9",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "embedding-bert-512-v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "embedding-bert-512-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "embedding_s1_v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "embedding_s1_v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "semantic_similarity_s1_v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "semantic_similarity_s1_v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "360GPT_S2_V9.4",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "360GPT_S2_V9.4",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,220 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryRequest struct {
 | 
			
		||||
	Model     string `json:"model"`
 | 
			
		||||
	Query     string `json:"query"`
 | 
			
		||||
	LibraryId string `json:"libraryId"`
 | 
			
		||||
	Stream    bool   `json:"stream"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryError struct {
 | 
			
		||||
	ErrCode int    `json:"errCode"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryDocument struct {
 | 
			
		||||
	Title string `json:"title"`
 | 
			
		||||
	URL   string `json:"url"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryResponse struct {
 | 
			
		||||
	Success   bool                     `json:"success"`
 | 
			
		||||
	Answer    string                   `json:"answer"`
 | 
			
		||||
	Documents []AIProxyLibraryDocument `json:"documents"`
 | 
			
		||||
	AIProxyLibraryError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryStreamResponse struct {
 | 
			
		||||
	Content   string                   `json:"content"`
 | 
			
		||||
	Finish    bool                     `json:"finish"`
 | 
			
		||||
	Model     string                   `json:"model"`
 | 
			
		||||
	Documents []AIProxyLibraryDocument `json:"documents"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
 | 
			
		||||
	query := ""
 | 
			
		||||
	if len(request.Messages) != 0 {
 | 
			
		||||
		query = request.Messages[len(request.Messages)-1].Content
 | 
			
		||||
	}
 | 
			
		||||
	return &AIProxyLibraryRequest{
 | 
			
		||||
		Model:  request.Model,
 | 
			
		||||
		Stream: request.Stream,
 | 
			
		||||
		Query:  query,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
 | 
			
		||||
	if len(documents) == 0 {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	content := "\n\n参考文档:\n"
 | 
			
		||||
	for i, document := range documents {
 | 
			
		||||
		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
 | 
			
		||||
	}
 | 
			
		||||
	return content
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
 | 
			
		||||
	content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: content,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: "stop",
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Id:      common.GetUUID(),
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = aiProxyDocuments2Markdown(documents)
 | 
			
		||||
	choice.FinishReason = &stopFinishReason
 | 
			
		||||
	return &ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      common.GetUUID(),
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   "",
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = response.Content
 | 
			
		||||
	return &ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      common.GetUUID(),
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   response.Model,
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	var documents []AIProxyLibraryDocument
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var AIProxyLibraryResponse AIProxyLibraryStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if len(AIProxyLibraryResponse.Documents) != 0 {
 | 
			
		||||
				documents = AIProxyLibraryResponse.Documents
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			response := documentsAIProxyLibrary(documents)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var AIProxyLibraryResponse AIProxyLibraryResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if AIProxyLibraryResponse.ErrCode != 0 {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: AIProxyLibraryResponse.Message,
 | 
			
		||||
				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
 | 
			
		||||
				Code:    AIProxyLibraryResponse.ErrCode,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
@@ -35,6 +35,29 @@ type AliChatRequest struct {
 | 
			
		||||
	Parameters AliParameters `json:"parameters,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliEmbeddingRequest struct {
 | 
			
		||||
	Model string `json:"model"`
 | 
			
		||||
	Input struct {
 | 
			
		||||
		Texts []string `json:"texts"`
 | 
			
		||||
	} `json:"input"`
 | 
			
		||||
	Parameters *struct {
 | 
			
		||||
		TextType string `json:"text_type,omitempty"`
 | 
			
		||||
	} `json:"parameters,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliEmbedding struct {
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
	TextIndex int       `json:"text_index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliEmbeddingResponse struct {
 | 
			
		||||
	Output struct {
 | 
			
		||||
		Embeddings []AliEmbedding `json:"embeddings"`
 | 
			
		||||
	} `json:"output"`
 | 
			
		||||
	Usage AliUsage `json:"usage"`
 | 
			
		||||
	AliError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliError struct {
 | 
			
		||||
	Code      string `json:"code"`
 | 
			
		||||
	Message   string `json:"message"`
 | 
			
		||||
@@ -44,6 +67,7 @@ type AliError struct {
 | 
			
		||||
type AliUsage struct {
 | 
			
		||||
	InputTokens  int `json:"input_tokens"`
 | 
			
		||||
	OutputTokens int `json:"output_tokens"`
 | 
			
		||||
	TotalTokens  int `json:"total_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliOutput struct {
 | 
			
		||||
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
 | 
			
		||||
	return &AliEmbeddingRequest{
 | 
			
		||||
		Model: "text-embedding-v1",
 | 
			
		||||
		Input: struct {
 | 
			
		||||
			Texts []string `json:"texts"`
 | 
			
		||||
		}{
 | 
			
		||||
			Texts: request.ParseInput(),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var aliResponse AliEmbeddingResponse
 | 
			
		||||
	err := json.NewDecoder(resp.Body).Decode(&aliResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if aliResponse.Code != "" {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: aliResponse.Message,
 | 
			
		||||
				Type:    aliResponse.Code,
 | 
			
		||||
				Param:   aliResponse.RequestId,
 | 
			
		||||
				Code:    aliResponse.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
			
		||||
	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
 | 
			
		||||
		Model:  "text-embedding-v1",
 | 
			
		||||
		Usage:  Usage{TotalTokens: response.Usage.TotalTokens},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, item := range response.Output.Embeddings {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
 | 
			
		||||
			Object:    `embedding`,
 | 
			
		||||
			Index:     item.TextIndex,
 | 
			
		||||
			Embedding: item.Embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										147
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,147 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	audioModel := "whisper-1"
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
 | 
			
		||||
	preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
	modelRatio := common.GetModelRatio(audioModel)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(group)
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	if userQuota > 100*preConsumedQuota {
 | 
			
		||||
		// in this case, we do not pre-consume quota
 | 
			
		||||
		// because the user has enough quota
 | 
			
		||||
		preConsumedQuota = 0
 | 
			
		||||
	}
 | 
			
		||||
	if preConsumedQuota > 0 {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
	modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	if modelMapping != "" {
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap[audioModel] != "" {
 | 
			
		||||
			audioModel = modelMap[audioModel]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
	requestBody := c.Request.Body
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
 | 
			
		||||
	resp, err := httpClient.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	var audioResponse AudioResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		go func() {
 | 
			
		||||
			quota := countTokenText(audioResponse.Text, audioModel)
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &audioResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 | 
			
		||||
	baiduEmbeddingRequest := BaiduEmbeddingRequest{
 | 
			
		||||
		Input: nil,
 | 
			
		||||
	return &BaiduEmbeddingRequest{
 | 
			
		||||
		Input: request.ParseInput(),
 | 
			
		||||
	}
 | 
			
		||||
	switch request.Input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
 | 
			
		||||
	case []any:
 | 
			
		||||
		for _, item := range request.Input.([]any) {
 | 
			
		||||
			if str, ok := item.(string); ok {
 | 
			
		||||
				baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &baiduEmbeddingRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,7 @@ const (
 | 
			
		||||
	APITypeZhipu
 | 
			
		||||
	APITypeAli
 | 
			
		||||
	APITypeXunfei
 | 
			
		||||
	APITypeAIProxyLibrary
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var httpClient *http.Client
 | 
			
		||||
@@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		apiType = APITypeAli
 | 
			
		||||
	case common.ChannelTypeXunfei:
 | 
			
		||||
		apiType = APITypeXunfei
 | 
			
		||||
	case common.ChannelTypeAIProxyLibrary:
 | 
			
		||||
		apiType = APITypeAIProxyLibrary
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
@@ -172,6 +175,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
			
		||||
	case APITypeAli:
 | 
			
		||||
		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 | 
			
		||||
		if relayMode == RelayModeEmbeddings {
 | 
			
		||||
			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 | 
			
		||||
	}
 | 
			
		||||
	var promptTokens int
 | 
			
		||||
	var completionTokens int
 | 
			
		||||
@@ -258,8 +266,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeAli:
 | 
			
		||||
		aliRequest := requestOpenAI2Ali(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(aliRequest)
 | 
			
		||||
		var jsonStr []byte
 | 
			
		||||
		var err error
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case RelayModeEmbeddings:
 | 
			
		||||
			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
 | 
			
		||||
			jsonStr, err = json.Marshal(aliEmbeddingRequest)
 | 
			
		||||
		default:
 | 
			
		||||
			aliRequest := requestOpenAI2Ali(textRequest)
 | 
			
		||||
			jsonStr, err = json.Marshal(aliRequest)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
 | 
			
		||||
		aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
 | 
			
		||||
		jsonStr, err := json.Marshal(aiProxyLibraryRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
@@ -287,6 +311,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
				req.Header.Set("api-key", apiKey)
 | 
			
		||||
			} else {
 | 
			
		||||
				req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
				if channelType == common.ChannelTypeOpenRouter {
 | 
			
		||||
					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
 | 
			
		||||
					req.Header.Set("X-Title", "One API")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		case APITypeClaude:
 | 
			
		||||
			req.Header.Set("x-api-key", apiKey)
 | 
			
		||||
@@ -303,6 +331,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if textRequest.Stream {
 | 
			
		||||
				req.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			req.Header.Set("Authorization", "Bearer "+apiKey)
 | 
			
		||||
		}
 | 
			
		||||
		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
@@ -365,7 +395,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
 | 
			
		||||
					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
 | 
			
		||||
					model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
@@ -491,7 +520,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := aliHandler(c, resp)
 | 
			
		||||
			var err *OpenAIErrorWithStatusCode
 | 
			
		||||
			var usage *Usage
 | 
			
		||||
			switch relayMode {
 | 
			
		||||
			case RelayModeEmbeddings:
 | 
			
		||||
				err, usage = aliEmbeddingHandler(c, resp)
 | 
			
		||||
			default:
 | 
			
		||||
				err, usage = aliHandler(c, resp)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
@@ -519,6 +555,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		} else {
 | 
			
		||||
			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage := aiProxyLibraryStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := aiProxyLibraryHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,24 @@ var stopFinishReason = "stop"
 | 
			
		||||
 | 
			
		||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
			
		||||
 | 
			
		||||
func InitTokenEncoders() {
 | 
			
		||||
	common.SysLog("initializing token encoders")
 | 
			
		||||
	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
	for model, _ := range common.ModelRatio {
 | 
			
		||||
		tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
 | 
			
		||||
			tokenEncoderMap[model] = fallbackTokenEncoder
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		tokenEncoderMap[model] = tokenEncoder
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog("token encoders initialized")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
			
		||||
	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
			
		||||
		return tokenEncoder
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -29,6 +28,7 @@ const (
 | 
			
		||||
	RelayModeMidjourneyChange
 | 
			
		||||
	RelayModeMidjourneyNotify
 | 
			
		||||
	RelayModeMidjourneyTaskFetch
 | 
			
		||||
	RelayModeAudio
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/chat
 | 
			
		||||
@@ -45,6 +45,26 @@ type GeneralOpenAIRequest struct {
 | 
			
		||||
	Input       any       `json:"input,omitempty"`
 | 
			
		||||
	Instruction string    `json:"instruction,omitempty"`
 | 
			
		||||
	Size        string    `json:"size,omitempty"`
 | 
			
		||||
	Functions   any       `json:"functions,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
			
		||||
	if r.Input == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	var input []string
 | 
			
		||||
	switch r.Input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		input = []string{r.Input.(string)}
 | 
			
		||||
	case []any:
 | 
			
		||||
		input = make([]string, 0, len(r.Input.([]any)))
 | 
			
		||||
		for _, item := range r.Input.([]any) {
 | 
			
		||||
			if str, ok := item.(string); ok {
 | 
			
		||||
				input = append(input, str)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return input
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
@@ -67,6 +87,10 @@ type ImageRequest struct {
 | 
			
		||||
	Size   string `json:"size"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AudioResponse struct {
 | 
			
		||||
	Text string `json:"text,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Usage struct {
 | 
			
		||||
	PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
	CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
@@ -147,23 +171,6 @@ type CompletionsStreamResponse struct {
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidjourneyRequest struct {
 | 
			
		||||
	Prompt      string   `json:"prompt"`
 | 
			
		||||
	NotifyHook  string   `json:"notifyHook"`
 | 
			
		||||
	Action      string   `json:"action"`
 | 
			
		||||
	Index       int      `json:"index"`
 | 
			
		||||
	State       string   `json:"state"`
 | 
			
		||||
	TaskId      string   `json:"taskId"`
 | 
			
		||||
	Base64Array []string `json:"base64Array"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidjourneyResponse struct {
 | 
			
		||||
	Code        int         `json:"code"`
 | 
			
		||||
	Description string      `json:"description"`
 | 
			
		||||
	Properties  interface{} `json:"properties"`
 | 
			
		||||
	Result      string      `json:"result"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Relay(c *gin.Context) {
 | 
			
		||||
	relayMode := RelayModeUnknown
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
 | 
			
		||||
 
 | 
			
		||||
@@ -523,5 +523,6 @@
 | 
			
		||||
  "按照如下格式输入:": "Enter in the following format:",
 | 
			
		||||
  "模型版本": "Model version",
 | 
			
		||||
  "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
 | 
			
		||||
  "点击查看": "click to view"
 | 
			
		||||
  "点击查看": "click to view",
 | 
			
		||||
  "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.go
									
									
									
									
									
								
							@@ -78,6 +78,12 @@ func main() {
 | 
			
		||||
		go controller.AutomaticallyTestChannels(frequency)
 | 
			
		||||
	}
 | 
			
		||||
	go controller.UpdateMidjourneyTask()
 | 
			
		||||
	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 | 
			
		||||
		common.BatchUpdateEnabled = true
 | 
			
		||||
		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
 | 
			
		||||
		model.InitBatchUpdater()
 | 
			
		||||
	}
 | 
			
		||||
	controller.InitTokenEncoders()
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.Default()
 | 
			
		||||
 
 | 
			
		||||
@@ -109,7 +109,18 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !model.CacheIsUserEnabled(token.UserId) {
 | 
			
		||||
		userEnabled, err := model.IsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !userEnabled {
 | 
			
		||||
			c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": "用户已被封禁",
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
@@ -22,7 +21,6 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
		c.Set("group", userGroup)
 | 
			
		||||
		var channel *model.Channel
 | 
			
		||||
		var err error
 | 
			
		||||
		channelId, ok := c.Get("channelId")
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
@@ -58,7 +56,6 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
 | 
			
		||||
			// Select a channel for the user
 | 
			
		||||
			var modelRequest ModelRequest
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/mj") {
 | 
			
		||||
@@ -79,7 +76,6 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "text-moderation-stable"
 | 
			
		||||
@@ -95,6 +91,11 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					modelRequest.Model = "dall-e"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "whisper-1"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 | 
			
		||||
@@ -118,8 +119,13 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		c.Set("model_mapping", channel.ModelMapping)
 | 
			
		||||
		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
		c.Set("base_url", channel.BaseURL)
 | 
			
		||||
		if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
 | 
			
		||||
		switch channel.Type {
 | 
			
		||||
		case common.ChannelTypeAzure:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		case common.ChannelTypeXunfei:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		case common.ChannelTypeAIProxyLibrary:
 | 
			
		||||
			c.Set("library_id", channel.Other)
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CacheIsUserEnabled(userId int) bool {
 | 
			
		||||
func CacheIsUserEnabled(userId int) (bool, error) {
 | 
			
		||||
	if !common.RedisEnabled {
 | 
			
		||||
		return IsUserEnabled(userId)
 | 
			
		||||
	}
 | 
			
		||||
	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		status := common.UserStatusDisabled
 | 
			
		||||
		if IsUserEnabled(userId) {
 | 
			
		||||
			status = common.UserStatusEnabled
 | 
			
		||||
		}
 | 
			
		||||
		enabled = fmt.Sprintf("%d", status)
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("Redis set user enabled error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return enabled == "1", nil
 | 
			
		||||
	}
 | 
			
		||||
	return enabled == "1"
 | 
			
		||||
 | 
			
		||||
	userEnabled, err := IsUserEnabled(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	enabled = "0"
 | 
			
		||||
	if userEnabled {
 | 
			
		||||
		enabled = "1"
 | 
			
		||||
	}
 | 
			
		||||
	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("Redis set user enabled error: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	return userEnabled, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var group2model2channels map[string]map[string][]*Channel
 | 
			
		||||
 
 | 
			
		||||
@@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	updateChannelUsedQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update channel used quota: " + err.Error())
 | 
			
		||||
 
 | 
			
		||||
@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
			
		||||
	}
 | 
			
		||||
	token, err = CacheGetTokenByKey(key)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		if token.Status == common.TokenStatusExhausted {
 | 
			
		||||
			return nil, errors.New("该令牌额度已用尽")
 | 
			
		||||
		} else if token.Status == common.TokenStatusExpired {
 | 
			
		||||
			return nil, errors.New("该令牌已过期")
 | 
			
		||||
		}
 | 
			
		||||
		if token.Status != common.TokenStatusEnabled {
 | 
			
		||||
			return nil, errors.New("该令牌状态不可用")
 | 
			
		||||
		}
 | 
			
		||||
		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
 | 
			
		||||
			token.Status = common.TokenStatusExpired
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
			if !common.RedisEnabled {
 | 
			
		||||
				token.Status = common.TokenStatusExpired
 | 
			
		||||
				err := token.SelectUpdate()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return nil, errors.New("该令牌已过期")
 | 
			
		||||
		}
 | 
			
		||||
		if !token.UnlimitedQuota && token.RemainQuota <= 0 {
 | 
			
		||||
			token.Status = common.TokenStatusExhausted
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
			if !common.RedisEnabled {
 | 
			
		||||
				// in this case, we can make sure the token is exhausted
 | 
			
		||||
				token.Status = common.TokenStatusExhausted
 | 
			
		||||
				err := token.SelectUpdate()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return nil, errors.New("该令牌额度已用尽")
 | 
			
		||||
		}
 | 
			
		||||
		go func() {
 | 
			
		||||
			token.AccessedTime = common.GetTimestamp()
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("failed to update token" + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
		return token, nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, errors.New("无效的令牌")
 | 
			
		||||
@@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return increaseTokenQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func increaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"remain_quota": gorm.Expr("remain_quota + ?", quota),
 | 
			
		||||
			"used_quota":   gorm.Expr("used_quota - ?", quota),
 | 
			
		||||
			"remain_quota":  gorm.Expr("remain_quota + ?", quota),
 | 
			
		||||
			"used_quota":    gorm.Expr("used_quota - ?", quota),
 | 
			
		||||
			"accessed_time": common.GetTimestamp(),
 | 
			
		||||
		},
 | 
			
		||||
	).Error
 | 
			
		||||
	return err
 | 
			
		||||
@@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return decreaseTokenQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"remain_quota": gorm.Expr("remain_quota - ?", quota),
 | 
			
		||||
			"used_quota":   gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
			"remain_quota":  gorm.Expr("remain_quota - ?", quota),
 | 
			
		||||
			"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
			"accessed_time": common.GetTimestamp(),
 | 
			
		||||
		},
 | 
			
		||||
	).Error
 | 
			
		||||
	return err
 | 
			
		||||
 
 | 
			
		||||
@@ -235,17 +235,16 @@ func IsAdmin(userId int) bool {
 | 
			
		||||
	return user.Role >= common.RoleAdminUser
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsUserEnabled(userId int) bool {
 | 
			
		||||
func IsUserEnabled(userId int) (bool, error) {
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
		return false, errors.New("user id is empty")
 | 
			
		||||
	}
 | 
			
		||||
	var user User
 | 
			
		||||
	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("no such user " + err.Error())
 | 
			
		||||
		return false
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	return user.Status == common.UserStatusEnabled
 | 
			
		||||
	return user.Status == common.UserStatusEnabled, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ValidateAccessToken(token string) (user *User) {
 | 
			
		||||
@@ -284,6 +283,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUserQuota, id, quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return increaseUserQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func increaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
@@ -292,6 +299,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return decreaseUserQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
@@ -302,10 +317,18 @@ func GetRootUserEmail() (email string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	updateUserUsedQuotaAndRequestCount(id, quota, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 | 
			
		||||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
			"request_count": gorm.Expr("request_count + ?", 1),
 | 
			
		||||
			"request_count": gorm.Expr("request_count + ?", count),
 | 
			
		||||
		},
 | 
			
		||||
	).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,75 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	BatchUpdateTypeUserQuota = iota
 | 
			
		||||
	BatchUpdateTypeTokenQuota
 | 
			
		||||
	BatchUpdateTypeUsedQuotaAndRequestCount
 | 
			
		||||
	BatchUpdateTypeChannelUsedQuota
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var batchUpdateStores []map[int]int
 | 
			
		||||
var batchUpdateLocks []sync.Mutex
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
			
		||||
		batchUpdateStores = append(batchUpdateStores, make(map[int]int))
 | 
			
		||||
		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitBatchUpdater() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
 | 
			
		||||
			batchUpdate()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addNewRecord(type_ int, id int, value int) {
 | 
			
		||||
	batchUpdateLocks[type_].Lock()
 | 
			
		||||
	defer batchUpdateLocks[type_].Unlock()
 | 
			
		||||
	if _, ok := batchUpdateStores[type_][id]; !ok {
 | 
			
		||||
		batchUpdateStores[type_][id] = value
 | 
			
		||||
	} else {
 | 
			
		||||
		batchUpdateStores[type_][id] += value
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func batchUpdate() {
 | 
			
		||||
	common.SysLog("batch update started")
 | 
			
		||||
	for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
			
		||||
		batchUpdateLocks[i].Lock()
 | 
			
		||||
		store := batchUpdateStores[i]
 | 
			
		||||
		batchUpdateStores[i] = make(map[int]int)
 | 
			
		||||
		batchUpdateLocks[i].Unlock()
 | 
			
		||||
 | 
			
		||||
		for key, value := range store {
 | 
			
		||||
			switch i {
 | 
			
		||||
			case BatchUpdateTypeUserQuota:
 | 
			
		||||
				err := increaseUserQuota(key, value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to batch update user quota: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			case BatchUpdateTypeTokenQuota:
 | 
			
		||||
				err := increaseTokenQuota(key, value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to batch update token quota: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			case BatchUpdateTypeUsedQuotaAndRequestCount:
 | 
			
		||||
				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
 | 
			
		||||
			case BatchUpdateTypeChannelUsedQuota:
 | 
			
		||||
				updateChannelUsedQuota(key, value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog("batch update finished")
 | 
			
		||||
}
 | 
			
		||||
@@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/embeddings", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/audio/transcriptions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/translations", controller.Relay)
 | 
			
		||||
		relayV1Router.GET("/files", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/files", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
 | 
			
		||||
 
 | 
			
		||||
@@ -324,7 +324,7 @@ const LogsTable = () => {
 | 
			
		||||
              .map((log, idx) => {
 | 
			
		||||
                if (log.deleted) return <></>;
 | 
			
		||||
                return (
 | 
			
		||||
                  <Table.Row key={log.created_at}>
 | 
			
		||||
                  <Table.Row key={log.id}>
 | 
			
		||||
                    <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
 | 
			
		||||
                    {
 | 
			
		||||
                      isAdminUser && (
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,11 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
 | 
			
		||||
  { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
 | 
			
		||||
  { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
 | 
			
		||||
  { key: 19, text: '360 智脑', value: 19, color: 'blue' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
  { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
 | 
			
		||||
  { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
 | 
			
		||||
  { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { useParams, useNavigate } from 'react-router-dom';
 | 
			
		||||
import { useNavigate, useParams } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
 | 
			
		||||
import { CHANNEL_OPTIONS } from '../../constants';
 | 
			
		||||
 | 
			
		||||
@@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
 | 
			
		||||
  'gpt-4-32k-0314': 'gpt-4-32k'
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
function type2secretPrompt(type) {
 | 
			
		||||
  // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
 | 
			
		||||
  switch (type) {
 | 
			
		||||
    case 15:
 | 
			
		||||
      return '按照如下格式输入:APIKey|SecretKey';
 | 
			
		||||
    case 18:
 | 
			
		||||
      return '按照如下格式输入:APPID|APISecret|APIKey';
 | 
			
		||||
    case 22:
 | 
			
		||||
      return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
 | 
			
		||||
    default:
 | 
			
		||||
      return '请输入渠道对应的鉴权密钥';
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const EditChannel = () => {
 | 
			
		||||
  const params = useParams();
 | 
			
		||||
  const navigate = useNavigate();
 | 
			
		||||
@@ -19,7 +33,7 @@ const EditChannel = () => {
 | 
			
		||||
  const handleCancel = () => {
 | 
			
		||||
    navigate('/channel');
 | 
			
		||||
  };
 | 
			
		||||
  
 | 
			
		||||
 | 
			
		||||
  const originInputs = {
 | 
			
		||||
    name: '',
 | 
			
		||||
    type: 1,
 | 
			
		||||
@@ -53,7 +67,7 @@ const EditChannel = () => {
 | 
			
		||||
          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 17:
 | 
			
		||||
          localModels = ['qwen-v1', 'qwen-plus-v1'];
 | 
			
		||||
          localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 16:
 | 
			
		||||
          localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
			
		||||
@@ -61,6 +75,9 @@ const EditChannel = () => {
 | 
			
		||||
        case 18:
 | 
			
		||||
          localModels = ['SparkDesk'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 19:
 | 
			
		||||
          localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
 | 
			
		||||
          break;
 | 
			
		||||
      }
 | 
			
		||||
      setInputs((inputs) => ({ ...inputs, models: localModels }));
 | 
			
		||||
    }
 | 
			
		||||
@@ -190,6 +207,24 @@ const EditChannel = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const addCustomModel = () => {
 | 
			
		||||
    if (customModel.trim() === '') return;
 | 
			
		||||
    if (inputs.models.includes(customModel)) return;
 | 
			
		||||
    let localModels = [...inputs.models];
 | 
			
		||||
    localModels.push(customModel);
 | 
			
		||||
    let localModelOptions = [];
 | 
			
		||||
    localModelOptions.push({
 | 
			
		||||
      key: customModel,
 | 
			
		||||
      text: customModel,
 | 
			
		||||
      value: customModel
 | 
			
		||||
    });
 | 
			
		||||
    setModelOptions(modelOptions => {
 | 
			
		||||
      return [...modelOptions, ...localModelOptions];
 | 
			
		||||
    });
 | 
			
		||||
    setCustomModel('');
 | 
			
		||||
    handleInputChange(null, { name: 'models', value: localModels });
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <>
 | 
			
		||||
      <Segment loading={loading}>
 | 
			
		||||
@@ -292,6 +327,20 @@ const EditChannel = () => {
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type === 21 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='知识库 ID'
 | 
			
		||||
                  name='other'
 | 
			
		||||
                  placeholder={'请输入知识库 ID,例如:123456'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.other}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Dropdown
 | 
			
		||||
              label='模型'
 | 
			
		||||
@@ -319,29 +368,19 @@ const EditChannel = () => {
 | 
			
		||||
            }}>清除所有模型</Button>
 | 
			
		||||
            <Input
 | 
			
		||||
              action={
 | 
			
		||||
                <Button type={'button'} onClick={() => {
 | 
			
		||||
                  if (customModel.trim() === '') return;
 | 
			
		||||
                  if (inputs.models.includes(customModel)) return;
 | 
			
		||||
                  let localModels = [...inputs.models];
 | 
			
		||||
                  localModels.push(customModel);
 | 
			
		||||
                  let localModelOptions = [];
 | 
			
		||||
                  localModelOptions.push({
 | 
			
		||||
                    key: customModel,
 | 
			
		||||
                    text: customModel,
 | 
			
		||||
                    value: customModel
 | 
			
		||||
                  });
 | 
			
		||||
                  setModelOptions(modelOptions => {
 | 
			
		||||
                    return [...modelOptions, ...localModelOptions];
 | 
			
		||||
                  });
 | 
			
		||||
                  setCustomModel('');
 | 
			
		||||
                  handleInputChange(null, { name: 'models', value: localModels });
 | 
			
		||||
                }}>填入</Button>
 | 
			
		||||
                <Button type={'button'} onClick={addCustomModel}>填入</Button>
 | 
			
		||||
              }
 | 
			
		||||
              placeholder='输入自定义模型名称'
 | 
			
		||||
              value={customModel}
 | 
			
		||||
              onChange={(e, { value }) => {
 | 
			
		||||
                setCustomModel(value);
 | 
			
		||||
              }}
 | 
			
		||||
              onKeyDown={(e) => {
 | 
			
		||||
                if (e.key === 'Enter') {
 | 
			
		||||
                  addCustomModel();
 | 
			
		||||
                  e.preventDefault();
 | 
			
		||||
                }
 | 
			
		||||
              }}
 | 
			
		||||
            />
 | 
			
		||||
          </div>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
@@ -372,7 +411,7 @@ const EditChannel = () => {
 | 
			
		||||
                label='密钥'
 | 
			
		||||
                name='key'
 | 
			
		||||
                required
 | 
			
		||||
                placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
 | 
			
		||||
                placeholder={type2secretPrompt(inputs.type)}
 | 
			
		||||
                onChange={handleInputChange}
 | 
			
		||||
                value={inputs.key}
 | 
			
		||||
                autoComplete='new-password'
 | 
			
		||||
@@ -390,7 +429,7 @@ const EditChannel = () => {
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type !== 3 && inputs.type !== 8 && (
 | 
			
		||||
            inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='代理'
 | 
			
		||||
@@ -403,6 +442,20 @@ const EditChannel = () => {
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type === 22 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='私有部署地址'
 | 
			
		||||
                  name='base_url'
 | 
			
		||||
                  placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.base_url}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Button onClick={handleCancel}>取消</Button>
 | 
			
		||||
          <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user