feat: Optimize model list.

This commit is contained in:
MartialBE
2024-01-01 14:36:58 +08:00
committed by Buer
parent bf5ba315ee
commit 7ef4a7db59
9 changed files with 327 additions and 138 deletions

View File

@@ -38,37 +38,35 @@ type OpenAIModels struct {
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
var modelOwnedBy map[int]string
func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
keys := make([]string, 0, len(common.ModelRatio))
for k := range common.ModelRatio {
keys = append(keys, k)
}
sort.Strings(keys)
for _, modelId := range keys {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: nil,
Permission: nil,
Root: nil,
Parent: nil,
})
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
modelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
}
}
func ListModels(c *gin.Context) {
groupName := c.GetString("group")
if groupName == "" {
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return
}
groupName = user.Group
}
models, err := model.CacheGetGroupModels(groupName)
if err != nil {
@@ -83,13 +81,18 @@ func ListModels(c *gin.Context) {
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: nil,
OwnedBy: getModelOwnedBy(modelId),
Permission: nil,
Root: nil,
Parent: nil,
})
}
// 根据 OwnedBy 排序
sort.Slice(groupOpenAIModels, func(i, j int) bool {
return *groupOpenAIModels[i].OwnedBy < *groupOpenAIModels[j].OwnedBy
})
c.JSON(200, gin.H{
"object": "list",
"data": groupOpenAIModels,
@@ -97,6 +100,23 @@ func ListModels(c *gin.Context) {
}
func ListModelsForAdmin(c *gin.Context) {
openAIModels := make([]OpenAIModels, 0, len(common.ModelTypes))
for modelId := range common.ModelRatio {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: getModelOwnedBy(modelId),
Permission: nil,
Root: nil,
Parent: nil,
})
}
// 根据 OwnedBy 排序
sort.Slice(openAIModels, func(i, j int) bool {
return *openAIModels[i].OwnedBy < *openAIModels[j].OwnedBy
})
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
@@ -105,8 +125,17 @@ func ListModelsForAdmin(c *gin.Context) {
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
ownedByName := getModelOwnedBy(modelId)
if ownedByName != nil {
c.JSON(200, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: ownedByName,
Permission: nil,
Root: nil,
Parent: nil,
})
} else {
openAIError := types.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
@@ -119,3 +148,13 @@ func RetrieveModel(c *gin.Context) {
})
}
}
func getModelOwnedBy(modelId string) (ownedBy *string) {
if modelType, ok := common.ModelTypes[modelId]; ok {
if ownedByName, ok := modelOwnedBy[modelType.Type]; ok {
return &ownedByName
}
}
return
}