one-api/providers/mistral/base.go
2024-03-10 01:53:33 +08:00

78 lines
1.7 KiB
Go

package mistral
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"one-api/providers/base"
)
type MistralProviderFactory struct{}
type MistralProvider struct {
base.BaseProvider
}
// 创建 MistralProvider
func (f MistralProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
MistralProvider := CreateMistralProvider(channel, "https://api.mistral.ai")
return MistralProvider
}
// 创建 MistralProvider
func CreateMistralProvider(channel *model.Channel, baseURL string) *MistralProvider {
config := getMistralConfig(baseURL)
return &MistralProvider{
BaseProvider: base.BaseProvider{
Config: config,
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle),
},
}
}
func getMistralConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
}
}
// 请求错误处理
func RequestErrorHandle(resp *http.Response) *types.OpenAIError {
errorResponse := &MistralError{}
err := json.NewDecoder(resp.Body).Decode(errorResponse)
if err != nil {
return nil
}
return errorHandle(errorResponse)
}
// 错误处理
func errorHandle(MistralError *MistralError) *types.OpenAIError {
if MistralError.Object != "error" {
return nil
}
return &types.OpenAIError{
Message: MistralError.Message.Detail[0].errorMsg(),
Type: MistralError.Type,
}
}
// 获取请求头
func (p *MistralProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
return headers
}