mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 13:53:41 +08:00 
			
		
		
		
	Compare commits
	
		
			20 Commits
		
	
	
		
			v0.6.0-alp
			...
			v0.6.1-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | de18d6fe16 | ||
|  | 1d0b7fb5ae | ||
|  | f9490bb72e | ||
|  | 76467285e8 | ||
|  | df1fd9aa81 | ||
|  | 614c2e0442 | ||
|  | eac6a0b9aa | ||
|  | b747cdbc6f | ||
|  | 6b27d6659a | ||
|  | dc5b781191 | ||
|  | c880b4a9a3 | ||
|  | 565ea58e68 | ||
|  | f141a37a9e | ||
|  | 5b78886ad3 | ||
|  | 87c7c4f0e6 | ||
|  | 4c4a873890 | ||
|  | 0664bdfda1 | ||
|  | 32387d9c20 | ||
|  | bd888f2eb7 | ||
|  | cece77e533 | 
							
								
								
									
										4
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -23,7 +23,7 @@ jobs: | |||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
|       - name: Build Frontend (theme default) |       - name: Build Frontend | ||||||
|         env: |         env: | ||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
| @@ -38,7 +38,7 @@ jobs: | |||||||
|       - name: Build Backend (amd64) |       - name: Build Backend (amd64) | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api |           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
|       - name: Build Backend (arm64) |       - name: Build Backend (arm64) | ||||||
|         run: | |         run: | | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -23,7 +23,7 @@ jobs: | |||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
|       - name: Build Frontend (theme default) |       - name: Build Frontend | ||||||
|         env: |         env: | ||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
| @@ -38,7 +38,7 @@ jobs: | |||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos |           go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -26,7 +26,7 @@ jobs: | |||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
|       - name: Build Frontend (theme default) |       - name: Build Frontend | ||||||
|         env: |         env: | ||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
| @@ -41,7 +41,7 @@ jobs: | |||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe |           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ ADD go.mod go.sum ./ | |||||||
| RUN go mod download | RUN go mod download | ||||||
| COPY . . | COPY . . | ||||||
| COPY --from=builder /web/build ./web/build | COPY --from=builder /web/build ./web/build | ||||||
| RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
| FROM alpine | FROM alpine | ||||||
|  |  | ||||||
|   | |||||||
| @@ -74,8 +74,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|    + [x] [360 智脑](https://ai.360.cn) |    + [x] [360 智脑](https://ai.360.cn) | ||||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) |    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||||
|    + [x] [Moonshot AI](https://platform.moonshot.cn/) |    + [x] [Moonshot AI](https://platform.moonshot.cn/) | ||||||
|  |    + [x] [百川大模型](https://platform.baichuan-ai.com) | ||||||
|    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) |    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) | ||||||
|    + [ ] [MINIMAX](https://api.minimax.chat/) (WIP) |    + [x] [MINIMAX](https://api.minimax.chat/) | ||||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
|   | |||||||
| @@ -64,6 +64,8 @@ const ( | |||||||
| 	ChannelTypeTencent        = 23 | 	ChannelTypeTencent        = 23 | ||||||
| 	ChannelTypeGemini         = 24 | 	ChannelTypeGemini         = 24 | ||||||
| 	ChannelTypeMoonshot       = 25 | 	ChannelTypeMoonshot       = 25 | ||||||
|  | 	ChannelTypeBaichuan       = 26 | ||||||
|  | 	ChannelTypeMinimax        = 27 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| @@ -93,6 +95,8 @@ var ChannelBaseURLs = []string{ | |||||||
| 	"https://hunyuan.cloud.tencent.com",         // 23 | 	"https://hunyuan.cloud.tencent.com",         // 23 | ||||||
| 	"https://generativelanguage.googleapis.com", // 24 | 	"https://generativelanguage.googleapis.com", // 24 | ||||||
| 	"https://api.moonshot.cn",                   // 25 | 	"https://api.moonshot.cn",                   // 25 | ||||||
|  | 	"https://api.baichuan-ai.com",               // 26 | ||||||
|  | 	"https://api.minimax.chat",                  // 27 | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|   | |||||||
| @@ -8,12 +8,24 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | const KeyRequestBody = "key_request_body" | ||||||
|  |  | ||||||
|  | func GetRequestBody(c *gin.Context) ([]byte, error) { | ||||||
|  | 	requestBody, _ := c.Get(KeyRequestBody) | ||||||
|  | 	if requestBody != nil { | ||||||
|  | 		return requestBody.([]byte), nil | ||||||
|  | 	} | ||||||
| 	requestBody, err := io.ReadAll(c.Request.Body) | 	requestBody, err := io.ReadAll(c.Request.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	err = c.Request.Body.Close() | 	_ = c.Request.Body.Close() | ||||||
|  | 	c.Set(KeyRequestBody, requestBody) | ||||||
|  | 	return requestBody.([]byte), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||||
|  | 	requestBody, err := GetRequestBody(c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -137,6 +137,7 @@ func GetUUID() string { | |||||||
| } | } | ||||||
|  |  | ||||||
| const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||||
|  | const keyNumbers = "0123456789" | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	rand.Seed(time.Now().UnixNano()) | 	rand.Seed(time.Now().UnixNano()) | ||||||
| @@ -168,6 +169,15 @@ func GetRandomString(length int) string { | |||||||
| 	return string(key) | 	return string(key) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetRandomNumberString(length int) string { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | 	key := make([]byte, length) | ||||||
|  | 	for i := 0; i < length; i++ { | ||||||
|  | 		key[i] = keyNumbers[rand.Intn(len(keyNumbers))] | ||||||
|  | 	} | ||||||
|  | 	return string(key) | ||||||
|  | } | ||||||
|  |  | ||||||
| func GetTimestamp() int64 { | func GetTimestamp() int64 { | ||||||
| 	return time.Now().Unix() | 	return time.Now().Unix() | ||||||
| } | } | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|  | 	loggerDEBUG = "DEBUG" | ||||||
| 	loggerINFO  = "INFO" | 	loggerINFO  = "INFO" | ||||||
| 	loggerWarn  = "WARN" | 	loggerWarn  = "WARN" | ||||||
| 	loggerError = "ERR" | 	loggerError = "ERR" | ||||||
| @@ -55,6 +56,10 @@ func SysError(s string) { | |||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Debug(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerDEBUG, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
| func Info(ctx context.Context, msg string) { | func Info(ctx context.Context, msg string) { | ||||||
| 	logHelper(ctx, loggerINFO, msg) | 	logHelper(ctx, loggerINFO, msg) | ||||||
| } | } | ||||||
| @@ -67,6 +72,10 @@ func Error(ctx context.Context, msg string) { | |||||||
| 	logHelper(ctx, loggerError, msg) | 	logHelper(ctx, loggerError, msg) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Debugf(ctx context.Context, format string, a ...any) { | ||||||
|  | 	Debug(ctx, fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
| func Infof(ctx context.Context, format string, a ...any) { | func Infof(ctx context.Context, format string, a ...any) { | ||||||
| 	Info(ctx, fmt.Sprintf(format, a...)) | 	Info(ctx, fmt.Sprintf(format, a...)) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -7,29 +7,6 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var DalleSizeRatios = map[string]map[string]float64{ |  | ||||||
| 	"dall-e-2": { |  | ||||||
| 		"256x256":   1, |  | ||||||
| 		"512x512":   1.125, |  | ||||||
| 		"1024x1024": 1.25, |  | ||||||
| 	}, |  | ||||||
| 	"dall-e-3": { |  | ||||||
| 		"1024x1024": 1, |  | ||||||
| 		"1024x1792": 2, |  | ||||||
| 		"1792x1024": 2, |  | ||||||
| 	}, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DalleGenerationImageAmounts = map[string][2]int{ |  | ||||||
| 	"dall-e-2": {1, 10}, |  | ||||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DalleImagePromptLengthLimitations = map[string]int{ |  | ||||||
| 	"dall-e-2": 1000, |  | ||||||
| 	"dall-e-3": 4000, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	USD2RMB = 7 | 	USD2RMB = 7 | ||||||
| 	USD     = 500 // $0.002 = 1 -> $1 = 500 | 	USD     = 500 // $0.002 = 1 -> $1 = 500 | ||||||
| @@ -102,6 +79,10 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"PaLM-2":            1, | 	"PaLM-2":            1, | ||||||
| 	"gemini-pro":        1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | 	"gemini-pro":        1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
| 	"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | 	"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
|  | 	// https://open.bigmodel.cn/pricing | ||||||
|  | 	"glm-4":                     0.1 * RMB, | ||||||
|  | 	"glm-4v":                    0.1 * RMB, | ||||||
|  | 	"glm-3-turbo":               0.005 * RMB, | ||||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||||
| @@ -127,6 +108,23 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"moonshot-v1-8k":   0.012 * RMB, | 	"moonshot-v1-8k":   0.012 * RMB, | ||||||
| 	"moonshot-v1-32k":  0.024 * RMB, | 	"moonshot-v1-32k":  0.024 * RMB, | ||||||
| 	"moonshot-v1-128k": 0.06 * RMB, | 	"moonshot-v1-128k": 0.06 * RMB, | ||||||
|  | 	// https://platform.baichuan-ai.com/price | ||||||
|  | 	"Baichuan2-Turbo":      0.008 * RMB, | ||||||
|  | 	"Baichuan2-Turbo-192k": 0.016 * RMB, | ||||||
|  | 	"Baichuan2-53B":        0.02 * RMB, | ||||||
|  | 	// https://api.minimax.chat/document/price | ||||||
|  | 	"abab6-chat":    0.1 * RMB, | ||||||
|  | 	"abab5.5-chat":  0.015 * RMB, | ||||||
|  | 	"abab5.5s-chat": 0.005 * RMB, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DefaultModelRatio map[string]float64 | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	DefaultModelRatio = make(map[string]float64) | ||||||
|  | 	for k, v := range ModelRatio { | ||||||
|  | 		DefaultModelRatio[k] = v | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
| @@ -147,6 +145,9 @@ func GetModelRatio(name string) float64 { | |||||||
| 		name = strings.TrimSuffix(name, "-internet") | 		name = strings.TrimSuffix(name, "-internet") | ||||||
| 	} | 	} | ||||||
| 	ratio, ok := ModelRatio[name] | 	ratio, ok := ModelRatio[name] | ||||||
|  | 	if !ok { | ||||||
|  | 		ratio, ok = DefaultModelRatio[name] | ||||||
|  | 	} | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		logger.SysError("model ratio not found: " + name) | 		logger.SysError("model ratio not found: " + name) | ||||||
| 		return 30 | 		return 30 | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" | 	"github.com/songquanpeng/one-api/relay/helper" | ||||||
| @@ -58,6 +60,9 @@ func init() { | |||||||
| 	}) | 	}) | ||||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||||
| 	for i := 0; i < constant.APITypeDummy; i++ { | 	for i := 0; i < constant.APITypeDummy; i++ { | ||||||
|  | 		if i == constant.APITypeAIProxyLibrary { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
| 		adaptor := helper.GetAdaptor(i) | 		adaptor := helper.GetAdaptor(i) | ||||||
| 		channelName := adaptor.GetChannelName() | 		channelName := adaptor.GetChannelName() | ||||||
| 		modelNames := adaptor.GetModelList() | 		modelNames := adaptor.GetModelList() | ||||||
| @@ -95,6 +100,28 @@ func init() { | |||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|  | 	for _, modelName := range baichuan.ModelList { | ||||||
|  | 		openAIModels = append(openAIModels, OpenAIModels{ | ||||||
|  | 			Id:         modelName, | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1626777600, | ||||||
|  | 			OwnedBy:    "baichuan", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       modelName, | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	for _, modelName := range minimax.ModelList { | ||||||
|  | 		openAIModels = append(openAIModels, OpenAIModels{ | ||||||
|  | 			Id:         modelName, | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1626777600, | ||||||
|  | 			OwnedBy:    "minimax", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       modelName, | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | 	openAIModelsMap = make(map[string]OpenAIModels) | ||||||
| 	for _, model := range openAIModels { | 	for _, model := range openAIModels { | ||||||
| 		openAIModelsMap[model.Id] = model | 		openAIModelsMap[model.Id] = model | ||||||
|   | |||||||
| @@ -1,23 +1,27 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/middleware" | ||||||
|  | 	dbmodel "github.com/songquanpeng/one-api/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/controller" | 	"github.com/songquanpeng/one-api/relay/controller" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
|  | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strconv" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/chat | // https://platform.openai.com/docs/api-reference/chat | ||||||
|  |  | ||||||
| func Relay(c *gin.Context) { | func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { | ||||||
| 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) |  | ||||||
| 	var err *model.ErrorWithStatusCode | 	var err *model.ErrorWithStatusCode | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case constant.RelayModeImagesGenerations: | 	case constant.RelayModeImagesGenerations: | ||||||
| @@ -31,34 +35,91 @@ func Relay(c *gin.Context) { | |||||||
| 	default: | 	default: | ||||||
| 		err = controller.RelayTextHelper(c) | 		err = controller.RelayTextHelper(c) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Relay(c *gin.Context) { | ||||||
|  | 	ctx := c.Request.Context() | ||||||
|  | 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) | ||||||
|  | 	if config.DebugEnabled { | ||||||
|  | 		requestBody, _ := common.GetRequestBody(c) | ||||||
|  | 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||||
|  | 	} | ||||||
|  | 	bizErr := relay(c, relayMode) | ||||||
|  | 	if bizErr == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
|  | 	lastFailedChannelId := channelId | ||||||
|  | 	channelName := c.GetString("channel_name") | ||||||
|  | 	group := c.GetString("group") | ||||||
|  | 	originalModel := c.GetString("original_model") | ||||||
|  | 	go processChannelRelayError(ctx, channelId, channelName, bizErr) | ||||||
| 	requestId := c.GetString(logger.RequestIdKey) | 	requestId := c.GetString(logger.RequestIdKey) | ||||||
| 		retryTimesStr := c.Query("retry") | 	retryTimes := config.RetryTimes | ||||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | 	if !shouldRetry(c, bizErr.StatusCode) { | ||||||
| 		if retryTimesStr == "" { | 		logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) | ||||||
| 			retryTimes = config.RetryTimes | 		retryTimes = 0 | ||||||
| 	} | 	} | ||||||
| 		if retryTimes > 0 { | 	for i := retryTimes; i > 0; i-- { | ||||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) | ||||||
| 		} else { | 		if err != nil { | ||||||
| 			if err.StatusCode == http.StatusTooManyRequests { | 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) | ||||||
| 				err.Error.Message = "当前分组上游负载已饱和,请稍后再试" | 			break | ||||||
| 		} | 		} | ||||||
| 			err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) | 		logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) | ||||||
| 			c.JSON(err.StatusCode, gin.H{ | 		if channel.Id == lastFailedChannelId { | ||||||
| 				"error": err.Error, | 			continue | ||||||
|  | 		} | ||||||
|  | 		middleware.SetupContextForSelectedChannel(c, channel, originalModel) | ||||||
|  | 		requestBody, err := common.GetRequestBody(c) | ||||||
|  | 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
|  | 		bizErr = relay(c, relayMode) | ||||||
|  | 		if bizErr == nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		channelId := c.GetInt("channel_id") | ||||||
|  | 		lastFailedChannelId = channelId | ||||||
|  | 		channelName := c.GetString("channel_name") | ||||||
|  | 		go processChannelRelayError(ctx, channelId, channelName, bizErr) | ||||||
|  | 	} | ||||||
|  | 	if bizErr != nil { | ||||||
|  | 		if bizErr.StatusCode == http.StatusTooManyRequests { | ||||||
|  | 			bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" | ||||||
|  | 		} | ||||||
|  | 		bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) | ||||||
|  | 		c.JSON(bizErr.StatusCode, gin.H{ | ||||||
|  | 			"error": bizErr.Error, | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 		channelId := c.GetInt("channel_id") | } | ||||||
| 		logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) |  | ||||||
|  | func shouldRetry(c *gin.Context, statusCode int) bool { | ||||||
|  | 	if _, ok := c.Get("specific_channel_id"); ok { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusTooManyRequests { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if statusCode/100 == 5 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusBadRequest { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode/100 == 2 { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { | ||||||
|  | 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | ||||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||||
| 			channelId := c.GetInt("channel_id") |  | ||||||
| 			channelName := c.GetString("channel_name") |  | ||||||
| 		disableChannel(channelId, channelName, err.Message) | 		disableChannel(channelId, channelName, err.Message) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| } |  | ||||||
|  |  | ||||||
| func RelayNotImplemented(c *gin.Context) { | func RelayNotImplemented(c *gin.Context) { | ||||||
| 	err := model.Error{ | 	err := model.Error{ | ||||||
|   | |||||||
| @@ -456,6 +456,7 @@ | |||||||
|   "已绑定的邮箱账户": "Email Account Bound", |   "已绑定的邮箱账户": "Email Account Bound", | ||||||
|   "用户信息更新成功!": "User information updated successfully!", |   "用户信息更新成功!": "User information updated successfully!", | ||||||
|   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", |   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", | ||||||
|  |   "模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f", | ||||||
|   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", |   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", | ||||||
|   "用户名称": "User Name", |   "用户名称": "User Name", | ||||||
|   "令牌名称": "Token Name", |   "令牌名称": "Token Name", | ||||||
|   | |||||||
| @@ -108,7 +108,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 		c.Set("token_name", token.Name) | 		c.Set("token_name", token.Name) | ||||||
| 		if len(parts) > 1 { | 		if len(parts) > 1 { | ||||||
| 			if model.IsAdmin(token.UserId) { | 			if model.IsAdmin(token.UserId) { | ||||||
| 				c.Set("channelId", parts[1]) | 				c.Set("specific_channel_id", parts[1]) | ||||||
| 			} else { | 			} else { | ||||||
| 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||||
| 				return | 				return | ||||||
|   | |||||||
| @@ -21,8 +21,9 @@ func Distribute() func(c *gin.Context) { | |||||||
| 		userId := c.GetInt("id") | 		userId := c.GetInt("id") | ||||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||||
| 		c.Set("group", userGroup) | 		c.Set("group", userGroup) | ||||||
|  | 		var requestModel string | ||||||
| 		var channel *model.Channel | 		var channel *model.Channel | ||||||
| 		channelId, ok := c.Get("channelId") | 		channelId, ok := c.Get("specific_channel_id") | ||||||
| 		if ok { | 		if ok { | ||||||
| 			id, err := strconv.Atoi(channelId.(string)) | 			id, err := strconv.Atoi(channelId.(string)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -66,6 +67,7 @@ func Distribute() func(c *gin.Context) { | |||||||
| 					modelRequest.Model = "whisper-1" | 					modelRequest.Model = "whisper-1" | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | 			requestModel = modelRequest.Model | ||||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||||
| @@ -77,10 +79,17 @@ func Distribute() func(c *gin.Context) { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | 		SetupContextForSelectedChannel(c, channel, requestModel) | ||||||
|  | 		c.Next() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { | ||||||
| 	c.Set("channel", channel.Type) | 	c.Set("channel", channel.Type) | ||||||
| 	c.Set("channel_id", channel.Id) | 	c.Set("channel_id", channel.Id) | ||||||
| 	c.Set("channel_name", channel.Name) | 	c.Set("channel_name", channel.Name) | ||||||
| 	c.Set("model_mapping", channel.GetModelMapping()) | 	c.Set("model_mapping", channel.GetModelMapping()) | ||||||
|  | 	c.Set("original_model", modelName) // for retry | ||||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||||
| 	c.Set("base_url", channel.GetBaseURL()) | 	c.Set("base_url", channel.GetBaseURL()) | ||||||
| 	// this is for backward compatibility | 	// this is for backward compatibility | ||||||
| @@ -100,6 +109,4 @@ func Distribute() func(c *gin.Context) { | |||||||
| 	for k, v := range cfg { | 	for k, v := range cfg { | ||||||
| 		c.Set(common.ConfigKeyPrefix+k, v) | 		c.Set(common.ConfigKeyPrefix+k, v) | ||||||
| 	} | 	} | ||||||
| 		c.Next() |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ import ( | |||||||
|  |  | ||||||
| func RequestId() func(c *gin.Context) { | func RequestId() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		id := helper.GetTimeString() + helper.GetRandomString(8) | 		id := helper.GetTimeString() + helper.GetRandomNumberString(8) | ||||||
| 		c.Set(logger.RequestIdKey, id) | 		c.Set(logger.RequestIdKey, id) | ||||||
| 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | ||||||
| 		c.Request = c.Request.WithContext(ctx) | 		c.Request = c.Request.WithContext(ctx) | ||||||
|   | |||||||
| @@ -94,7 +94,7 @@ func CacheUpdateUserQuota(id int) error { | |||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	quota, err := GetUserQuota(id) | 	quota, err := CacheGetUserQuota(id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -53,7 +53,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon | |||||||
| 		FinishReason: "stop", | 		FinishReason: "stop", | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := openai.TextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      helper.GetUUID(), | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Choices: []openai.TextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| @@ -66,7 +66,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion | |||||||
| 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||||
| 	choice.FinishReason = &constant.StopFinishReason | 	choice.FinishReason = &constant.StopFinishReason | ||||||
| 	return &openai.ChatCompletionsStreamResponse{ | 	return &openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      helper.GetUUID(), | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   "", | 		Model:   "", | ||||||
| @@ -78,7 +78,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena | |||||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = response.Content | 	choice.Delta.Content = response.Content | ||||||
| 	return &openai.ChatCompletionsStreamResponse{ | 	return &openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      helper.GetUUID(), | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   response.Model, | 		Model:   response.Model, | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								relay/channel/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | package baichuan | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"Baichuan2-Turbo", | ||||||
|  | 	"Baichuan2-Turbo-192k", | ||||||
|  | 	"Baichuan-Text-Embedding", | ||||||
|  | } | ||||||
							
								
								
									
										7
									
								
								relay/channel/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | package minimax | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"abab5.5s-chat", | ||||||
|  | 	"abab5.5-chat", | ||||||
|  | 	"abab6-chat", | ||||||
|  | } | ||||||
							
								
								
									
										14
									
								
								relay/channel/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								relay/channel/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | package minimax | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
|  | 	if meta.Mode == constant.RelayModeChatCompletions { | ||||||
|  | 		return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil | ||||||
|  | 	} | ||||||
|  | 	return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) | ||||||
|  | } | ||||||
| @@ -7,6 +7,8 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| @@ -24,7 +26,8 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
| 	if meta.ChannelType == common.ChannelTypeAzure { | 	switch meta.ChannelType { | ||||||
|  | 	case common.ChannelTypeAzure: | ||||||
| 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||||
| 		requestURL := strings.Split(meta.RequestURLPath, "?")[0] | 		requestURL := strings.Split(meta.RequestURLPath, "?")[0] | ||||||
| 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) | 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) | ||||||
| @@ -38,9 +41,12 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | |||||||
|  |  | ||||||
| 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||||
| 		return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil | 		return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil | ||||||
| 	} | 	case common.ChannelTypeMinimax: | ||||||
|  | 		return minimax.GetRequestURL(meta) | ||||||
|  | 	default: | ||||||
| 		return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil | 		return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||||
| 	channel.SetupCommonRequestHeader(c, req, meta) | 	channel.SetupCommonRequestHeader(c, req, meta) | ||||||
| @@ -70,7 +76,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | |||||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		var responseText string | 		var responseText string | ||||||
| 		err, responseText = StreamHandler(c, resp, meta.Mode) | 		err, responseText, _ = StreamHandler(c, resp, meta.Mode) | ||||||
| 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||||
| 	} else { | 	} else { | ||||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||||
| @@ -84,6 +90,10 @@ func (a *Adaptor) GetModelList() []string { | |||||||
| 		return ai360.ModelList | 		return ai360.ModelList | ||||||
| 	case common.ChannelTypeMoonshot: | 	case common.ChannelTypeMoonshot: | ||||||
| 		return moonshot.ModelList | 		return moonshot.ModelList | ||||||
|  | 	case common.ChannelTypeBaichuan: | ||||||
|  | 		return baichuan.ModelList | ||||||
|  | 	case common.ChannelTypeMinimax: | ||||||
|  | 		return minimax.ModelList | ||||||
| 	default: | 	default: | ||||||
| 		return ModelList | 		return ModelList | ||||||
| 	} | 	} | ||||||
| @@ -97,6 +107,10 @@ func (a *Adaptor) GetChannelName() string { | |||||||
| 		return "360" | 		return "360" | ||||||
| 	case common.ChannelTypeMoonshot: | 	case common.ChannelTypeMoonshot: | ||||||
| 		return "moonshot" | 		return "moonshot" | ||||||
|  | 	case common.ChannelTypeBaichuan: | ||||||
|  | 		return "baichuan" | ||||||
|  | 	case common.ChannelTypeMinimax: | ||||||
|  | 		return "minimax" | ||||||
| 	default: | 	default: | ||||||
| 		return "openai" | 		return "openai" | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -14,7 +14,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { | ||||||
| 	responseText := "" | 	responseText := "" | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| @@ -31,6 +31,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | |||||||
| 	}) | 	}) | ||||||
| 	dataChan := make(chan string) | 	dataChan := make(chan string) | ||||||
| 	stopChan := make(chan bool) | 	stopChan := make(chan bool) | ||||||
|  | 	var usage *model.Usage | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for scanner.Scan() { | 		for scanner.Scan() { | ||||||
| 			data := scanner.Text() | 			data := scanner.Text() | ||||||
| @@ -54,6 +55,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | |||||||
| 					for _, choice := range streamResponse.Choices { | 					for _, choice := range streamResponse.Choices { | ||||||
| 						responseText += choice.Delta.Content | 						responseText += choice.Delta.Content | ||||||
| 					} | 					} | ||||||
|  | 					if streamResponse.Usage != nil { | ||||||
|  | 						usage = streamResponse.Usage | ||||||
|  | 					} | ||||||
| 				case constant.RelayModeCompletions: | 				case constant.RelayModeCompletions: | ||||||
| 					var streamResponse CompletionsStreamResponse | 					var streamResponse CompletionsStreamResponse | ||||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||||
| @@ -86,9 +90,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil | ||||||
| 	} | 	} | ||||||
| 	return nil, responseText | 	return nil, responseText, usage | ||||||
| } | } | ||||||
|  |  | ||||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|   | |||||||
| @@ -118,8 +118,10 @@ type ImageResponse struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponseChoice struct { | type ChatCompletionsStreamResponseChoice struct { | ||||||
|  | 	Index int `json:"index"` | ||||||
| 	Delta struct { | 	Delta struct { | ||||||
| 		Content string `json:"content"` | 		Content string `json:"content"` | ||||||
|  | 		Role    string `json:"role,omitempty"` | ||||||
| 	} `json:"delta"` | 	} `json:"delta"` | ||||||
| 	FinishReason *string `json:"finish_reason,omitempty"` | 	FinishReason *string `json:"finish_reason,omitempty"` | ||||||
| } | } | ||||||
| @@ -130,6 +132,7 @@ type ChatCompletionsStreamResponse struct { | |||||||
| 	Created int64                                 `json:"created"` | 	Created int64                                 `json:"created"` | ||||||
| 	Model   string                                `json:"model"` | 	Model   string                                `json:"model"` | ||||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||||
|  | 	Usage   *model.Usage                          `json:"usage"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type CompletionsStreamResponse struct { | type CompletionsStreamResponse struct { | ||||||
|   | |||||||
| @@ -81,6 +81,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | |||||||
|  |  | ||||||
| func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	response := openai.ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   "tencent-hunyuan", | 		Model:   "tencent-hunyuan", | ||||||
|   | |||||||
| @@ -70,6 +70,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | |||||||
| 		FinishReason: constant.StopFinishReason, | 		FinishReason: constant.StopFinishReason, | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := openai.TextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Choices: []openai.TextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| @@ -92,6 +93,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl | |||||||
| 		choice.FinishReason = &constant.StopFinishReason | 		choice.FinishReason = &constant.StopFinishReason | ||||||
| 	} | 	} | ||||||
| 	response := openai.ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   "SparkDesk", | 		Model:   "SparkDesk", | ||||||
|   | |||||||
| @@ -5,20 +5,35 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Adaptor struct { | type Adaptor struct { | ||||||
|  | 	APIVersion string | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) Init(meta *util.RelayMeta) { | func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) SetVersionByModeName(modelName string) { | ||||||
|  | 	if strings.HasPrefix(modelName, "glm-") { | ||||||
|  | 		a.APIVersion = "v4" | ||||||
|  | 	} else { | ||||||
|  | 		a.APIVersion = "v3" | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
|  | 	a.SetVersionByModeName(meta.ActualModelName) | ||||||
|  | 	if a.APIVersion == "v4" { | ||||||
|  | 		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil | ||||||
|  | 	} | ||||||
| 	method := "invoke" | 	method := "invoke" | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		method = "sse-invoke" | 		method = "sse-invoke" | ||||||
| @@ -37,6 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | |||||||
| 	if request == nil { | 	if request == nil { | ||||||
| 		return nil, errors.New("request is nil") | 		return nil, errors.New("request is nil") | ||||||
| 	} | 	} | ||||||
|  | 	if request.TopP >= 1 { | ||||||
|  | 		request.TopP = 0.99 | ||||||
|  | 	} | ||||||
|  | 	a.SetVersionByModeName(request.Model) | ||||||
|  | 	if a.APIVersion == "v4" { | ||||||
|  | 		return request, nil | ||||||
|  | 	} | ||||||
| 	return ConvertRequest(*request), nil | 	return ConvertRequest(*request), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -44,7 +66,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | |||||||
| 	return channel.DoRequestHelper(a, c, meta, requestBody) | 	return channel.DoRequestHelper(a, c, meta, requestBody) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
|  | 	if meta.IsStream { | ||||||
|  | 		err, _, usage = openai.StreamHandler(c, resp, meta.Mode) | ||||||
|  | 	} else { | ||||||
|  | 		err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
|  | 	if a.APIVersion == "v4" { | ||||||
|  | 		return a.DoResponseV4(c, resp, meta) | ||||||
|  | 	} | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		err, usage = StreamHandler(c, resp) | 		err, usage = StreamHandler(c, resp) | ||||||
| 	} else { | 	} else { | ||||||
|   | |||||||
| @@ -2,4 +2,5 @@ package zhipu | |||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | ||||||
|  | 	"glm-4", "glm-4v", "glm-3-turbo", | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | package constant | ||||||
|  |  | ||||||
|  | var DalleSizeRatios = map[string]map[string]float64{ | ||||||
|  | 	"dall-e-2": { | ||||||
|  | 		"256x256":   1, | ||||||
|  | 		"512x512":   1.125, | ||||||
|  | 		"1024x1024": 1.25, | ||||||
|  | 	}, | ||||||
|  | 	"dall-e-3": { | ||||||
|  | 		"1024x1024": 1, | ||||||
|  | 		"1024x1792": 2, | ||||||
|  | 		"1792x1024": 2, | ||||||
|  | 	}, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DalleGenerationImageAmounts = map[string][2]int{ | ||||||
|  | 	"dall-e-2": {1, 10}, | ||||||
|  | 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DalleImagePromptLengthLimitations = map[string]int{ | ||||||
|  | 	"dall-e-2": 1000, | ||||||
|  | 	"dall-e-3": 4000, | ||||||
|  | } | ||||||
| @@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener | |||||||
| 	return textRequest, nil | 	return textRequest, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { | ||||||
|  | 	imageRequest := &openai.ImageRequest{} | ||||||
|  | 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.N == 0 { | ||||||
|  | 		imageRequest.N = 1 | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.Size == "" { | ||||||
|  | 		imageRequest.Size = "1024x1024" | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.Model == "" { | ||||||
|  | 		imageRequest.Model = "dall-e-2" | ||||||
|  | 	} | ||||||
|  | 	return imageRequest, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { | ||||||
|  | 	// model validation | ||||||
|  | 	_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||||
|  | 	if !hasValidSize { | ||||||
|  | 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||||
|  | 	} | ||||||
|  | 	// check prompt length | ||||||
|  | 	if imageRequest.Prompt == "" { | ||||||
|  | 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||||
|  | 	} | ||||||
|  | 	if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { | ||||||
|  | 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||||
|  | 	} | ||||||
|  | 	// Number of generated images validation | ||||||
|  | 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||||
|  | 		// channel not azure | ||||||
|  | 		if meta.ChannelType != common.ChannelTypeAzure { | ||||||
|  | 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { | ||||||
|  | 	if imageRequest == nil { | ||||||
|  | 		return 0, errors.New("imageRequest is nil") | ||||||
|  | 	} | ||||||
|  | 	imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||||
|  | 	if !hasValidSize { | ||||||
|  | 		return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||||
|  | 		if imageRequest.Size == "1024x1024" { | ||||||
|  | 			imageCostRatio *= 2 | ||||||
|  | 		} else { | ||||||
|  | 			imageCostRatio *= 1.5 | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return imageCostRatio, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case constant.RelayModeChatCompletions: | 	case constant.RelayModeChatCompletions: | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -20,120 +21,65 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func isWithinRange(element string, value int) bool { | func isWithinRange(element string, value int) bool { | ||||||
| 	if _, ok := common.DalleGenerationImageAmounts[element]; !ok { | 	if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 	min := common.DalleGenerationImageAmounts[element][0] | 	min := constant.DalleGenerationImageAmounts[element][0] | ||||||
| 	max := common.DalleGenerationImageAmounts[element][1] | 	max := constant.DalleGenerationImageAmounts[element][1] | ||||||
|  |  | ||||||
| 	return value >= min && value <= max | 	return value >= min && value <= max | ||||||
| } | } | ||||||
|  |  | ||||||
| func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||||
| 	imageModel := "dall-e-2" | 	ctx := c.Request.Context() | ||||||
| 	imageSize := "1024x1024" | 	meta := util.GetRelayMeta(c) | ||||||
|  | 	imageRequest, err := getImageRequest(c, meta.Mode) | ||||||
| 	tokenId := c.GetInt("token_id") |  | ||||||
| 	channelType := c.GetInt("channel") |  | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
| 	userId := c.GetInt("id") |  | ||||||
| 	group := c.GetString("group") |  | ||||||
|  |  | ||||||
| 	var imageRequest openai.ImageRequest |  | ||||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | 		logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) | ||||||
| 	} | 		return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) | ||||||
|  |  | ||||||
| 	if imageRequest.N == 0 { |  | ||||||
| 		imageRequest.N = 1 |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Size validation |  | ||||||
| 	if imageRequest.Size != "" { |  | ||||||
| 		imageSize = imageRequest.Size |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Model validation |  | ||||||
| 	if imageRequest.Model != "" { |  | ||||||
| 		imageModel = imageRequest.Model |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] |  | ||||||
|  |  | ||||||
| 	// Check if model is supported |  | ||||||
| 	if hasValidSize { |  | ||||||
| 		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { |  | ||||||
| 			if imageSize == "1024x1024" { |  | ||||||
| 				imageCostRatio *= 2 |  | ||||||
| 			} else { |  | ||||||
| 				imageCostRatio *= 1.5 |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Prompt validation |  | ||||||
| 	if imageRequest.Prompt == "" { |  | ||||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check prompt length |  | ||||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { |  | ||||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Number of generated images validation |  | ||||||
| 	if !isWithinRange(imageModel, imageRequest.N) { |  | ||||||
| 		// channel not azure |  | ||||||
| 		if channelType != common.ChannelTypeAzure { |  | ||||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// map model name | 	// map model name | ||||||
| 	modelMapping := c.GetString("model_mapping") | 	var isModelMapped bool | ||||||
| 	isModelMapped := false | 	meta.OriginModelName = imageRequest.Model | ||||||
| 	if modelMapping != "" { | 	imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) | ||||||
| 		modelMap := make(map[string]string) | 	meta.ActualModelName = imageRequest.Model | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) |  | ||||||
|  | 	// model validation | ||||||
|  | 	bizErr := validateImageRequest(imageRequest, meta) | ||||||
|  | 	if bizErr != nil { | ||||||
|  | 		return bizErr | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	imageCostRatio, err := getImageCostRatio(imageRequest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 		if modelMap[imageModel] != "" { |  | ||||||
| 			imageModel = modelMap[imageModel] |  | ||||||
| 			isModelMapped = true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] |  | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
| 	if c.GetString("base_url") != "" { | 	fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) | ||||||
| 		baseURL = c.GetString("base_url") | 	if meta.ChannelType == common.ChannelTypeAzure { | ||||||
| 	} |  | ||||||
| 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) |  | ||||||
| 	if channelType == common.ChannelTypeAzure { |  | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||||
| 		apiVersion := util.GetAzureAPIVersion(c) | 		apiVersion := util.GetAzureAPIVersion(c) | ||||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var requestBody io.Reader | 	var requestBody io.Reader | ||||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | 	if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body | ||||||
| 		jsonStr, err := json.Marshal(imageRequest) | 		jsonStr, err := json.Marshal(imageRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	} else { | 	} else { | ||||||
| 		requestBody = c.Request.Body | 		requestBody = c.Request.Body | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	modelRatio := common.GetModelRatio(imageModel) | 	modelRatio := common.GetModelRatio(imageRequest.Model) | ||||||
| 	groupRatio := common.GetGroupRatio(group) | 	groupRatio := common.GetGroupRatio(meta.Group) | ||||||
| 	ratio := modelRatio * groupRatio | 	ratio := modelRatio * groupRatio | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | 	userQuota, err := model.CacheGetUserQuota(meta.UserId) | ||||||
|  |  | ||||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||||
|  |  | ||||||
| @@ -146,7 +92,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	token := c.Request.Header.Get("Authorization") | 	token := c.Request.Header.Get("Authorization") | ||||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication | 	if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication | ||||||
| 		token = strings.TrimPrefix(token, "Bearer ") | 		token = strings.TrimPrefix(token, "Bearer ") | ||||||
| 		req.Header.Set("api-key", token) | 		req.Header.Set("api-key", token) | ||||||
| 	} else { | 	} else { | ||||||
| @@ -169,25 +115,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	var textResponse openai.ImageResponse | 	var imageResponse openai.ImageResponse | ||||||
|  |  | ||||||
| 	defer func(ctx context.Context) { | 	defer func(ctx context.Context) { | ||||||
| 		if resp.StatusCode != http.StatusOK { | 		if resp.StatusCode != http.StatusOK { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) | 		err := model.PostConsumeTokenQuota(meta.TokenId, quota) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("error consuming token remain quota: " + err.Error()) | 			logger.SysError("error consuming token remain quota: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		err = model.CacheUpdateUserQuota(userId) | 		err = model.CacheUpdateUserQuota(meta.UserId) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("error update user quota cache: " + err.Error()) | 			logger.SysError("error update user quota cache: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		if quota != 0 { | 		if quota != 0 { | ||||||
| 			tokenName := c.GetString("token_name") | 			tokenName := c.GetString("token_name") | ||||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | 			model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) | ||||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||||
| 			channelId := c.GetInt("channel_id") | 			channelId := c.GetInt("channel_id") | ||||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | 			model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 		} | 		} | ||||||
| @@ -202,7 +148,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &textResponse) | 	err = json.Unmarshal(responseBody, &imageResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -39,6 +39,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 	ratio := modelRatio * groupRatio | 	ratio := modelRatio * groupRatio | ||||||
| 	// pre-consume quota | 	// pre-consume quota | ||||||
| 	promptTokens := getPromptTokens(textRequest, meta.Mode) | 	promptTokens := getPromptTokens(textRequest, meta.Mode) | ||||||
|  | 	meta.PromptTokens = promptTokens | ||||||
| 	preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) | 	preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) | ||||||
| 	if bizErr != nil { | 	if bizErr != nil { | ||||||
| 		logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr) | 		logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr) | ||||||
| @@ -54,7 +55,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 	var requestBody io.Reader | 	var requestBody io.Reader | ||||||
| 	if meta.APIType == constant.APITypeOpenAI { | 	if meta.APIType == constant.APITypeOpenAI { | ||||||
| 		// no need to convert request for openai | 		// no need to convert request for openai | ||||||
| 		if isModelMapped { | 		shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan | ||||||
|  | 		if shouldResetRequestBody { | ||||||
| 			jsonStr, err := json.Marshal(textRequest) | 			jsonStr, err := json.Marshal(textRequest) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | 				return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | ||||||
|   | |||||||
| @@ -71,6 +71,18 @@ export const CHANNEL_OPTIONS = { | |||||||
|     value: 23, |     value: 23, | ||||||
|     color: 'default' |     color: 'default' | ||||||
|   }, |   }, | ||||||
|  |   26: { | ||||||
|  |     key: 26, | ||||||
|  |     text: '百川大模型', | ||||||
|  |     value: 26, | ||||||
|  |     color: 'default' | ||||||
|  |   }, | ||||||
|  |   27: { | ||||||
|  |     key: 27, | ||||||
|  |     text: 'MiniMax', | ||||||
|  |     value: 27, | ||||||
|  |     color: 'default' | ||||||
|  |   }, | ||||||
|   8: { |   8: { | ||||||
|     key: 8, |     key: 8, | ||||||
|     text: '自定义渠道', |     text: '自定义渠道', | ||||||
|   | |||||||
| @@ -67,7 +67,7 @@ const typeConfig = { | |||||||
|   }, |   }, | ||||||
|   16: { |   16: { | ||||||
|     input: { |     input: { | ||||||
|       models: ["chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], |       models: ["glm-4", "glm-4v", "glm-3-turbo", "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], | ||||||
|     }, |     }, | ||||||
|     modelGroup: "zhipu", |     modelGroup: "zhipu", | ||||||
|   }, |   }, | ||||||
| @@ -145,6 +145,24 @@ const typeConfig = { | |||||||
|     }, |     }, | ||||||
|     modelGroup: "google gemini", |     modelGroup: "google gemini", | ||||||
|   }, |   }, | ||||||
|  |   25: { | ||||||
|  |     input: { | ||||||
|  |       models: ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'], | ||||||
|  |     }, | ||||||
|  |     modelGroup: "moonshot", | ||||||
|  |   }, | ||||||
|  |   26: { | ||||||
|  |     input: { | ||||||
|  |       models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'], | ||||||
|  |     }, | ||||||
|  |     modelGroup: "baichuan", | ||||||
|  |   }, | ||||||
|  |   27: { | ||||||
|  |     input: { | ||||||
|  |       models: ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat'], | ||||||
|  |     }, | ||||||
|  |     modelGroup: "minimax", | ||||||
|  |   }, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| export { defaultConfig, typeConfig }; | export { defaultConfig, typeConfig }; | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								web/build.sh
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								web/build.sh
									
									
									
									
									
								
							| @@ -1,13 +1,13 @@ | |||||||
| #!/bin/sh | #!/bin/sh | ||||||
|  |  | ||||||
| version=$(cat VERSION) | version=$(cat VERSION) | ||||||
| themes=$(cat THEMES) | pwd | ||||||
| IFS=$'\n' |  | ||||||
|  |  | ||||||
| for theme in $themes; do | while IFS= read -r theme; do | ||||||
|     echo "Building theme: $theme" |     echo "Building theme: $theme" | ||||||
|     cd $theme |     rm -r build/$theme | ||||||
|  |     cd "$theme" | ||||||
|     npm install |     npm install | ||||||
|     DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$version npm run build |     DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$version npm run build | ||||||
|     cd .. |     cd .. | ||||||
| done | done < THEMES | ||||||
|   | |||||||
| @@ -11,6 +11,8 @@ export const CHANNEL_OPTIONS = [ | |||||||
|   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, |   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||||
|   { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, |   { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, | ||||||
|   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, |   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, | ||||||
|  |   { key: 26, text: '百川大模型', value: 26, color: 'orange' }, | ||||||
|  |   { key: 27, text: 'MiniMax', value: 27, color: 'red' }, | ||||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, |   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, |   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|   | |||||||
| @@ -79,7 +79,7 @@ const EditChannel = () => { | |||||||
|           localModels = [...localModels, ...withInternetVersion]; |           localModels = [...localModels, ...withInternetVersion]; | ||||||
|           break; |           break; | ||||||
|         case 16: |         case 16: | ||||||
|           localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; |           localModels = ["glm-4", "glm-4v", "glm-3-turbo",'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||||
|           break; |           break; | ||||||
|         case 18: |         case 18: | ||||||
|           localModels = [ |           localModels = [ | ||||||
| @@ -102,6 +102,12 @@ const EditChannel = () => { | |||||||
|         case 25: |         case 25: | ||||||
|           localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']; |           localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']; | ||||||
|           break; |           break; | ||||||
|  |         case 26: | ||||||
|  |           localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding']; | ||||||
|  |           break; | ||||||
|  |         case 27: | ||||||
|  |           localModels = ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat']; | ||||||
|  |           break; | ||||||
|       } |       } | ||||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); |       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user