one-api/relay/model/alias_test.go
longkeyy 18fc5fc0be feat: implement model alias system for unified model names
- Add comprehensive model alias mapping system in relay/model/alias.go
- Support standard model names (gpt-4o, claude-3-sonnet) across different channels
- Modify AddAbilities() to create aliases for both standard and channel-specific names
- Update text processing to resolve aliases before channel-specific mapping
- Enhance billing system to support alias-based model ratio lookup
- Add comprehensive tests for alias resolution and reverse mapping
- Support major providers: OpenRouter, Anthropic, Gemini, Groq

This allows users to use consistent model names (e.g., 'gpt-4o')
regardless of channel provider, with automatic mapping to
channel-specific names (e.g., 'openai/gpt-4o' for OpenRouter).
2025-08-27 02:24:38 +08:00

172 lines
4.1 KiB
Go

package model
import (
"testing"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/stretchr/testify/assert"
)
func TestResolveModelAlias(t *testing.T) {
tests := []struct {
name string
standardName string
channelType int
expected string
}{
{
name: "OpenRouter gpt-4o alias",
standardName: "gpt-4o",
channelType: channeltype.OpenRouter,
expected: "openai/gpt-4o",
},
{
name: "OpenAI gpt-4o direct",
standardName: "gpt-4o",
channelType: channeltype.OpenAI,
expected: "gpt-4o",
},
{
name: "Anthropic Claude alias",
standardName: "claude-3-sonnet",
channelType: channeltype.Anthropic,
expected: "claude-3-sonnet-20240229",
},
{
name: "Non-existent alias",
standardName: "non-existent-model",
channelType: channeltype.OpenRouter,
expected: "non-existent-model",
},
{
name: "Unsupported channel type",
standardName: "gpt-4o",
channelType: 999,
expected: "gpt-4o",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ResolveModelAlias(tt.standardName, tt.channelType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRegisterModelAlias(t *testing.T) {
originalMap := ModelAliasMap
defer func() { ModelAliasMap = originalMap }()
// Clear the map for testing
ModelAliasMap = make(map[string]map[int]string)
RegisterModelAlias("test-model", "actual-test-model", channeltype.OpenRouter)
assert.Contains(t, ModelAliasMap, "test-model")
assert.Equal(t, "actual-test-model", ModelAliasMap["test-model"][channeltype.OpenRouter])
}
func TestGetStandardModelName(t *testing.T) {
tests := []struct {
name string
actualName string
channelType int
expected string
}{
{
name: "OpenRouter reverse lookup",
actualName: "openai/gpt-4o",
channelType: channeltype.OpenRouter,
expected: "gpt-4o",
},
{
name: "Anthropic reverse lookup",
actualName: "claude-3-sonnet-20240229",
channelType: channeltype.Anthropic,
expected: "claude-3-sonnet",
},
{
name: "No reverse mapping",
actualName: "unknown-model",
channelType: channeltype.OpenRouter,
expected: "unknown-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetStandardModelName(tt.actualName, tt.channelType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsAliasSupported(t *testing.T) {
tests := []struct {
name string
standardName string
channelType int
expected bool
}{
{
name: "Supported alias",
standardName: "gpt-4o",
channelType: channeltype.OpenRouter,
expected: true,
},
{
name: "Unsupported channel",
standardName: "gpt-4o",
channelType: 999,
expected: false,
},
{
name: "Unsupported model",
standardName: "non-existent",
channelType: channeltype.OpenRouter,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsAliasSupported(tt.standardName, tt.channelType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetSupportedChannelTypes(t *testing.T) {
channelTypes := GetSupportedChannelTypes("gpt-4o")
assert.Contains(t, channelTypes, channeltype.OpenRouter)
assert.Contains(t, channelTypes, channeltype.OpenAI)
assert.NotContains(t, channelTypes, 999) // non-existent channel type
emptyTypes := GetSupportedChannelTypes("non-existent-model")
assert.Empty(t, emptyTypes)
}
func TestGetAllModelAliases(t *testing.T) {
aliases := GetAllModelAliases()
assert.NotNil(t, aliases)
assert.Contains(t, aliases, "gpt-4o")
assert.Contains(t, aliases["gpt-4o"], channeltype.OpenRouter)
assert.Equal(t, "openai/gpt-4o", aliases["gpt-4o"][channeltype.OpenRouter])
}
// Benchmark tests
func BenchmarkResolveModelAlias(b *testing.B) {
for i := 0; i < b.N; i++ {
ResolveModelAlias("gpt-4o", channeltype.OpenRouter)
}
}
func BenchmarkGetStandardModelName(b *testing.B) {
for i := 0; i < b.N; i++ {
GetStandardModelName("openai/gpt-4o", channeltype.OpenRouter)
}
}