mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 05:33:42 +08:00
✨ feat: Add support for retrieving model list from providers (#188)
* ✨ feat: Add support for retrieving model list from providers * 🔖 chore: Custom channel automatically get the model
This commit is contained in:
@@ -25,6 +25,7 @@ type ProviderConfig struct {
|
||||
ImagesGenerations string
|
||||
ImagesEdit string
|
||||
ImagesVariations string
|
||||
ModelList string
|
||||
}
|
||||
|
||||
type BaseProvider struct {
|
||||
|
||||
@@ -99,6 +99,16 @@ type ImageVariationsInterface interface {
|
||||
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// type RelayInterface interface {
|
||||
// ProviderInterface
|
||||
// CreateRelay() (*http.Response, *types.OpenAIErrorWithStatusCode)
|
||||
// }
|
||||
|
||||
type ModelListInterface interface {
|
||||
ProviderInterface
|
||||
GetModelList() ([]string, error)
|
||||
}
|
||||
|
||||
// 余额接口
|
||||
type BalanceInterface interface {
|
||||
Balance() (float64, error)
|
||||
|
||||
@@ -32,6 +32,7 @@ func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.cohere.ai/v1",
|
||||
ChatCompletions: "/chat",
|
||||
ModelList: "/models",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
35
providers/cohere/model.go
Normal file
35
providers/cohere/model.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (p *CohereProvider) GetModelList() ([]string, error) {
|
||||
params := url.Values{}
|
||||
params.Add("page_size", "1000")
|
||||
params.Add("endpoint", "chat")
|
||||
queryString := params.Encode()
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Config.ModelList) + "?" + queryString
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, errors.New("new_request_failed")
|
||||
}
|
||||
|
||||
response := &ModelListResponse{}
|
||||
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errors.New(errWithCode.Message)
|
||||
}
|
||||
|
||||
var modelList []string
|
||||
for _, model := range response.Models {
|
||||
modelList = append(modelList, model.Name)
|
||||
}
|
||||
|
||||
return modelList, nil
|
||||
}
|
||||
@@ -15,28 +15,34 @@ type CohereConnector struct {
|
||||
}
|
||||
|
||||
type CohereRequest struct {
|
||||
Message string `json:"message"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Preamble string `json:"preamble,omitempty"`
|
||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||
ConversationId string `json:"conversation_id,omitempty"`
|
||||
PromptTruncation string `json:"prompt_truncation,omitempty"`
|
||||
Connectors []CohereConnector `json:"connectors,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
||||
K int `json:"k,omitempty"`
|
||||
P float64 `json:"p,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
StopSequences any `json:"stop_sequences,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
|
||||
ToolResults any `json:"tool_results,omitempty"`
|
||||
// SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Preamble string `json:"preamble,omitempty"`
|
||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||
ConversationId string `json:"conversation_id,omitempty"`
|
||||
PromptTruncation string `json:"prompt_truncation,omitempty"`
|
||||
Connectors []CohereConnector `json:"connectors,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
||||
K int `json:"k,omitempty"`
|
||||
P float64 `json:"p,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
StopSequences any `json:"stop_sequences,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
|
||||
ToolResults any `json:"tool_results,omitempty"`
|
||||
SearchQueriesOnly *bool `json:"search_queries_only,omitempty"`
|
||||
Documents []ChatDocument `json:"documents,omitempty"`
|
||||
CitationQuality *string `json:"citation_quality,omitempty"`
|
||||
RawPrompting *bool `json:"raw_prompting,omitempty"`
|
||||
ReturnPrompt *bool `json:"return_prompt,omitempty"`
|
||||
}
|
||||
|
||||
type ChatDocument = map[string]string
|
||||
|
||||
type APIVersion struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
@@ -60,16 +66,46 @@ type CohereToolCall struct {
|
||||
}
|
||||
|
||||
type CohereResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
ResponseID string `json:"response_id,omitempty"`
|
||||
GenerationID string `json:"generation_id,omitempty"`
|
||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
||||
Meta Meta `json:"meta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ResponseID string `json:"response_id,omitempty"`
|
||||
Citations []*ChatCitation `json:"citations,omitempty"`
|
||||
Documents []ChatDocument `json:"documents,omitempty"`
|
||||
IsSearchRequired *bool `json:"is_search_required,omitempty"`
|
||||
SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty"`
|
||||
SearchResults []*ChatSearchResult `json:"search_results,omitempty"`
|
||||
GenerationID string `json:"generation_id,omitempty"`
|
||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||
Prompt *string `json:"prompt,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
||||
Meta Meta `json:"meta,omitempty"`
|
||||
CohereError
|
||||
}
|
||||
|
||||
type ChatCitation struct {
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
Text string `json:"text"`
|
||||
DocumentIds []string `json:"document_ids,omitempty"`
|
||||
}
|
||||
|
||||
type ChatSearchQuery struct {
|
||||
Text string `json:"text"`
|
||||
GenerationId string `json:"generation_id"`
|
||||
}
|
||||
|
||||
type ChatSearchResult struct {
|
||||
SearchQuery *ChatSearchQuery `json:"search_query,omitempty" url:"search_query,omitempty"`
|
||||
Connector *ChatSearchResultConnector `json:"connector,omitempty" url:"connector,omitempty"`
|
||||
DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"`
|
||||
ErrorMessage *string `json:"error_message,omitempty" url:"error_message,omitempty"`
|
||||
ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"`
|
||||
}
|
||||
|
||||
type ChatSearchResultConnector struct {
|
||||
Id string `json:"id" url:"id"`
|
||||
}
|
||||
|
||||
type CohereError struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
@@ -83,3 +119,77 @@ type CohereStreamResponse struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type RerankRequest struct {
|
||||
Model *string `json:"model,omitempty"`
|
||||
Query string `json:"query" url:"query"`
|
||||
Documents []*RerankRequestDocumentsItem `json:"documents,omitempty"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
RankFields []string `json:"rank_fields,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
MaxChunksPerDoc *int `json:"max_chunks_per_doc,omitempty"`
|
||||
}
|
||||
|
||||
type RerankRequestDocumentsItem struct {
|
||||
String string
|
||||
RerankRequestDocumentsItemText *RerankDocumentsItemText
|
||||
}
|
||||
type RerankDocumentsItemText struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Id *string `json:"id,omitempty"`
|
||||
Results []*RerankResponseResultsItem `json:"results,omitempty"`
|
||||
Meta *Meta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
type RerankResponseResultsItem struct {
|
||||
Document *RerankDocumentsItemText `json:"document,omitempty"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Texts any `json:"texts,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
InputType *string `json:"input_type,omitempty"`
|
||||
EmbeddingTypes []string `json:"embedding_types,omitempty"`
|
||||
Truncate *string `json:"truncate,omitempty"`
|
||||
}
|
||||
|
||||
type EmbedResponse struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
Embeddings any `json:"embeddings"`
|
||||
}
|
||||
|
||||
type EmbedFloatsResponse struct {
|
||||
Id string `json:"id"`
|
||||
Embeddings [][]float64 `json:"embeddings,omitempty"`
|
||||
Texts []string `json:"texts,omitempty"`
|
||||
Meta *Meta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
type EmbedByTypeResponse struct {
|
||||
Id string `json:"id"`
|
||||
Embeddings *EmbedByTypeResponseEmbeddings `json:"embeddings,omitempty"`
|
||||
Texts []string `json:"texts,omitempty"`
|
||||
Meta *Meta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
type EmbedByTypeResponseEmbeddings struct {
|
||||
Float [][]float64 `json:"float,omitempty"`
|
||||
Int8 [][]int `json:"int8,omitempty"`
|
||||
Uint8 [][]int `json:"uint8,omitempty"`
|
||||
Binary [][]int `json:"binary,omitempty"`
|
||||
Ubinary [][]int `json:"ubinary,omitempty"`
|
||||
}
|
||||
|
||||
type ModelListResponse struct {
|
||||
Models []ModelDetails `json:"models"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
Name string `json:"name"`
|
||||
Endpoints []string `json:"endpoints"`
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ func getDeepseekConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
ModelList: "/v1/models",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
ChatCompletions: "/",
|
||||
ModelList: "/models",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
45
providers/gemini/model.go
Normal file
45
providers/gemini/model.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (p *GeminiProvider) GetModelList() ([]string, error) {
|
||||
params := url.Values{}
|
||||
params.Add("page_size", "1000")
|
||||
queryString := params.Encode()
|
||||
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
version := "v1beta"
|
||||
fullRequestURL := fmt.Sprintf("%s/%s%s?%s", baseURL, version, p.Config.ModelList, queryString)
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, errors.New("new_request_failed")
|
||||
}
|
||||
|
||||
response := &ModelListResponse{}
|
||||
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errors.New(errWithCode.Message)
|
||||
}
|
||||
|
||||
var modelList []string
|
||||
for _, model := range response.Models {
|
||||
for _, modelType := range model.SupportedGenerationMethods {
|
||||
if modelType == "generateContent" {
|
||||
modelName := strings.TrimPrefix(model.Name, "models/")
|
||||
modelList = append(modelList, modelName)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modelList, nil
|
||||
}
|
||||
@@ -218,3 +218,12 @@ func ConvertRole(roleName string) string {
|
||||
return types.ChatMessageRoleUser
|
||||
}
|
||||
}
|
||||
|
||||
type ModelListResponse struct {
|
||||
Models []ModelDetails `json:"models"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
Name string `json:"name"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.groq.com/openai",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
ModelList: "/v1/models",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ func getMistralConfig(baseURL string) base.ProviderConfig {
|
||||
BaseURL: baseURL,
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
ModelList: "/v1/models",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
29
providers/mistral/model.go
Normal file
29
providers/mistral/model.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (p *MistralProvider) GetModelList() ([]string, error) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.Config.ModelList, "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, errors.New("new_request_failed")
|
||||
}
|
||||
|
||||
response := &ModelListResponse{}
|
||||
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errors.New(errWithCode.Message)
|
||||
}
|
||||
|
||||
var modelList []string
|
||||
for _, model := range response.Data {
|
||||
modelList = append(modelList, model.Id)
|
||||
}
|
||||
|
||||
return modelList, nil
|
||||
}
|
||||
@@ -53,3 +53,15 @@ type ChatCompletionStreamResponse struct {
|
||||
types.ChatCompletionStreamResponse
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ModelListResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelDetails `json:"data"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
@@ -57,6 +57,7 @@ func getOpenAIConfig(baseURL string) base.ProviderConfig {
|
||||
ImagesGenerations: "/v1/images/generations",
|
||||
ImagesEdit: "/v1/images/edits",
|
||||
ImagesVariations: "/v1/images/variations",
|
||||
ModelList: "/v1/models",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
29
providers/openai/model.go
Normal file
29
providers/openai/model.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) GetModelList() ([]string, error) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.Config.ModelList, "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, errors.New("new_request_failed")
|
||||
}
|
||||
|
||||
response := &ModelListResponse{}
|
||||
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errors.New(errWithCode.Message)
|
||||
}
|
||||
|
||||
var modelList []string
|
||||
for _, model := range response.Data {
|
||||
modelList = append(modelList, model.Id)
|
||||
}
|
||||
|
||||
return modelList, nil
|
||||
}
|
||||
@@ -73,3 +73,15 @@ type OpenAIUsageResponse struct {
|
||||
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
||||
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||||
}
|
||||
|
||||
type ModelListResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelDetails `json:"data"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user