mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 13:53:41 +08:00 
			
		
		
		
	Compare commits
	
		
			9 Commits
		
	
	
		
			v0.6.1-alp
			...
			v0.6.1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 4fb22ad4ce | ||
|  | 95cfb8e8c9 | ||
|  | c6ace985c2 | ||
|  | 10a926b8f3 | ||
|  | 2df877a352 | ||
|  | 9d8967f7d3 | ||
|  | b35f3523d3 | ||
|  | 82e916b5ff | ||
|  | de18d6fe16 | 
| @@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) | ||||
|    + [x] [Mistral 系列模型](https://mistral.ai/) | ||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
|   | ||||
| @@ -66,6 +66,7 @@ const ( | ||||
| 	ChannelTypeMoonshot       = 25 | ||||
| 	ChannelTypeBaichuan       = 26 | ||||
| 	ChannelTypeMinimax        = 27 | ||||
| 	ChannelTypeMistral        = 28 | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| @@ -97,6 +98,7 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://api.moonshot.cn",                   // 25 | ||||
| 	"https://api.baichuan-ai.com",               // 26 | ||||
| 	"https://api.minimax.chat",                  // 27 | ||||
| 	"https://api.mistral.ai",                    // 28 | ||||
| } | ||||
|  | ||||
| const ( | ||||
|   | ||||
| @@ -7,29 +7,6 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var DalleSizeRatios = map[string]map[string]float64{ | ||||
| 	"dall-e-2": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1.125, | ||||
| 		"1024x1024": 1.25, | ||||
| 	}, | ||||
| 	"dall-e-3": { | ||||
| 		"1024x1024": 1, | ||||
| 		"1024x1792": 2, | ||||
| 		"1792x1024": 2, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var DalleGenerationImageAmounts = map[string][2]int{ | ||||
| 	"dall-e-2": {1, 10}, | ||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||
| } | ||||
|  | ||||
| var DalleImagePromptLengthLimitations = map[string]int{ | ||||
| 	"dall-e-2": 1000, | ||||
| 	"dall-e-3": 4000, | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	USD2RMB = 7 | ||||
| 	USD     = 500 // $0.002 = 1 -> $1 = 500 | ||||
| @@ -40,7 +17,6 @@ const ( | ||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||
| // https://openai.com/pricing | ||||
| // TODO: when a new api is enabled, check the pricing here | ||||
| // 1 === $0.002 / 1K tokens | ||||
| // 1 === ¥0.014 / 1k tokens | ||||
| var ModelRatio = map[string]float64{ | ||||
| @@ -139,15 +115,29 @@ var ModelRatio = map[string]float64{ | ||||
| 	"abab6-chat":    0.1 * RMB, | ||||
| 	"abab5.5-chat":  0.015 * RMB, | ||||
| 	"abab5.5s-chat": 0.005 * RMB, | ||||
| 	// https://docs.mistral.ai/platform/pricing/ | ||||
| 	"open-mistral-7b":       0.25 / 1000 * USD, | ||||
| 	"open-mixtral-8x7b":     0.7 / 1000 * USD, | ||||
| 	"mistral-small-latest":  2.0 / 1000 * USD, | ||||
| 	"mistral-medium-latest": 2.7 / 1000 * USD, | ||||
| 	"mistral-large-latest":  8.0 / 1000 * USD, | ||||
| 	"mistral-embed":         0.1 / 1000 * USD, | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{} | ||||
|  | ||||
| var DefaultModelRatio map[string]float64 | ||||
| var DefaultCompletionRatio map[string]float64 | ||||
|  | ||||
| func init() { | ||||
| 	DefaultModelRatio = make(map[string]float64) | ||||
| 	for k, v := range ModelRatio { | ||||
| 		DefaultModelRatio[k] = v | ||||
| 	} | ||||
| 	DefaultCompletionRatio = make(map[string]float64) | ||||
| 	for k, v := range CompletionRatio { | ||||
| 		DefaultCompletionRatio[k] = v | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
| @@ -178,8 +168,6 @@ func GetModelRatio(name string) float64 { | ||||
| 	return ratio | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{} | ||||
|  | ||||
| func CompletionRatio2JSONString() string { | ||||
| 	jsonBytes, err := json.Marshal(CompletionRatio) | ||||
| 	if err != nil { | ||||
| @@ -197,6 +185,9 @@ func GetCompletionRatio(name string) float64 { | ||||
| 	if ratio, ok := CompletionRatio[name]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if ratio, ok := DefaultCompletionRatio[name]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-3.5") { | ||||
| 		if strings.HasSuffix(name, "0125") { | ||||
| 			// https://openai.com/blog/new-embedding-models-and-api-updates | ||||
| @@ -229,5 +220,8 @@ func GetCompletionRatio(name string) float64 { | ||||
| 	if strings.HasPrefix(name, "claude-2") { | ||||
| 		return 2.965517 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "mistral-") { | ||||
| 		return 3 | ||||
| 	} | ||||
| 	return 1 | ||||
| } | ||||
|   | ||||
							
								
								
									
										8
									
								
								common/random.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								common/random.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| package common | ||||
|  | ||||
| import "math/rand" | ||||
|  | ||||
| // RandRange returns a random number between min and max (max is not included) | ||||
| func RandRange(min, max int) int { | ||||
| 	return min + rand.Intn(max-min) | ||||
| } | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| @@ -18,6 +19,7 @@ import ( | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| @@ -51,6 +53,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	c.Request.Header.Set("Content-Type", "application/json") | ||||
| 	c.Set("channel", channel.Type) | ||||
| 	c.Set("base_url", channel.GetBaseURL()) | ||||
| 	middleware.SetupContextForSelectedChannel(c, channel, "") | ||||
| 	meta := util.GetRelayMeta(c) | ||||
| 	apiType := constant.ChannelType2APIType(channel.Type) | ||||
| 	adaptor := helper.GetAdaptor(apiType) | ||||
| @@ -59,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
| 	modelName := adaptor.GetModelList()[0] | ||||
| 	if !strings.Contains(channel.Models, modelName) { | ||||
| 		modelNames := strings.Split(channel.Models, ",") | ||||
| 		if len(modelNames) > 0 { | ||||
| 			modelName = modelNames[0] | ||||
| 		} | ||||
| 	} | ||||
| 	request := buildTestRequest() | ||||
| 	request.Model = modelName | ||||
| 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| @@ -122,6 +123,17 @@ func init() { | ||||
| 			Parent:     nil, | ||||
| 		}) | ||||
| 	} | ||||
| 	for _, modelName := range mistral.ModelList { | ||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | ||||
| 			Id:         modelName, | ||||
| 			Object:     "model", | ||||
| 			Created:    1626777600, | ||||
| 			OwnedBy:    "mistralai", | ||||
| 			Permission: permission, | ||||
| 			Root:       modelName, | ||||
| 			Parent:     nil, | ||||
| 		}) | ||||
| 	} | ||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | ||||
| 	for _, model := range openAIModels { | ||||
| 		openAIModelsMap[model.Id] = model | ||||
|   | ||||
| @@ -62,7 +62,7 @@ func Relay(c *gin.Context) { | ||||
| 		retryTimes = 0 | ||||
| 	} | ||||
| 	for i := retryTimes; i > 0; i-- { | ||||
| 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) | ||||
| 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) | ||||
| 			break | ||||
|   | ||||
| @@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 				} | ||||
| 			} | ||||
| 			requestModel = modelRequest.Model | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				if channel != nil { | ||||
|   | ||||
| @@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||
| 	if !config.MemoryCacheEnabled { | ||||
| 		return GetRandomSatisfiedChannel(group, model) | ||||
| 	} | ||||
| @@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | ||||
| 		} | ||||
| 	} | ||||
| 	idx := rand.Intn(endIdx) | ||||
| 	if ignoreFirstPriority { | ||||
| 		if endIdx < len(channels) { // which means there are more than one priority | ||||
| 			idx = common.RandRange(endIdx, len(channels)) | ||||
| 		} | ||||
| 	} | ||||
| 	return channels[idx], nil | ||||
| } | ||||
|   | ||||
| @@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		enableSearch = true | ||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||
| 	} | ||||
| 	if request.TopP >= 1 { | ||||
| 		request.TopP = 0.9999 | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Model: aliModel, | ||||
| 		Input: Input{ | ||||
| @@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			EnableSearch:      enableSearch, | ||||
| 			IncrementalOutput: request.Stream, | ||||
| 			Seed:              uint64(request.Seed), | ||||
| 			MaxTokens:         request.MaxTokens, | ||||
| 			Temperature:       request.Temperature, | ||||
| 			TopP:              request.TopP, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -16,6 +16,8 @@ type Parameters struct { | ||||
| 	Seed              uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int     `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64 `json:"temperature,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
|   | ||||
| @@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 	case "Embedding-V1": | ||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 	default: | ||||
|                fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName | ||||
| 	} | ||||
| 	var accessToken string | ||||
| 	var err error | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| package gemini | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-pro", | ||||
| 	"gemini-pro-vision", | ||||
| 	"gemini-pro", "gemini-1.0-pro-001", | ||||
| 	"gemini-pro-vision", "gemini-1.0-pro-vision-001", | ||||
| } | ||||
|   | ||||
							
								
								
									
										10
									
								
								relay/channel/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/channel/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package mistral | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"open-mistral-7b", | ||||
| 	"open-mixtral-8x7b", | ||||
| 	"mistral-small-latest", | ||||
| 	"mistral-medium-latest", | ||||
| 	"mistral-large-latest", | ||||
| 	"mistral-embed", | ||||
| } | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| @@ -94,6 +95,8 @@ func (a *Adaptor) GetModelList() []string { | ||||
| 		return baichuan.ModelList | ||||
| 	case common.ChannelTypeMinimax: | ||||
| 		return minimax.ModelList | ||||
| 	case common.ChannelTypeMistral: | ||||
| 		return mistral.ModelList | ||||
| 	default: | ||||
| 		return ModelList | ||||
| 	} | ||||
| @@ -111,6 +114,8 @@ func (a *Adaptor) GetChannelName() string { | ||||
| 		return "baichuan" | ||||
| 	case common.ChannelTypeMinimax: | ||||
| 		return "minimax" | ||||
| 	case common.ChannelTypeMistral: | ||||
| 		return "mistralai" | ||||
| 	default: | ||||
| 		return "openai" | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| package constant | ||||
|  | ||||
| var DalleSizeRatios = map[string]map[string]float64{ | ||||
| 	"dall-e-2": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1.125, | ||||
| 		"1024x1024": 1.25, | ||||
| 	}, | ||||
| 	"dall-e-3": { | ||||
| 		"1024x1024": 1, | ||||
| 		"1024x1792": 2, | ||||
| 		"1792x1024": 2, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var DalleGenerationImageAmounts = map[string][2]int{ | ||||
| 	"dall-e-2": {1, 10}, | ||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||
| } | ||||
|  | ||||
| var DalleImagePromptLengthLimitations = map[string]int{ | ||||
| 	"dall-e-2": 1000, | ||||
| 	"dall-e-3": 4000, | ||||
| } | ||||
| @@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener | ||||
| 	return textRequest, nil | ||||
| } | ||||
|  | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { | ||||
| 	imageRequest := &openai.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
| 	if imageRequest.Size == "" { | ||||
| 		imageRequest.Size = "1024x1024" | ||||
| 	} | ||||
| 	if imageRequest.Model == "" { | ||||
| 		imageRequest.Model = "dall-e-2" | ||||
| 	} | ||||
| 	return imageRequest, nil | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// model validation | ||||
| 	_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||
| 	if !hasValidSize { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||
| 		// channel not azure | ||||
| 		if meta.ChannelType != common.ChannelTypeAzure { | ||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { | ||||
| 	if imageRequest == nil { | ||||
| 		return 0, errors.New("imageRequest is nil") | ||||
| 	} | ||||
| 	imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||
| 	if !hasValidSize { | ||||
| 		return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) | ||||
| 	} | ||||
| 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||
| 		if imageRequest.Size == "1024x1024" { | ||||
| 			imageCostRatio *= 2 | ||||
| 		} else { | ||||
| 			imageCostRatio *= 1.5 | ||||
| 		} | ||||
| 	} | ||||
| 	return imageCostRatio, nil | ||||
| } | ||||
|  | ||||
| func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | ||||
| 	switch relayMode { | ||||
| 	case constant.RelayModeChatCompletions: | ||||
| @@ -113,10 +172,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||
| 	} | ||||
| 	if quota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | ||||
| 		model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 		model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||
| 	} | ||||
| 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | ||||
| 	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||
| 	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 	model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||
| } | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| @@ -20,120 +21,65 @@ import ( | ||||
| ) | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	if _, ok := common.DalleGenerationImageAmounts[element]; !ok { | ||||
| 	if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { | ||||
| 		return false | ||||
| 	} | ||||
| 	min := common.DalleGenerationImageAmounts[element][0] | ||||
| 	max := common.DalleGenerationImageAmounts[element][1] | ||||
| 	min := constant.DalleGenerationImageAmounts[element][0] | ||||
| 	max := constant.DalleGenerationImageAmounts[element][1] | ||||
|  | ||||
| 	return value >= min && value <= max | ||||
| } | ||||
|  | ||||
| func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||
| 	imageModel := "dall-e-2" | ||||
| 	imageSize := "1024x1024" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	var imageRequest openai.ImageRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 	ctx := c.Request.Context() | ||||
| 	meta := util.GetRelayMeta(c) | ||||
| 	imageRequest, err := getImageRequest(c, meta.Mode) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
|  | ||||
| 	// Size validation | ||||
| 	if imageRequest.Size != "" { | ||||
| 		imageSize = imageRequest.Size | ||||
| 	} | ||||
|  | ||||
| 	// Model validation | ||||
| 	if imageRequest.Model != "" { | ||||
| 		imageModel = imageRequest.Model | ||||
| 	} | ||||
|  | ||||
| 	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] | ||||
|  | ||||
| 	// Check if model is supported | ||||
| 	if hasValidSize { | ||||
| 		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { | ||||
| 			if imageSize == "1024x1024" { | ||||
| 				imageCostRatio *= 2 | ||||
| 			} else { | ||||
| 				imageCostRatio *= 1.5 | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Check prompt length | ||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageModel, imageRequest.N) { | ||||
| 		// channel not azure | ||||
| 		if channelType != common.ChannelTypeAzure { | ||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 		} | ||||
| 		logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) | ||||
| 		return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[imageModel] != "" { | ||||
| 			imageModel = modelMap[imageModel] | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	var isModelMapped bool | ||||
| 	meta.OriginModelName = imageRequest.Model | ||||
| 	imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) | ||||
| 	meta.ActualModelName = imageRequest.Model | ||||
|  | ||||
| 	// model validation | ||||
| 	bizErr := validateImageRequest(imageRequest, meta) | ||||
| 	if bizErr != nil { | ||||
| 		return bizErr | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
|  | ||||
| 	imageCostRatio, err := getImageCostRatio(imageRequest) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if channelType == common.ChannelTypeAzure { | ||||
| 	fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) | ||||
| 	if meta.ChannelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||
| 		apiVersion := util.GetAzureAPIVersion(c) | ||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) | ||||
| 	} | ||||
|  | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | ||||
| 	if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 			return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
|  | ||||
| 	modelRatio := common.GetModelRatio(imageModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	modelRatio := common.GetModelRatio(imageRequest.Model) | ||||
| 	groupRatio := common.GetGroupRatio(meta.Group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	userQuota, err := model.CacheGetUserQuota(meta.UserId) | ||||
|  | ||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||
|  | ||||
| @@ -146,7 +92,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	token := c.Request.Header.Get("Authorization") | ||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication | ||||
| 	if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication | ||||
| 		token = strings.TrimPrefix(token, "Bearer ") | ||||
| 		req.Header.Set("api-key", token) | ||||
| 	} else { | ||||
| @@ -169,25 +115,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var textResponse openai.ImageResponse | ||||
| 	var imageResponse openai.ImageResponse | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			return | ||||
| 		} | ||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 		err := model.PostConsumeTokenQuota(meta.TokenId, quota) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 		} | ||||
| 		err = model.CacheUpdateUserQuota(userId) | ||||
| 		err = model.CacheUpdateUserQuota(meta.UserId) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error update user quota cache: " + err.Error()) | ||||
| 		} | ||||
| 		if quota != 0 { | ||||
| 			tokenName := c.GetString("token_name") | ||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 			model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 		} | ||||
| @@ -202,7 +148,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	err = json.Unmarshal(responseBody, &imageResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|   | ||||
| @@ -29,6 +29,12 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 24, | ||||
|     color: 'orange' | ||||
|   }, | ||||
|   28: { | ||||
|     key: 28, | ||||
|     text: 'Mistral AI', | ||||
|     value: 28, | ||||
|     color: 'orange' | ||||
|   }, | ||||
|   15: { | ||||
|     key: 15, | ||||
|     text: '百度文心千帆', | ||||
|   | ||||
| @@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||
|   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||
|   { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, | ||||
|   { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, | ||||
|   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||
|   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user