mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 01:06:37 +08:00
Merge 18fc5fc0be
into 8df4a2670b
This commit is contained in:
commit
b40d55ee3f
@ -54,8 +54,12 @@ func (channel *Channel) AddAbilities() error {
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
models_ = utils.DeDuplication(models_)
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
|
||||
// Expand models to include aliases for standard names
|
||||
expandedModels := expandModelsWithAliases(models_, channel.Type)
|
||||
|
||||
abilities := make([]Ability, 0, len(expandedModels))
|
||||
for _, model := range expandedModels {
|
||||
for _, group := range groups_ {
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
@ -70,6 +74,73 @@ func (channel *Channel) AddAbilities() error {
|
||||
return DB.Create(&abilities).Error
|
||||
}
|
||||
|
||||
// expandModelsWithAliases expands model list to include standard names for channel-specific models
|
||||
func expandModelsWithAliases(models []string, channelType int) []string {
|
||||
expandedModels := make([]string, 0)
|
||||
modelSet := make(map[string]bool)
|
||||
|
||||
for _, model := range models {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add original model
|
||||
if !modelSet[model] {
|
||||
expandedModels = append(expandedModels, model)
|
||||
modelSet[model] = true
|
||||
}
|
||||
|
||||
// Try to find standard name for this channel-specific model
|
||||
standardName := getStandardModelNameForChannel(model, channelType)
|
||||
if standardName != model && !modelSet[standardName] {
|
||||
expandedModels = append(expandedModels, standardName)
|
||||
modelSet[standardName] = true
|
||||
}
|
||||
}
|
||||
|
||||
return expandedModels
|
||||
}
|
||||
|
||||
// getStandardModelNameForChannel returns the standard name for a channel-specific model
|
||||
func getStandardModelNameForChannel(actualName string, channelType int) string {
|
||||
// Import alias mapping (we'll create a lightweight version here to avoid circular imports)
|
||||
aliasMap := getModelAliasesForChannelType(channelType)
|
||||
|
||||
for standard, actual := range aliasMap {
|
||||
if actual == actualName {
|
||||
return standard
|
||||
}
|
||||
}
|
||||
|
||||
return actualName
|
||||
}
|
||||
|
||||
// getModelAliasesForChannelType returns model aliases for specific channel type
|
||||
// This is a lightweight version to avoid importing the full alias module
|
||||
func getModelAliasesForChannelType(channelType int) map[string]string {
|
||||
switch channelType {
|
||||
case 24: // OpenRouter
|
||||
return map[string]string{
|
||||
"gpt-4o": "openai/gpt-4o",
|
||||
"gpt-4o-mini": "openai/gpt-4o-mini",
|
||||
"gpt-4": "openai/gpt-4",
|
||||
"gpt-4-turbo": "openai/gpt-4-turbo",
|
||||
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
|
||||
"o1": "openai/o1",
|
||||
"o1-mini": "openai/o1-mini",
|
||||
"o1-preview": "openai/o1-preview",
|
||||
"claude-3-haiku": "anthropic/claude-3-haiku",
|
||||
"claude-3-sonnet": "anthropic/claude-3-sonnet",
|
||||
"claude-3-opus": "anthropic/claude-3-opus",
|
||||
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
|
||||
"claude-3.5-haiku": "anthropic/claude-3.5-haiku",
|
||||
}
|
||||
default:
|
||||
return map[string]string{}
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) DeleteAbilities() error {
|
||||
return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
|
||||
}
|
||||
|
@ -692,6 +692,8 @@ func GetModelRatio(name string, channelType int) float64 {
|
||||
if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") {
|
||||
name = strings.TrimSuffix(name, "-internet")
|
||||
}
|
||||
|
||||
// Try channel-specific model ratio first
|
||||
model := fmt.Sprintf("%s(%d)", name, channelType)
|
||||
if ratio, ok := ModelRatio[model]; ok {
|
||||
return ratio
|
||||
@ -699,16 +701,81 @@ func GetModelRatio(name string, channelType int) float64 {
|
||||
if ratio, ok := DefaultModelRatio[model]; ok {
|
||||
return ratio
|
||||
}
|
||||
|
||||
// Try direct model name
|
||||
if ratio, ok := ModelRatio[name]; ok {
|
||||
return ratio
|
||||
}
|
||||
if ratio, ok := DefaultModelRatio[name]; ok {
|
||||
return ratio
|
||||
}
|
||||
|
||||
// Try to find standard model name for alias lookup
|
||||
standardName := getStandardModelNameForBilling(name, channelType)
|
||||
if standardName != name {
|
||||
// Try standard model name
|
||||
if ratio, ok := ModelRatio[standardName]; ok {
|
||||
return ratio
|
||||
}
|
||||
if ratio, ok := DefaultModelRatio[standardName]; ok {
|
||||
return ratio
|
||||
}
|
||||
|
||||
// Try standard model with channel type
|
||||
standardModel := fmt.Sprintf("%s(%d)", standardName, channelType)
|
||||
if ratio, ok := ModelRatio[standardModel]; ok {
|
||||
return ratio
|
||||
}
|
||||
if ratio, ok := DefaultModelRatio[standardModel]; ok {
|
||||
return ratio
|
||||
}
|
||||
}
|
||||
|
||||
logger.SysError("model ratio not found: " + name)
|
||||
return 30
|
||||
}
|
||||
|
||||
// getStandardModelNameForBilling returns the standard model name for billing lookup
|
||||
func getStandardModelNameForBilling(actualName string, channelType int) string {
|
||||
// Reverse alias mapping for billing
|
||||
aliasMap := getBillingAliasMap(channelType)
|
||||
|
||||
for standard, actual := range aliasMap {
|
||||
if actual == actualName {
|
||||
return standard
|
||||
}
|
||||
}
|
||||
|
||||
return actualName
|
||||
}
|
||||
|
||||
// getBillingAliasMap returns alias mapping for billing purposes
|
||||
func getBillingAliasMap(channelType int) map[string]string {
|
||||
switch channelType {
|
||||
case 24: // OpenRouter
|
||||
return map[string]string{
|
||||
"gpt-4o": "openai/gpt-4o",
|
||||
"gpt-4o-mini": "openai/gpt-4o-mini",
|
||||
"gpt-4": "openai/gpt-4",
|
||||
"gpt-4-turbo": "openai/gpt-4-turbo",
|
||||
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
|
||||
"claude-3-haiku": "anthropic/claude-3-haiku",
|
||||
"claude-3-sonnet": "anthropic/claude-3-sonnet",
|
||||
"claude-3-opus": "anthropic/claude-3-opus",
|
||||
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
|
||||
}
|
||||
case 18: // Anthropic
|
||||
return map[string]string{
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
|
||||
}
|
||||
default:
|
||||
return map[string]string{}
|
||||
}
|
||||
}
|
||||
|
||||
func CompletionRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(CompletionRatio)
|
||||
if err != nil {
|
||||
|
@ -151,6 +151,64 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo
|
||||
return modelName, false
|
||||
}
|
||||
|
||||
// resolveModelAlias resolves standard model names to channel-specific names
|
||||
func resolveModelAlias(modelName string, channelType int) string {
|
||||
// Lightweight alias resolution to avoid circular imports
|
||||
aliasMap := getChannelModelAliases(channelType)
|
||||
|
||||
if actualName, exists := aliasMap[modelName]; exists {
|
||||
return actualName
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
// getChannelModelAliases returns model aliases for specific channel type
|
||||
func getChannelModelAliases(channelType int) map[string]string {
|
||||
switch channelType {
|
||||
case 24: // OpenRouter
|
||||
return map[string]string{
|
||||
"gpt-4o": "openai/gpt-4o",
|
||||
"gpt-4o-mini": "openai/gpt-4o-mini",
|
||||
"gpt-4": "openai/gpt-4",
|
||||
"gpt-4-turbo": "openai/gpt-4-turbo",
|
||||
"gpt-3.5-turbo": "openai/gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125": "openai/gpt-3.5-turbo-0125",
|
||||
"o1": "openai/o1",
|
||||
"o1-mini": "openai/o1-mini",
|
||||
"o1-preview": "openai/o1-preview",
|
||||
"claude-3-haiku": "anthropic/claude-3-haiku",
|
||||
"claude-3-sonnet": "anthropic/claude-3-sonnet",
|
||||
"claude-3-opus": "anthropic/claude-3-opus",
|
||||
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
|
||||
"claude-3.5-haiku": "anthropic/claude-3.5-haiku",
|
||||
}
|
||||
case 18: // Anthropic
|
||||
return map[string]string{
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
|
||||
"claude-3.5-haiku": "claude-3-5-haiku-20241022",
|
||||
}
|
||||
case 28: // Gemini
|
||||
return map[string]string{
|
||||
"gemini-pro": "gemini-pro",
|
||||
"gemini-pro-1.5": "gemini-1.5-pro-latest",
|
||||
"gemini-flash-1.5": "gemini-1.5-flash-latest",
|
||||
}
|
||||
case 33: // Groq
|
||||
return map[string]string{
|
||||
"llama-3-8b-instruct": "llama3-8b-8192",
|
||||
"llama-3-70b-instruct": "llama3-70b-8192",
|
||||
"llama-3.1-8b-instruct": "llama-3.1-8b-instant",
|
||||
"llama-3.1-70b-instruct": "llama-3.1-70b-versatile",
|
||||
}
|
||||
default:
|
||||
return map[string]string{}
|
||||
}
|
||||
}
|
||||
|
||||
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
||||
if resp == nil {
|
||||
if meta.ChannelType == channeltype.AwsClaude {
|
||||
|
@ -35,6 +35,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
|
||||
// map model name
|
||||
meta.OriginModelName = textRequest.Model
|
||||
|
||||
// First resolve model alias (standard name to channel-specific name)
|
||||
textRequest.Model = resolveModelAlias(textRequest.Model, meta.ChannelType)
|
||||
|
||||
// Then apply channel-specific model mapping (if configured)
|
||||
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = textRequest.Model
|
||||
// set system prompt if not empty
|
||||
|
212
relay/model/alias.go
Normal file
212
relay/model/alias.go
Normal file
@ -0,0 +1,212 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
// ModelAliasMap maps standard model names to channel-specific names
|
||||
// Format: standardName -> channelType -> actualName
|
||||
var ModelAliasMap = map[string]map[int]string{
|
||||
// OpenAI GPT models
|
||||
"gpt-3.5-turbo": {
|
||||
channeltype.OpenRouter: "openai/gpt-3.5-turbo",
|
||||
channeltype.OpenAI: "gpt-3.5-turbo",
|
||||
},
|
||||
"gpt-3.5-turbo-0125": {
|
||||
channeltype.OpenRouter: "openai/gpt-3.5-turbo-0125",
|
||||
channeltype.OpenAI: "gpt-3.5-turbo-0125",
|
||||
},
|
||||
"gpt-4": {
|
||||
channeltype.OpenRouter: "openai/gpt-4",
|
||||
channeltype.OpenAI: "gpt-4",
|
||||
},
|
||||
"gpt-4-turbo": {
|
||||
channeltype.OpenRouter: "openai/gpt-4-turbo",
|
||||
channeltype.OpenAI: "gpt-4-turbo",
|
||||
},
|
||||
"gpt-4o": {
|
||||
channeltype.OpenRouter: "openai/gpt-4o",
|
||||
channeltype.OpenAI: "gpt-4o",
|
||||
},
|
||||
"gpt-4o-mini": {
|
||||
channeltype.OpenRouter: "openai/gpt-4o-mini",
|
||||
channeltype.OpenAI: "gpt-4o-mini",
|
||||
},
|
||||
"o1": {
|
||||
channeltype.OpenRouter: "openai/o1",
|
||||
channeltype.OpenAI: "o1",
|
||||
},
|
||||
"o1-mini": {
|
||||
channeltype.OpenRouter: "openai/o1-mini",
|
||||
channeltype.OpenAI: "o1-mini",
|
||||
},
|
||||
"o1-preview": {
|
||||
channeltype.OpenRouter: "openai/o1-preview",
|
||||
channeltype.OpenAI: "o1-preview",
|
||||
},
|
||||
|
||||
// Anthropic Claude models
|
||||
"claude-3-haiku": {
|
||||
channeltype.OpenRouter: "anthropic/claude-3-haiku",
|
||||
channeltype.Anthropic: "claude-3-haiku-20240307",
|
||||
},
|
||||
"claude-3-sonnet": {
|
||||
channeltype.OpenRouter: "anthropic/claude-3-sonnet",
|
||||
channeltype.Anthropic: "claude-3-sonnet-20240229",
|
||||
},
|
||||
"claude-3-opus": {
|
||||
channeltype.OpenRouter: "anthropic/claude-3-opus",
|
||||
channeltype.Anthropic: "claude-3-opus-20240229",
|
||||
},
|
||||
"claude-3.5-sonnet": {
|
||||
channeltype.OpenRouter: "anthropic/claude-3.5-sonnet",
|
||||
channeltype.Anthropic: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
"claude-3.5-haiku": {
|
||||
channeltype.OpenRouter: "anthropic/claude-3.5-haiku",
|
||||
channeltype.Anthropic: "claude-3-5-haiku-20241022",
|
||||
},
|
||||
|
||||
// Google models
|
||||
"gemini-pro": {
|
||||
channeltype.OpenRouter: "google/gemini-pro",
|
||||
channeltype.Gemini: "gemini-pro",
|
||||
},
|
||||
"gemini-pro-1.5": {
|
||||
channeltype.OpenRouter: "google/gemini-pro-1.5",
|
||||
channeltype.Gemini: "gemini-1.5-pro-latest",
|
||||
},
|
||||
"gemini-flash-1.5": {
|
||||
channeltype.OpenRouter: "google/gemini-flash-1.5",
|
||||
channeltype.Gemini: "gemini-1.5-flash-latest",
|
||||
},
|
||||
|
||||
// Meta Llama models
|
||||
"llama-3-8b-instruct": {
|
||||
channeltype.OpenRouter: "meta-llama/llama-3-8b-instruct",
|
||||
channeltype.Groq: "llama3-8b-8192",
|
||||
},
|
||||
"llama-3-70b-instruct": {
|
||||
channeltype.OpenRouter: "meta-llama/llama-3-70b-instruct",
|
||||
channeltype.Groq: "llama3-70b-8192",
|
||||
},
|
||||
"llama-3.1-8b-instruct": {
|
||||
channeltype.OpenRouter: "meta-llama/llama-3.1-8b-instruct",
|
||||
channeltype.Groq: "llama-3.1-8b-instant",
|
||||
},
|
||||
"llama-3.1-70b-instruct": {
|
||||
channeltype.OpenRouter: "meta-llama/llama-3.1-70b-instruct",
|
||||
channeltype.Groq: "llama-3.1-70b-versatile",
|
||||
},
|
||||
|
||||
// Mistral models
|
||||
"mistral-7b-instruct": {
|
||||
channeltype.OpenRouter: "mistralai/mistral-7b-instruct",
|
||||
channeltype.Mistral: "mistral-small-latest",
|
||||
},
|
||||
"mixtral-8x7b-instruct": {
|
||||
channeltype.OpenRouter: "mistralai/mixtral-8x7b-instruct",
|
||||
channeltype.Mistral: "mixtral-8x7b-instruct-v0.1",
|
||||
},
|
||||
}
|
||||
|
||||
var aliasLock sync.RWMutex
|
||||
|
||||
// ResolveModelAlias resolves a standard model name to channel-specific name
|
||||
func ResolveModelAlias(standardName string, channelType int) string {
|
||||
aliasLock.RLock()
|
||||
defer aliasLock.RUnlock()
|
||||
|
||||
if channelMap, exists := ModelAliasMap[standardName]; exists {
|
||||
if actualName, exists := channelMap[channelType]; exists {
|
||||
return actualName
|
||||
}
|
||||
}
|
||||
|
||||
// If no alias found, return original name
|
||||
return standardName
|
||||
}
|
||||
|
||||
// RegisterModelAlias registers a new model alias
|
||||
func RegisterModelAlias(standardName, actualName string, channelType int) {
|
||||
aliasLock.Lock()
|
||||
defer aliasLock.Unlock()
|
||||
|
||||
if ModelAliasMap[standardName] == nil {
|
||||
ModelAliasMap[standardName] = make(map[int]string)
|
||||
}
|
||||
|
||||
ModelAliasMap[standardName][channelType] = actualName
|
||||
logger.SysLog(fmt.Sprintf("registered alias: %s -> %s (channel type %d)", standardName, actualName, channelType))
|
||||
}
|
||||
|
||||
// GetAllModelAliases returns all registered aliases
|
||||
func GetAllModelAliases() map[string]map[int]string {
|
||||
aliasLock.RLock()
|
||||
defer aliasLock.RUnlock()
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make(map[string]map[int]string)
|
||||
for standard, channelMap := range ModelAliasMap {
|
||||
result[standard] = make(map[int]string)
|
||||
for channelType, actual := range channelMap {
|
||||
result[standard][channelType] = actual
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetStandardModelName returns the standard name for a channel-specific model
|
||||
func GetStandardModelName(actualName string, channelType int) string {
|
||||
aliasLock.RLock()
|
||||
defer aliasLock.RUnlock()
|
||||
|
||||
for standard, channelMap := range ModelAliasMap {
|
||||
if channelMap[channelType] == actualName {
|
||||
return standard
|
||||
}
|
||||
}
|
||||
|
||||
// If no reverse mapping found, return original name
|
||||
return actualName
|
||||
}
|
||||
|
||||
// IsAliasSupported checks if a standard model name has aliases for given channel type
|
||||
func IsAliasSupported(standardName string, channelType int) bool {
|
||||
aliasLock.RLock()
|
||||
defer aliasLock.RUnlock()
|
||||
|
||||
if channelMap, exists := ModelAliasMap[standardName]; exists {
|
||||
_, exists := channelMap[channelType]
|
||||
return exists
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSupportedChannelTypes returns all channel types that support the given standard model
|
||||
func GetSupportedChannelTypes(standardName string) []int {
|
||||
aliasLock.RLock()
|
||||
defer aliasLock.RUnlock()
|
||||
|
||||
var channelTypes []int
|
||||
if channelMap, exists := ModelAliasMap[standardName]; exists {
|
||||
for channelType := range channelMap {
|
||||
channelTypes = append(channelTypes, channelType)
|
||||
}
|
||||
}
|
||||
|
||||
return channelTypes
|
||||
}
|
||||
|
||||
// LoadModelAliasesFromConfig loads aliases from configuration (placeholder for future implementation)
|
||||
func LoadModelAliasesFromConfig() error {
|
||||
// TODO: Implement loading from database or configuration file
|
||||
logger.SysLog("loaded model aliases from static configuration")
|
||||
return nil
|
||||
}
|
171
relay/model/alias_test.go
Normal file
171
relay/model/alias_test.go
Normal file
@ -0,0 +1,171 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResolveModelAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
standardName string
|
||||
channelType int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "OpenRouter gpt-4o alias",
|
||||
standardName: "gpt-4o",
|
||||
channelType: channeltype.OpenRouter,
|
||||
expected: "openai/gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "OpenAI gpt-4o direct",
|
||||
standardName: "gpt-4o",
|
||||
channelType: channeltype.OpenAI,
|
||||
expected: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "Anthropic Claude alias",
|
||||
standardName: "claude-3-sonnet",
|
||||
channelType: channeltype.Anthropic,
|
||||
expected: "claude-3-sonnet-20240229",
|
||||
},
|
||||
{
|
||||
name: "Non-existent alias",
|
||||
standardName: "non-existent-model",
|
||||
channelType: channeltype.OpenRouter,
|
||||
expected: "non-existent-model",
|
||||
},
|
||||
{
|
||||
name: "Unsupported channel type",
|
||||
standardName: "gpt-4o",
|
||||
channelType: 999,
|
||||
expected: "gpt-4o",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ResolveModelAlias(tt.standardName, tt.channelType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterModelAlias(t *testing.T) {
|
||||
originalMap := ModelAliasMap
|
||||
defer func() { ModelAliasMap = originalMap }()
|
||||
|
||||
// Clear the map for testing
|
||||
ModelAliasMap = make(map[string]map[int]string)
|
||||
|
||||
RegisterModelAlias("test-model", "actual-test-model", channeltype.OpenRouter)
|
||||
|
||||
assert.Contains(t, ModelAliasMap, "test-model")
|
||||
assert.Equal(t, "actual-test-model", ModelAliasMap["test-model"][channeltype.OpenRouter])
|
||||
}
|
||||
|
||||
func TestGetStandardModelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
actualName string
|
||||
channelType int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "OpenRouter reverse lookup",
|
||||
actualName: "openai/gpt-4o",
|
||||
channelType: channeltype.OpenRouter,
|
||||
expected: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "Anthropic reverse lookup",
|
||||
actualName: "claude-3-sonnet-20240229",
|
||||
channelType: channeltype.Anthropic,
|
||||
expected: "claude-3-sonnet",
|
||||
},
|
||||
{
|
||||
name: "No reverse mapping",
|
||||
actualName: "unknown-model",
|
||||
channelType: channeltype.OpenRouter,
|
||||
expected: "unknown-model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetStandardModelName(tt.actualName, tt.channelType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAliasSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
standardName string
|
||||
channelType int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Supported alias",
|
||||
standardName: "gpt-4o",
|
||||
channelType: channeltype.OpenRouter,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Unsupported channel",
|
||||
standardName: "gpt-4o",
|
||||
channelType: 999,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Unsupported model",
|
||||
standardName: "non-existent",
|
||||
channelType: channeltype.OpenRouter,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsAliasSupported(tt.standardName, tt.channelType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSupportedChannelTypes(t *testing.T) {
|
||||
channelTypes := GetSupportedChannelTypes("gpt-4o")
|
||||
|
||||
assert.Contains(t, channelTypes, channeltype.OpenRouter)
|
||||
assert.Contains(t, channelTypes, channeltype.OpenAI)
|
||||
assert.NotContains(t, channelTypes, 999) // non-existent channel type
|
||||
|
||||
emptyTypes := GetSupportedChannelTypes("non-existent-model")
|
||||
assert.Empty(t, emptyTypes)
|
||||
}
|
||||
|
||||
func TestGetAllModelAliases(t *testing.T) {
|
||||
aliases := GetAllModelAliases()
|
||||
|
||||
assert.NotNil(t, aliases)
|
||||
assert.Contains(t, aliases, "gpt-4o")
|
||||
assert.Contains(t, aliases["gpt-4o"], channeltype.OpenRouter)
|
||||
assert.Equal(t, "openai/gpt-4o", aliases["gpt-4o"][channeltype.OpenRouter])
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkResolveModelAlias(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ResolveModelAlias("gpt-4o", channeltype.OpenRouter)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetStandardModelName(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GetStandardModelName("openai/gpt-4o", channeltype.OpenRouter)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user