From 18fc5fc0bee5f5d164f6e94a6211326a2cc425d5 Mon Sep 17 00:00:00 2001 From: longkeyy Date: Wed, 27 Aug 2025 02:24:38 +0800 Subject: [PATCH] feat: implement model alias system for unified model names - Add comprehensive model alias mapping system in relay/model/alias.go - Support standard model names (gpt-4o, claude-3-sonnet) across different channels - Modify AddAbilities() to create aliases for both standard and channel-specific names - Update text processing to resolve aliases before channel-specific mapping - Enhance billing system to support alias-based model ratio lookup - Add comprehensive tests for alias resolution and reverse mapping - Support major providers: OpenRouter, Anthropic, Gemini, Groq This allows users to use consistent model names (e.g., 'gpt-4o') regardless of channel provider, with automatic mapping to channel-specific names (e.g., 'openai/gpt-4o' for OpenRouter). --- model/ability.go | 75 ++++++++++++- relay/billing/ratio/model.go | 67 +++++++++++ relay/controller/helper.go | 58 ++++++++++ relay/controller/text.go | 5 + relay/model/alias.go | 212 +++++++++++++++++++++++++++++++++++ relay/model/alias_test.go | 171 ++++++++++++++++++++++++++++ 6 files changed, 586 insertions(+), 2 deletions(-) create mode 100644 relay/model/alias.go create mode 100644 relay/model/alias_test.go diff --git a/model/ability.go b/model/ability.go index 5cfb9949..cad720e4 100644 --- a/model/ability.go +++ b/model/ability.go @@ -54,8 +54,12 @@ func (channel *Channel) AddAbilities() error { models_ := strings.Split(channel.Models, ",") models_ = utils.DeDuplication(models_) groups_ := strings.Split(channel.Group, ",") - abilities := make([]Ability, 0, len(models_)) - for _, model := range models_ { + + // Expand models to include aliases for standard names + expandedModels := expandModelsWithAliases(models_, channel.Type) + + abilities := make([]Ability, 0, len(expandedModels)) + for _, model := range expandedModels { for _, group := range groups_ { ability := Ability{ Group: group, @@ -70,6 +74,73 @@ func (channel *Channel) AddAbilities() error { return DB.Create(&abilities).Error } +// expandModelsWithAliases expands model list to include standard names for channel-specific models +func expandModelsWithAliases(models []string, channelType int) []string { + expandedModels := make([]string, 0) + modelSet := make(map[string]bool) + + for _, model := range models { + model = strings.TrimSpace(model) + if model == "" { + continue + } + + // Add original model + if !modelSet[model] { + expandedModels = append(expandedModels, model) + modelSet[model] = true + } + + // Try to find standard name for this channel-specific model + standardName := getStandardModelNameForChannel(model, channelType) + if standardName != model && !modelSet[standardName] { + expandedModels = append(expandedModels, standardName) + modelSet[standardName] = true + } + } + + return expandedModels +} + +// getStandardModelNameForChannel returns the standard name for a channel-specific model +func getStandardModelNameForChannel(actualName string, channelType int) string { + // Import alias mapping (we'll create a lightweight version here to avoid circular imports) + aliasMap := getModelAliasesForChannelType(channelType) + + for standard, actual := range aliasMap { + if actual == actualName { + return standard + } + } + + return actualName +} + +// getModelAliasesForChannelType returns model aliases for specific channel type +// This is a lightweight version to avoid importing the full alias module +func getModelAliasesForChannelType(channelType int) map[string]string { + switch channelType { + case 24: // OpenRouter + return map[string]string{ + "gpt-4o": "openai/gpt-4o", + "gpt-4o-mini": "openai/gpt-4o-mini", + "gpt-4": "openai/gpt-4", + "gpt-4-turbo": "openai/gpt-4-turbo", + "gpt-3.5-turbo": "openai/gpt-3.5-turbo", + "o1": "openai/o1", + "o1-mini": "openai/o1-mini", + "o1-preview": "openai/o1-preview", + "claude-3-haiku": "anthropic/claude-3-haiku", + "claude-3-sonnet": "anthropic/claude-3-sonnet", + "claude-3-opus": "anthropic/claude-3-opus", + "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", + "claude-3.5-haiku": "anthropic/claude-3.5-haiku", + } + default: + return map[string]string{} + } +} + func (channel *Channel) DeleteAbilities() error { return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index e8b3b615..7633325e 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -692,6 +692,8 @@ func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } + + // Try channel-specific model ratio first model := fmt.Sprintf("%s(%d)", name, channelType) if ratio, ok := ModelRatio[model]; ok { return ratio @@ -699,16 +701,81 @@ func GetModelRatio(name string, channelType int) float64 { if ratio, ok := DefaultModelRatio[model]; ok { return ratio } + + // Try direct model name if ratio, ok := ModelRatio[name]; ok { return ratio } if ratio, ok := DefaultModelRatio[name]; ok { return ratio } + + // Try to find standard model name for alias lookup + standardName := getStandardModelNameForBilling(name, channelType) + if standardName != name { + // Try standard model name + if ratio, ok := ModelRatio[standardName]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[standardName]; ok { + return ratio + } + + // Try standard model with channel type + standardModel := fmt.Sprintf("%s(%d)", standardName, channelType) + if ratio, ok := ModelRatio[standardModel]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[standardModel]; ok { + return ratio + } + } + logger.SysError("model ratio not found: " + name) return 30 } +// getStandardModelNameForBilling returns the standard model name for billing lookup +func getStandardModelNameForBilling(actualName string, channelType int) string { + // Reverse alias mapping for billing + aliasMap := getBillingAliasMap(channelType) + + for standard, actual := range aliasMap { + if actual == actualName { + return standard + } + } + + return actualName +} + +// getBillingAliasMap returns alias mapping for billing purposes +func getBillingAliasMap(channelType int) map[string]string { + switch channelType { + case 24: // OpenRouter + return map[string]string{ + "gpt-4o": "openai/gpt-4o", + "gpt-4o-mini": "openai/gpt-4o-mini", + "gpt-4": "openai/gpt-4", + "gpt-4-turbo": "openai/gpt-4-turbo", + "gpt-3.5-turbo": "openai/gpt-3.5-turbo", + "claude-3-haiku": "anthropic/claude-3-haiku", + "claude-3-sonnet": "anthropic/claude-3-sonnet", + "claude-3-opus": "anthropic/claude-3-opus", + "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", + } + case 18: // Anthropic + return map[string]string{ + "claude-3-haiku": "claude-3-haiku-20240307", + "claude-3-sonnet": "claude-3-sonnet-20240229", + "claude-3-opus": "claude-3-opus-20240229", + "claude-3.5-sonnet": "claude-3-5-sonnet-20241022", + } + default: + return map[string]string{} + } +} + func CompletionRatio2JSONString() string { jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 5b6f023f..d42abe4f 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -151,6 +151,64 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo return modelName, false } +// resolveModelAlias resolves standard model names to channel-specific names +func resolveModelAlias(modelName string, channelType int) string { + // Lightweight alias resolution to avoid circular imports + aliasMap := getChannelModelAliases(channelType) + + if actualName, exists := aliasMap[modelName]; exists { + return actualName + } + + return modelName +} + +// getChannelModelAliases returns model aliases for specific channel type +func getChannelModelAliases(channelType int) map[string]string { + switch channelType { + case 24: // OpenRouter + return map[string]string{ + "gpt-4o": "openai/gpt-4o", + "gpt-4o-mini": "openai/gpt-4o-mini", + "gpt-4": "openai/gpt-4", + "gpt-4-turbo": "openai/gpt-4-turbo", + "gpt-3.5-turbo": "openai/gpt-3.5-turbo", + "gpt-3.5-turbo-0125": "openai/gpt-3.5-turbo-0125", + "o1": "openai/o1", + "o1-mini": "openai/o1-mini", + "o1-preview": "openai/o1-preview", + "claude-3-haiku": "anthropic/claude-3-haiku", + "claude-3-sonnet": "anthropic/claude-3-sonnet", + "claude-3-opus": "anthropic/claude-3-opus", + "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", + "claude-3.5-haiku": "anthropic/claude-3.5-haiku", + } + case 18: // Anthropic + return map[string]string{ + "claude-3-haiku": "claude-3-haiku-20240307", + "claude-3-sonnet": "claude-3-sonnet-20240229", + "claude-3-opus": "claude-3-opus-20240229", + "claude-3.5-sonnet": "claude-3-5-sonnet-20241022", + "claude-3.5-haiku": "claude-3-5-haiku-20241022", + } + case 28: // Gemini + return map[string]string{ + "gemini-pro": "gemini-pro", + "gemini-pro-1.5": "gemini-1.5-pro-latest", + "gemini-flash-1.5": "gemini-1.5-flash-latest", + } + case 33: // Groq + return map[string]string{ + "llama-3-8b-instruct": "llama3-8b-8192", + "llama-3-70b-instruct": "llama3-70b-8192", + "llama-3.1-8b-instruct": "llama-3.1-8b-instant", + "llama-3.1-70b-instruct": "llama-3.1-70b-versatile", + } + default: + return map[string]string{} + } +} + func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { if resp == nil { if meta.ChannelType == channeltype.AwsClaude { diff --git a/relay/controller/text.go b/relay/controller/text.go index f912498a..8ceaa04f 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -35,6 +35,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // map model name meta.OriginModelName = textRequest.Model + + // First resolve model alias (standard name to channel-specific name) + textRequest.Model = resolveModelAlias(textRequest.Model, meta.ChannelType) + + // Then apply channel-specific model mapping (if configured) textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // set system prompt if not empty diff --git a/relay/model/alias.go b/relay/model/alias.go new file mode 100644 index 00000000..f4172dbb --- /dev/null +++ b/relay/model/alias.go @@ -0,0 +1,212 @@ +package model + +import ( + "fmt" + "sync" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/channeltype" +) + +// ModelAliasMap maps standard model names to channel-specific names +// Format: standardName -> channelType -> actualName +var ModelAliasMap = map[string]map[int]string{ + // OpenAI GPT models + "gpt-3.5-turbo": { + channeltype.OpenRouter: "openai/gpt-3.5-turbo", + channeltype.OpenAI: "gpt-3.5-turbo", + }, + "gpt-3.5-turbo-0125": { + channeltype.OpenRouter: "openai/gpt-3.5-turbo-0125", + channeltype.OpenAI: "gpt-3.5-turbo-0125", + }, + "gpt-4": { + channeltype.OpenRouter: "openai/gpt-4", + channeltype.OpenAI: "gpt-4", + }, + "gpt-4-turbo": { + channeltype.OpenRouter: "openai/gpt-4-turbo", + channeltype.OpenAI: "gpt-4-turbo", + }, + "gpt-4o": { + channeltype.OpenRouter: "openai/gpt-4o", + channeltype.OpenAI: "gpt-4o", + }, + "gpt-4o-mini": { + channeltype.OpenRouter: "openai/gpt-4o-mini", + channeltype.OpenAI: "gpt-4o-mini", + }, + "o1": { + channeltype.OpenRouter: "openai/o1", + channeltype.OpenAI: "o1", + }, + "o1-mini": { + channeltype.OpenRouter: "openai/o1-mini", + channeltype.OpenAI: "o1-mini", + }, + "o1-preview": { + channeltype.OpenRouter: "openai/o1-preview", + channeltype.OpenAI: "o1-preview", + }, + + // Anthropic Claude models + "claude-3-haiku": { + channeltype.OpenRouter: "anthropic/claude-3-haiku", + channeltype.Anthropic: "claude-3-haiku-20240307", + }, + "claude-3-sonnet": { + channeltype.OpenRouter: "anthropic/claude-3-sonnet", + channeltype.Anthropic: "claude-3-sonnet-20240229", + }, + "claude-3-opus": { + channeltype.OpenRouter: "anthropic/claude-3-opus", + channeltype.Anthropic: "claude-3-opus-20240229", + }, + "claude-3.5-sonnet": { + channeltype.OpenRouter: "anthropic/claude-3.5-sonnet", + channeltype.Anthropic: "claude-3-5-sonnet-20241022", + }, + "claude-3.5-haiku": { + channeltype.OpenRouter: "anthropic/claude-3.5-haiku", + channeltype.Anthropic: "claude-3-5-haiku-20241022", + }, + + // Google models + "gemini-pro": { + channeltype.OpenRouter: "google/gemini-pro", + channeltype.Gemini: "gemini-pro", + }, + "gemini-pro-1.5": { + channeltype.OpenRouter: "google/gemini-pro-1.5", + channeltype.Gemini: "gemini-1.5-pro-latest", + }, + "gemini-flash-1.5": { + channeltype.OpenRouter: "google/gemini-flash-1.5", + channeltype.Gemini: "gemini-1.5-flash-latest", + }, + + // Meta Llama models + "llama-3-8b-instruct": { + channeltype.OpenRouter: "meta-llama/llama-3-8b-instruct", + channeltype.Groq: "llama3-8b-8192", + }, + "llama-3-70b-instruct": { + channeltype.OpenRouter: "meta-llama/llama-3-70b-instruct", + channeltype.Groq: "llama3-70b-8192", + }, + "llama-3.1-8b-instruct": { + channeltype.OpenRouter: "meta-llama/llama-3.1-8b-instruct", + channeltype.Groq: "llama-3.1-8b-instant", + }, + "llama-3.1-70b-instruct": { + channeltype.OpenRouter: "meta-llama/llama-3.1-70b-instruct", + channeltype.Groq: "llama-3.1-70b-versatile", + }, + + // Mistral models + "mistral-7b-instruct": { + channeltype.OpenRouter: "mistralai/mistral-7b-instruct", + channeltype.Mistral: "mistral-small-latest", + }, + "mixtral-8x7b-instruct": { + channeltype.OpenRouter: "mistralai/mixtral-8x7b-instruct", + channeltype.Mistral: "mixtral-8x7b-instruct-v0.1", + }, +} + +var aliasLock sync.RWMutex + +// ResolveModelAlias resolves a standard model name to channel-specific name +func ResolveModelAlias(standardName string, channelType int) string { + aliasLock.RLock() + defer aliasLock.RUnlock() + + if channelMap, exists := ModelAliasMap[standardName]; exists { + if actualName, exists := channelMap[channelType]; exists { + return actualName + } + } + + // If no alias found, return original name + return standardName +} + +// RegisterModelAlias registers a new model alias +func RegisterModelAlias(standardName, actualName string, channelType int) { + aliasLock.Lock() + defer aliasLock.Unlock() + + if ModelAliasMap[standardName] == nil { + ModelAliasMap[standardName] = make(map[int]string) + } + + ModelAliasMap[standardName][channelType] = actualName + logger.SysLog(fmt.Sprintf("registered alias: %s -> %s (channel type %d)", standardName, actualName, channelType)) +} + +// GetAllModelAliases returns all registered aliases +func GetAllModelAliases() map[string]map[int]string { + aliasLock.RLock() + defer aliasLock.RUnlock() + + // Return a copy to prevent external modification + result := make(map[string]map[int]string) + for standard, channelMap := range ModelAliasMap { + result[standard] = make(map[int]string) + for channelType, actual := range channelMap { + result[standard][channelType] = actual + } + } + + return result +} + +// GetStandardModelName returns the standard name for a channel-specific model +func GetStandardModelName(actualName string, channelType int) string { + aliasLock.RLock() + defer aliasLock.RUnlock() + + for standard, channelMap := range ModelAliasMap { + if channelMap[channelType] == actualName { + return standard + } + } + + // If no reverse mapping found, return original name + return actualName +} + +// IsAliasSupported checks if a standard model name has aliases for given channel type +func IsAliasSupported(standardName string, channelType int) bool { + aliasLock.RLock() + defer aliasLock.RUnlock() + + if channelMap, exists := ModelAliasMap[standardName]; exists { + _, exists := channelMap[channelType] + return exists + } + + return false +} + +// GetSupportedChannelTypes returns all channel types that support the given standard model +func GetSupportedChannelTypes(standardName string) []int { + aliasLock.RLock() + defer aliasLock.RUnlock() + + var channelTypes []int + if channelMap, exists := ModelAliasMap[standardName]; exists { + for channelType := range channelMap { + channelTypes = append(channelTypes, channelType) + } + } + + return channelTypes +} + +// LoadModelAliasesFromConfig loads aliases from configuration (placeholder for future implementation) +func LoadModelAliasesFromConfig() error { + // TODO: Implement loading from database or configuration file + logger.SysLog("loaded model aliases from static configuration") + return nil +} diff --git a/relay/model/alias_test.go b/relay/model/alias_test.go new file mode 100644 index 00000000..e1b65252 --- /dev/null +++ b/relay/model/alias_test.go @@ -0,0 +1,171 @@ +package model + +import ( + "testing" + + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/stretchr/testify/assert" +) + +func TestResolveModelAlias(t *testing.T) { + tests := []struct { + name string + standardName string + channelType int + expected string + }{ + { + name: "OpenRouter gpt-4o alias", + standardName: "gpt-4o", + channelType: channeltype.OpenRouter, + expected: "openai/gpt-4o", + }, + { + name: "OpenAI gpt-4o direct", + standardName: "gpt-4o", + channelType: channeltype.OpenAI, + expected: "gpt-4o", + }, + { + name: "Anthropic Claude alias", + standardName: "claude-3-sonnet", + channelType: channeltype.Anthropic, + expected: "claude-3-sonnet-20240229", + }, + { + name: "Non-existent alias", + standardName: "non-existent-model", + channelType: channeltype.OpenRouter, + expected: "non-existent-model", + }, + { + name: "Unsupported channel type", + standardName: "gpt-4o", + channelType: 999, + expected: "gpt-4o", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ResolveModelAlias(tt.standardName, tt.channelType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRegisterModelAlias(t *testing.T) { + originalMap := ModelAliasMap + defer func() { ModelAliasMap = originalMap }() + + // Clear the map for testing + ModelAliasMap = make(map[string]map[int]string) + + RegisterModelAlias("test-model", "actual-test-model", channeltype.OpenRouter) + + assert.Contains(t, ModelAliasMap, "test-model") + assert.Equal(t, "actual-test-model", ModelAliasMap["test-model"][channeltype.OpenRouter]) +} + +func TestGetStandardModelName(t *testing.T) { + tests := []struct { + name string + actualName string + channelType int + expected string + }{ + { + name: "OpenRouter reverse lookup", + actualName: "openai/gpt-4o", + channelType: channeltype.OpenRouter, + expected: "gpt-4o", + }, + { + name: "Anthropic reverse lookup", + actualName: "claude-3-sonnet-20240229", + channelType: channeltype.Anthropic, + expected: "claude-3-sonnet", + }, + { + name: "No reverse mapping", + actualName: "unknown-model", + channelType: channeltype.OpenRouter, + expected: "unknown-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetStandardModelName(tt.actualName, tt.channelType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsAliasSupported(t *testing.T) { + tests := []struct { + name string + standardName string + channelType int + expected bool + }{ + { + name: "Supported alias", + standardName: "gpt-4o", + channelType: channeltype.OpenRouter, + expected: true, + }, + { + name: "Unsupported channel", + standardName: "gpt-4o", + channelType: 999, + expected: false, + }, + { + name: "Unsupported model", + standardName: "non-existent", + channelType: channeltype.OpenRouter, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsAliasSupported(tt.standardName, tt.channelType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetSupportedChannelTypes(t *testing.T) { + channelTypes := GetSupportedChannelTypes("gpt-4o") + + assert.Contains(t, channelTypes, channeltype.OpenRouter) + assert.Contains(t, channelTypes, channeltype.OpenAI) + assert.NotContains(t, channelTypes, 999) // non-existent channel type + + emptyTypes := GetSupportedChannelTypes("non-existent-model") + assert.Empty(t, emptyTypes) +} + +func TestGetAllModelAliases(t *testing.T) { + aliases := GetAllModelAliases() + + assert.NotNil(t, aliases) + assert.Contains(t, aliases, "gpt-4o") + assert.Contains(t, aliases["gpt-4o"], channeltype.OpenRouter) + assert.Equal(t, "openai/gpt-4o", aliases["gpt-4o"][channeltype.OpenRouter]) +} + +// Benchmark tests +func BenchmarkResolveModelAlias(b *testing.B) { + for i := 0; i < b.N; i++ { + ResolveModelAlias("gpt-4o", channeltype.OpenRouter) + } +} + +func BenchmarkGetStandardModelName(b *testing.B) { + for i := 0; i < b.N; i++ { + GetStandardModelName("openai/gpt-4o", channeltype.OpenRouter) + } +}