mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
- Add comprehensive model alias mapping system in relay/model/alias.go - Support standard model names (gpt-4o, claude-3-sonnet) across different channels - Modify AddAbilities() to create aliases for both standard and channel-specific names - Update text processing to resolve aliases before channel-specific mapping - Enhance billing system to support alias-based model ratio lookup - Add comprehensive tests for alias resolution and reverse mapping - Support major providers: OpenRouter, Anthropic, Gemini, Groq This allows users to use consistent model names (e.g., 'gpt-4o') regardless of channel provider, with automatic mapping to channel-specific names (e.g., 'openai/gpt-4o' for OpenRouter).
172 lines
4.1 KiB
Go
172 lines
4.1 KiB
Go
package model
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestResolveModelAlias(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
standardName string
|
|
channelType int
|
|
expected string
|
|
}{
|
|
{
|
|
name: "OpenRouter gpt-4o alias",
|
|
standardName: "gpt-4o",
|
|
channelType: channeltype.OpenRouter,
|
|
expected: "openai/gpt-4o",
|
|
},
|
|
{
|
|
name: "OpenAI gpt-4o direct",
|
|
standardName: "gpt-4o",
|
|
channelType: channeltype.OpenAI,
|
|
expected: "gpt-4o",
|
|
},
|
|
{
|
|
name: "Anthropic Claude alias",
|
|
standardName: "claude-3-sonnet",
|
|
channelType: channeltype.Anthropic,
|
|
expected: "claude-3-sonnet-20240229",
|
|
},
|
|
{
|
|
name: "Non-existent alias",
|
|
standardName: "non-existent-model",
|
|
channelType: channeltype.OpenRouter,
|
|
expected: "non-existent-model",
|
|
},
|
|
{
|
|
name: "Unsupported channel type",
|
|
standardName: "gpt-4o",
|
|
channelType: 999,
|
|
expected: "gpt-4o",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := ResolveModelAlias(tt.standardName, tt.channelType)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterModelAlias(t *testing.T) {
|
|
originalMap := ModelAliasMap
|
|
defer func() { ModelAliasMap = originalMap }()
|
|
|
|
// Clear the map for testing
|
|
ModelAliasMap = make(map[string]map[int]string)
|
|
|
|
RegisterModelAlias("test-model", "actual-test-model", channeltype.OpenRouter)
|
|
|
|
assert.Contains(t, ModelAliasMap, "test-model")
|
|
assert.Equal(t, "actual-test-model", ModelAliasMap["test-model"][channeltype.OpenRouter])
|
|
}
|
|
|
|
func TestGetStandardModelName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
actualName string
|
|
channelType int
|
|
expected string
|
|
}{
|
|
{
|
|
name: "OpenRouter reverse lookup",
|
|
actualName: "openai/gpt-4o",
|
|
channelType: channeltype.OpenRouter,
|
|
expected: "gpt-4o",
|
|
},
|
|
{
|
|
name: "Anthropic reverse lookup",
|
|
actualName: "claude-3-sonnet-20240229",
|
|
channelType: channeltype.Anthropic,
|
|
expected: "claude-3-sonnet",
|
|
},
|
|
{
|
|
name: "No reverse mapping",
|
|
actualName: "unknown-model",
|
|
channelType: channeltype.OpenRouter,
|
|
expected: "unknown-model",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetStandardModelName(tt.actualName, tt.channelType)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsAliasSupported(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
standardName string
|
|
channelType int
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "Supported alias",
|
|
standardName: "gpt-4o",
|
|
channelType: channeltype.OpenRouter,
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "Unsupported channel",
|
|
standardName: "gpt-4o",
|
|
channelType: 999,
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "Unsupported model",
|
|
standardName: "non-existent",
|
|
channelType: channeltype.OpenRouter,
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := IsAliasSupported(tt.standardName, tt.channelType)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetSupportedChannelTypes(t *testing.T) {
|
|
channelTypes := GetSupportedChannelTypes("gpt-4o")
|
|
|
|
assert.Contains(t, channelTypes, channeltype.OpenRouter)
|
|
assert.Contains(t, channelTypes, channeltype.OpenAI)
|
|
assert.NotContains(t, channelTypes, 999) // non-existent channel type
|
|
|
|
emptyTypes := GetSupportedChannelTypes("non-existent-model")
|
|
assert.Empty(t, emptyTypes)
|
|
}
|
|
|
|
func TestGetAllModelAliases(t *testing.T) {
|
|
aliases := GetAllModelAliases()
|
|
|
|
assert.NotNil(t, aliases)
|
|
assert.Contains(t, aliases, "gpt-4o")
|
|
assert.Contains(t, aliases["gpt-4o"], channeltype.OpenRouter)
|
|
assert.Equal(t, "openai/gpt-4o", aliases["gpt-4o"][channeltype.OpenRouter])
|
|
}
|
|
|
|
// Benchmark tests
|
|
func BenchmarkResolveModelAlias(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
ResolveModelAlias("gpt-4o", channeltype.OpenRouter)
|
|
}
|
|
}
|
|
|
|
func BenchmarkGetStandardModelName(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
GetStandardModelName("openai/gpt-4o", channeltype.OpenRouter)
|
|
}
|
|
}
|