This commit is contained in:
longkeyy 2025-08-26 18:46:50 +00:00 committed by GitHub
commit b40d55ee3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 586 additions and 2 deletions

View File

@ -54,8 +54,12 @@ func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
models_ = utils.DeDuplication(models_)
groups_ := strings.Split(channel.Group, ",")
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
// Expand models to include aliases for standard names
expandedModels := expandModelsWithAliases(models_, channel.Type)
abilities := make([]Ability, 0, len(expandedModels))
for _, model := range expandedModels {
for _, group := range groups_ {
ability := Ability{
Group: group,
@ -70,6 +74,73 @@ func (channel *Channel) AddAbilities() error {
return DB.Create(&abilities).Error
}
// expandModelsWithAliases expands model list to include standard names for channel-specific models
func expandModelsWithAliases(models []string, channelType int) []string {
expandedModels := make([]string, 0)
modelSet := make(map[string]bool)
for _, model := range models {
model = strings.TrimSpace(model)
if model == "" {
continue
}
// Add original model
if !modelSet[model] {
expandedModels = append(expandedModels, model)
modelSet[model] = true
}
// Try to find standard name for this channel-specific model
standardName := getStandardModelNameForChannel(model, channelType)
if standardName != model && !modelSet[standardName] {
expandedModels = append(expandedModels, standardName)
modelSet[standardName] = true
}
}
return expandedModels
}
// getStandardModelNameForChannel returns the standard name for a channel-specific model
func getStandardModelNameForChannel(actualName string, channelType int) string {
// Import alias mapping (we'll create a lightweight version here to avoid circular imports)
aliasMap := getModelAliasesForChannelType(channelType)
for standard, actual := range aliasMap {
if actual == actualName {
return standard
}
}
return actualName
}
// getModelAliasesForChannelType returns model aliases for specific channel type
// This is a lightweight version to avoid importing the full alias module
func getModelAliasesForChannelType(channelType int) map[string]string {
switch channelType {
case 24: // OpenRouter
return map[string]string{
"gpt-4o": "openai/gpt-4o",
"gpt-4o-mini": "openai/gpt-4o-mini",
"gpt-4": "openai/gpt-4",
"gpt-4-turbo": "openai/gpt-4-turbo",
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
"o1": "openai/o1",
"o1-mini": "openai/o1-mini",
"o1-preview": "openai/o1-preview",
"claude-3-haiku": "anthropic/claude-3-haiku",
"claude-3-sonnet": "anthropic/claude-3-sonnet",
"claude-3-opus": "anthropic/claude-3-opus",
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
"claude-3.5-haiku": "anthropic/claude-3.5-haiku",
}
default:
return map[string]string{}
}
}
func (channel *Channel) DeleteAbilities() error {
return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
}

View File

@ -692,6 +692,8 @@ func GetModelRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
// Try channel-specific model ratio first
model := fmt.Sprintf("%s(%d)", name, channelType)
if ratio, ok := ModelRatio[model]; ok {
return ratio
@ -699,16 +701,81 @@ func GetModelRatio(name string, channelType int) float64 {
if ratio, ok := DefaultModelRatio[model]; ok {
return ratio
}
// Try direct model name
if ratio, ok := ModelRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[name]; ok {
return ratio
}
// Try to find standard model name for alias lookup
standardName := getStandardModelNameForBilling(name, channelType)
if standardName != name {
// Try standard model name
if ratio, ok := ModelRatio[standardName]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[standardName]; ok {
return ratio
}
// Try standard model with channel type
standardModel := fmt.Sprintf("%s(%d)", standardName, channelType)
if ratio, ok := ModelRatio[standardModel]; ok {
return ratio
}
if ratio, ok := DefaultModelRatio[standardModel]; ok {
return ratio
}
}
logger.SysError("model ratio not found: " + name)
return 30
}
// getStandardModelNameForBilling returns the standard model name for billing lookup
func getStandardModelNameForBilling(actualName string, channelType int) string {
// Reverse alias mapping for billing
aliasMap := getBillingAliasMap(channelType)
for standard, actual := range aliasMap {
if actual == actualName {
return standard
}
}
return actualName
}
// getBillingAliasMap returns alias mapping for billing purposes
func getBillingAliasMap(channelType int) map[string]string {
switch channelType {
case 24: // OpenRouter
return map[string]string{
"gpt-4o": "openai/gpt-4o",
"gpt-4o-mini": "openai/gpt-4o-mini",
"gpt-4": "openai/gpt-4",
"gpt-4-turbo": "openai/gpt-4-turbo",
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
"claude-3-haiku": "anthropic/claude-3-haiku",
"claude-3-sonnet": "anthropic/claude-3-sonnet",
"claude-3-opus": "anthropic/claude-3-opus",
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
}
case 18: // Anthropic
return map[string]string{
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-opus": "claude-3-opus-20240229",
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
}
default:
return map[string]string{}
}
}
func CompletionRatio2JSONString() string {
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {

View File

@ -151,6 +151,64 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo
return modelName, false
}
// resolveModelAlias resolves standard model names to channel-specific names
func resolveModelAlias(modelName string, channelType int) string {
// Lightweight alias resolution to avoid circular imports
aliasMap := getChannelModelAliases(channelType)
if actualName, exists := aliasMap[modelName]; exists {
return actualName
}
return modelName
}
// getChannelModelAliases returns model aliases for specific channel type
func getChannelModelAliases(channelType int) map[string]string {
switch channelType {
case 24: // OpenRouter
return map[string]string{
"gpt-4o": "openai/gpt-4o",
"gpt-4o-mini": "openai/gpt-4o-mini",
"gpt-4": "openai/gpt-4",
"gpt-4-turbo": "openai/gpt-4-turbo",
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
"gpt-3.5-turbo-0125": "openai/gpt-3.5-turbo-0125",
"o1": "openai/o1",
"o1-mini": "openai/o1-mini",
"o1-preview": "openai/o1-preview",
"claude-3-haiku": "anthropic/claude-3-haiku",
"claude-3-sonnet": "anthropic/claude-3-sonnet",
"claude-3-opus": "anthropic/claude-3-opus",
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
"claude-3.5-haiku": "anthropic/claude-3.5-haiku",
}
case 18: // Anthropic
return map[string]string{
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-opus": "claude-3-opus-20240229",
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
"claude-3.5-haiku": "claude-3-5-haiku-20241022",
}
case 28: // Gemini
return map[string]string{
"gemini-pro": "gemini-pro",
"gemini-pro-1.5": "gemini-1.5-pro-latest",
"gemini-flash-1.5": "gemini-1.5-flash-latest",
}
case 33: // Groq
return map[string]string{
"llama-3-8b-instruct": "llama3-8b-8192",
"llama-3-70b-instruct": "llama3-70b-8192",
"llama-3.1-8b-instruct": "llama-3.1-8b-instant",
"llama-3.1-70b-instruct": "llama-3.1-70b-versatile",
}
default:
return map[string]string{}
}
}
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
if resp == nil {
if meta.ChannelType == channeltype.AwsClaude {

View File

@ -35,6 +35,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// map model name
meta.OriginModelName = textRequest.Model
// First resolve model alias (standard name to channel-specific name)
textRequest.Model = resolveModelAlias(textRequest.Model, meta.ChannelType)
// Then apply channel-specific model mapping (if configured)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// set system prompt if not empty

212
relay/model/alias.go Normal file
View File

@ -0,0 +1,212 @@
package model
import (
"fmt"
"sync"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channeltype"
)
// ModelAliasMap maps standard model names to channel-specific names
// Format: standardName -> channelType -> actualName
var ModelAliasMap = map[string]map[int]string{
// OpenAI GPT models
"gpt-3.5-turbo": {
channeltype.OpenRouter: "openai/gpt-3.5-turbo",
channeltype.OpenAI: "gpt-3.5-turbo",
},
"gpt-3.5-turbo-0125": {
channeltype.OpenRouter: "openai/gpt-3.5-turbo-0125",
channeltype.OpenAI: "gpt-3.5-turbo-0125",
},
"gpt-4": {
channeltype.OpenRouter: "openai/gpt-4",
channeltype.OpenAI: "gpt-4",
},
"gpt-4-turbo": {
channeltype.OpenRouter: "openai/gpt-4-turbo",
channeltype.OpenAI: "gpt-4-turbo",
},
"gpt-4o": {
channeltype.OpenRouter: "openai/gpt-4o",
channeltype.OpenAI: "gpt-4o",
},
"gpt-4o-mini": {
channeltype.OpenRouter: "openai/gpt-4o-mini",
channeltype.OpenAI: "gpt-4o-mini",
},
"o1": {
channeltype.OpenRouter: "openai/o1",
channeltype.OpenAI: "o1",
},
"o1-mini": {
channeltype.OpenRouter: "openai/o1-mini",
channeltype.OpenAI: "o1-mini",
},
"o1-preview": {
channeltype.OpenRouter: "openai/o1-preview",
channeltype.OpenAI: "o1-preview",
},
// Anthropic Claude models
"claude-3-haiku": {
channeltype.OpenRouter: "anthropic/claude-3-haiku",
channeltype.Anthropic: "claude-3-haiku-20240307",
},
"claude-3-sonnet": {
channeltype.OpenRouter: "anthropic/claude-3-sonnet",
channeltype.Anthropic: "claude-3-sonnet-20240229",
},
"claude-3-opus": {
channeltype.OpenRouter: "anthropic/claude-3-opus",
channeltype.Anthropic: "claude-3-opus-20240229",
},
"claude-3.5-sonnet": {
channeltype.OpenRouter: "anthropic/claude-3.5-sonnet",
channeltype.Anthropic: "claude-3-5-sonnet-20241022",
},
"claude-3.5-haiku": {
channeltype.OpenRouter: "anthropic/claude-3.5-haiku",
channeltype.Anthropic: "claude-3-5-haiku-20241022",
},
// Google models
"gemini-pro": {
channeltype.OpenRouter: "google/gemini-pro",
channeltype.Gemini: "gemini-pro",
},
"gemini-pro-1.5": {
channeltype.OpenRouter: "google/gemini-pro-1.5",
channeltype.Gemini: "gemini-1.5-pro-latest",
},
"gemini-flash-1.5": {
channeltype.OpenRouter: "google/gemini-flash-1.5",
channeltype.Gemini: "gemini-1.5-flash-latest",
},
// Meta Llama models
"llama-3-8b-instruct": {
channeltype.OpenRouter: "meta-llama/llama-3-8b-instruct",
channeltype.Groq: "llama3-8b-8192",
},
"llama-3-70b-instruct": {
channeltype.OpenRouter: "meta-llama/llama-3-70b-instruct",
channeltype.Groq: "llama3-70b-8192",
},
"llama-3.1-8b-instruct": {
channeltype.OpenRouter: "meta-llama/llama-3.1-8b-instruct",
channeltype.Groq: "llama-3.1-8b-instant",
},
"llama-3.1-70b-instruct": {
channeltype.OpenRouter: "meta-llama/llama-3.1-70b-instruct",
channeltype.Groq: "llama-3.1-70b-versatile",
},
// Mistral models
"mistral-7b-instruct": {
channeltype.OpenRouter: "mistralai/mistral-7b-instruct",
channeltype.Mistral: "mistral-small-latest",
},
"mixtral-8x7b-instruct": {
channeltype.OpenRouter: "mistralai/mixtral-8x7b-instruct",
channeltype.Mistral: "mixtral-8x7b-instruct-v0.1",
},
}
var aliasLock sync.RWMutex
// ResolveModelAlias resolves a standard model name to channel-specific name
func ResolveModelAlias(standardName string, channelType int) string {
aliasLock.RLock()
defer aliasLock.RUnlock()
if channelMap, exists := ModelAliasMap[standardName]; exists {
if actualName, exists := channelMap[channelType]; exists {
return actualName
}
}
// If no alias found, return original name
return standardName
}
// RegisterModelAlias registers a new model alias
func RegisterModelAlias(standardName, actualName string, channelType int) {
aliasLock.Lock()
defer aliasLock.Unlock()
if ModelAliasMap[standardName] == nil {
ModelAliasMap[standardName] = make(map[int]string)
}
ModelAliasMap[standardName][channelType] = actualName
logger.SysLog(fmt.Sprintf("registered alias: %s -> %s (channel type %d)", standardName, actualName, channelType))
}
// GetAllModelAliases returns all registered aliases
func GetAllModelAliases() map[string]map[int]string {
aliasLock.RLock()
defer aliasLock.RUnlock()
// Return a copy to prevent external modification
result := make(map[string]map[int]string)
for standard, channelMap := range ModelAliasMap {
result[standard] = make(map[int]string)
for channelType, actual := range channelMap {
result[standard][channelType] = actual
}
}
return result
}
// GetStandardModelName returns the standard name for a channel-specific model
func GetStandardModelName(actualName string, channelType int) string {
aliasLock.RLock()
defer aliasLock.RUnlock()
for standard, channelMap := range ModelAliasMap {
if channelMap[channelType] == actualName {
return standard
}
}
// If no reverse mapping found, return original name
return actualName
}
// IsAliasSupported checks if a standard model name has aliases for given channel type
func IsAliasSupported(standardName string, channelType int) bool {
aliasLock.RLock()
defer aliasLock.RUnlock()
if channelMap, exists := ModelAliasMap[standardName]; exists {
_, exists := channelMap[channelType]
return exists
}
return false
}
// GetSupportedChannelTypes returns all channel types that support the given standard model
func GetSupportedChannelTypes(standardName string) []int {
aliasLock.RLock()
defer aliasLock.RUnlock()
var channelTypes []int
if channelMap, exists := ModelAliasMap[standardName]; exists {
for channelType := range channelMap {
channelTypes = append(channelTypes, channelType)
}
}
return channelTypes
}
// LoadModelAliasesFromConfig loads aliases from configuration (placeholder for future implementation)
func LoadModelAliasesFromConfig() error {
// TODO: Implement loading from database or configuration file
logger.SysLog("loaded model aliases from static configuration")
return nil
}

171
relay/model/alias_test.go Normal file
View File

@ -0,0 +1,171 @@
package model
import (
"testing"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/stretchr/testify/assert"
)
func TestResolveModelAlias(t *testing.T) {
tests := []struct {
name string
standardName string
channelType int
expected string
}{
{
name: "OpenRouter gpt-4o alias",
standardName: "gpt-4o",
channelType: channeltype.OpenRouter,
expected: "openai/gpt-4o",
},
{
name: "OpenAI gpt-4o direct",
standardName: "gpt-4o",
channelType: channeltype.OpenAI,
expected: "gpt-4o",
},
{
name: "Anthropic Claude alias",
standardName: "claude-3-sonnet",
channelType: channeltype.Anthropic,
expected: "claude-3-sonnet-20240229",
},
{
name: "Non-existent alias",
standardName: "non-existent-model",
channelType: channeltype.OpenRouter,
expected: "non-existent-model",
},
{
name: "Unsupported channel type",
standardName: "gpt-4o",
channelType: 999,
expected: "gpt-4o",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ResolveModelAlias(tt.standardName, tt.channelType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRegisterModelAlias(t *testing.T) {
originalMap := ModelAliasMap
defer func() { ModelAliasMap = originalMap }()
// Clear the map for testing
ModelAliasMap = make(map[string]map[int]string)
RegisterModelAlias("test-model", "actual-test-model", channeltype.OpenRouter)
assert.Contains(t, ModelAliasMap, "test-model")
assert.Equal(t, "actual-test-model", ModelAliasMap["test-model"][channeltype.OpenRouter])
}
func TestGetStandardModelName(t *testing.T) {
tests := []struct {
name string
actualName string
channelType int
expected string
}{
{
name: "OpenRouter reverse lookup",
actualName: "openai/gpt-4o",
channelType: channeltype.OpenRouter,
expected: "gpt-4o",
},
{
name: "Anthropic reverse lookup",
actualName: "claude-3-sonnet-20240229",
channelType: channeltype.Anthropic,
expected: "claude-3-sonnet",
},
{
name: "No reverse mapping",
actualName: "unknown-model",
channelType: channeltype.OpenRouter,
expected: "unknown-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetStandardModelName(tt.actualName, tt.channelType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsAliasSupported(t *testing.T) {
tests := []struct {
name string
standardName string
channelType int
expected bool
}{
{
name: "Supported alias",
standardName: "gpt-4o",
channelType: channeltype.OpenRouter,
expected: true,
},
{
name: "Unsupported channel",
standardName: "gpt-4o",
channelType: 999,
expected: false,
},
{
name: "Unsupported model",
standardName: "non-existent",
channelType: channeltype.OpenRouter,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsAliasSupported(tt.standardName, tt.channelType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetSupportedChannelTypes(t *testing.T) {
channelTypes := GetSupportedChannelTypes("gpt-4o")
assert.Contains(t, channelTypes, channeltype.OpenRouter)
assert.Contains(t, channelTypes, channeltype.OpenAI)
assert.NotContains(t, channelTypes, 999) // non-existent channel type
emptyTypes := GetSupportedChannelTypes("non-existent-model")
assert.Empty(t, emptyTypes)
}
func TestGetAllModelAliases(t *testing.T) {
aliases := GetAllModelAliases()
assert.NotNil(t, aliases)
assert.Contains(t, aliases, "gpt-4o")
assert.Contains(t, aliases["gpt-4o"], channeltype.OpenRouter)
assert.Equal(t, "openai/gpt-4o", aliases["gpt-4o"][channeltype.OpenRouter])
}
// Benchmark tests
func BenchmarkResolveModelAlias(b *testing.B) {
for i := 0; i < b.N; i++ {
ResolveModelAlias("gpt-4o", channeltype.OpenRouter)
}
}
func BenchmarkGetStandardModelName(b *testing.B) {
for i := 0; i < b.N; i++ {
GetStandardModelName("openai/gpt-4o", channeltype.OpenRouter)
}
}