feat: support mistral (#94)

This commit is contained in:
Buer
2024-03-10 01:53:33 +08:00
committed by GitHub
parent d8d880bf85
commit 6329db1a49
10 changed files with 340 additions and 0 deletions

77
providers/mistral/base.go Normal file
View File

@@ -0,0 +1,77 @@
package mistral
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"one-api/providers/base"
)
type MistralProviderFactory struct{}
type MistralProvider struct {
base.BaseProvider
}
// 创建 MistralProvider
func (f MistralProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
MistralProvider := CreateMistralProvider(channel, "https://api.mistral.ai")
return MistralProvider
}
// 创建 MistralProvider
func CreateMistralProvider(channel *model.Channel, baseURL string) *MistralProvider {
config := getMistralConfig(baseURL)
return &MistralProvider{
BaseProvider: base.BaseProvider{
Config: config,
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle),
},
}
}
func getMistralConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
}
}
// 请求错误处理
func RequestErrorHandle(resp *http.Response) *types.OpenAIError {
errorResponse := &MistralError{}
err := json.NewDecoder(resp.Body).Decode(errorResponse)
if err != nil {
return nil
}
return errorHandle(errorResponse)
}
// 错误处理
func errorHandle(MistralError *MistralError) *types.OpenAIError {
if MistralError.Object != "error" {
return nil
}
return &types.OpenAIError{
Message: MistralError.Message.Detail[0].errorMsg(),
Type: MistralError.Type,
}
}
// 获取请求头
func (p *MistralProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
return headers
}

137
providers/mistral/chat.go Normal file
View File

@@ -0,0 +1,137 @@
package mistral
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
type mistralStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *MistralProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &types.ChatCompletionResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
*p.Usage = *response.Usage
return response, nil
}
func (p *MistralProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &mistralStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *MistralProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
// 获取请求头
headers := p.GetRequestHeaders()
mistralRequest := convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(mistralRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func convertFromChatOpenai(request *types.ChatCompletionRequest) *MistralChatCompletionRequest {
mistralRequest := &MistralChatCompletionRequest{
Model: request.Model,
Messages: make([]types.ChatCompletionMessage, 0, len(request.Messages)),
Temperature: request.Temperature,
MaxTokens: request.MaxTokens,
TopP: request.TopP,
N: request.N,
Stream: request.Stream,
Seed: request.Seed,
}
for _, message := range request.Messages {
mistralRequest.Messages = append(mistralRequest.Messages, types.ChatCompletionMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
if request.Tools != nil {
mistralRequest.Tools = request.Tools
mistralRequest.ToolChoice = "auto"
}
return mistralRequest
}
// 转换为OpenAI聊天流式请求体
func (h *mistralStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return
}
*rawLine = (*rawLine)[6:]
if string(*rawLine) == "[DONE]" {
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
mistralResponse := &ChatCompletionStreamResponse{}
err := json.Unmarshal(*rawLine, mistralResponse)
if err != nil {
errChan <- common.ErrorToOpenAIError(err)
return
}
if mistralResponse.Usage != nil {
*h.Usage = *mistralResponse.Usage
}
responseBody, _ := json.Marshal(mistralResponse.ChatCompletionStreamResponse)
dataChan <- string(responseBody)
}

View File

@@ -0,0 +1,39 @@
package mistral
import (
"net/http"
"one-api/common"
"one-api/types"
)
func (p *MistralProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_mistral_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer req.Body.Close()
mistralResponse := &types.EmbeddingResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, mistralResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return mistralResponse, nil
}

55
providers/mistral/type.go Normal file
View File

@@ -0,0 +1,55 @@
package mistral
import (
"encoding/json"
"one-api/types"
)
type MistralChatCompletionRequest struct {
Model string `json:"model" binding:"required"`
Messages []types.ChatCompletionMessage `json:"messages" binding:"required"`
Temperature float64 `json:"temperature,omitempty"` // 0-1
MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"` // 0-1
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []*types.ChatCompletionTool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Seed *int `json:"seed,omitempty"`
SafePrompt bool `json:"safe_prompt,omitempty"`
}
type MistralError struct {
Object string `json:"object"`
Type string `json:"type,omitempty"`
Message MistralErrorMessages `json:"message,omitempty"`
}
type MistralErrorMessages struct {
Detail []MistralErrorDetail `json:"detail,omitempty"`
}
type MistralErrorDetail struct {
Type string `json:"type"`
Loc any `json:"loc"`
Msg string `json:"msg"`
Input string `json:"input"`
Ctx any `json:"ctx"`
}
func (m *MistralErrorDetail) errorMsg() string {
// 循环Loc拼接成字符串
// 返回字符串
var errMsg string
locStr, _ := json.Marshal(m.Loc)
errMsg += "Loc:" + string(locStr) + "Msg:" + m.Msg
return errMsg
}
type ChatCompletionStreamResponse struct {
types.ChatCompletionStreamResponse
Usage *types.Usage `json:"usage,omitempty"`
}

View File

@@ -18,6 +18,7 @@ import (
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/minimax"
"one-api/providers/mistral"
"one-api/providers/openai"
"one-api/providers/openaisb"
"one-api/providers/palm"
@@ -58,6 +59,7 @@ func init() {
providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{}
providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{}
providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
}