feat: Add support for retrieving model list from providers (#188)

*  feat: Add support for retrieving model list from providers

* 🔖 chore: Custom channel automatically get the model
This commit is contained in:
Buer
2024-05-16 15:21:13 +08:00
committed by GitHub
parent ef63fbfd31
commit 7263582b9b
20 changed files with 444 additions and 31 deletions

View File

@@ -32,6 +32,7 @@ func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/",
ModelList: "/models",
}
}

45
providers/gemini/model.go Normal file
View File

@@ -0,0 +1,45 @@
package gemini
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
)
func (p *GeminiProvider) GetModelList() ([]string, error) {
params := url.Values{}
params.Add("page_size", "1000")
queryString := params.Encode()
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
version := "v1beta"
fullRequestURL := fmt.Sprintf("%s/%s%s?%s", baseURL, version, p.Config.ModelList, queryString)
headers := p.GetRequestHeaders()
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return nil, errors.New("new_request_failed")
}
response := &ModelListResponse{}
_, errWithCode := p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errors.New(errWithCode.Message)
}
var modelList []string
for _, model := range response.Models {
for _, modelType := range model.SupportedGenerationMethods {
if modelType == "generateContent" {
modelName := strings.TrimPrefix(model.Name, "models/")
modelList = append(modelList, modelName)
break
}
}
}
return modelList, nil
}

View File

@@ -218,3 +218,12 @@ func ConvertRole(roleName string) string {
return types.ChatMessageRoleUser
}
}
type ModelListResponse struct {
Models []ModelDetails `json:"models"`
}
type ModelDetails struct {
Name string `json:"name"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
}