♻️ refactor: split relay

This commit is contained in:
Martial BE
2023-11-28 18:32:26 +08:00
parent 53da7134b2
commit 902c2faa2c
58 changed files with 4248 additions and 3369 deletions

50
providers/ali_base.go Normal file
View File

@@ -0,0 +1,50 @@
package providers
import (
"fmt"
"github.com/gin-gonic/gin"
)
type AliAIProvider struct {
ProviderConfig
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
// 创建 AliAIProvider
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
func CreateAliAIProvider(c *gin.Context) *AliAIProvider {
return &AliAIProvider{
ProviderConfig: ProviderConfig{
BaseURL: "https://dashscope.aliyuncs.com",
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
Context: c,
},
}
}
// 获取请求头
func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
return headers
}

256
providers/ali_chat.go Normal file
View File

@@ -0,0 +1,256 @@
package providers
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
)
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: aliResponse.Output.Text,
},
FinishReason: aliResponse.Output.FinishReason,
}
fullTextResponse := types.ChatCompletionResponse{
ID: aliResponse.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Usage: &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
CompletionTokens: aliResponse.Usage.OutputTokens,
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
},
}
return fullTextResponse, nil
}
func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: request.Messages[i+1].StringContent(),
})
i++
}
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
},
}
}
func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
headers["X-DashScope-SSE"] = "enable"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
return
}
if usage == nil {
usage = &types.Usage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
}
} else {
aliResponse := &AliChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, aliResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
CompletionTokens: aliResponse.Usage.OutputTokens,
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
}
}
return
}
func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := types.ChatCompletionStreamResponse{
ID: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}
func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
usage = &types.Usage{}
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), nil
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
lastResponseText := ""
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := p.streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, usage
}

View File

@@ -0,0 +1,94 @@
package providers
import (
"net/http"
"one-api/common"
"one-api/types"
)
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if aliResponse.Code != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}
}
openAIEmbeddingResponse := &types.EmbeddingResponse{
Object: "list",
Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens},
}
for _, item := range aliResponse.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return openAIEmbeddingResponse, nil
}
func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getEmbeddingsRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
aliEmbeddingResponse := &AliEmbeddingResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
return usage, nil
}

14
providers/api2d_base.go Normal file
View File

@@ -0,0 +1,14 @@
package providers
import "github.com/gin-gonic/gin"
type Api2dProvider struct {
*OpenAIProvider
}
// 创建 OpenAIProvider
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
return &Api2dProvider{
OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"),
}
}

41
providers/azure_base.go Normal file
View File

@@ -0,0 +1,41 @@
package providers
import (
"github.com/gin-gonic/gin"
)
type AzureProvider struct {
OpenAIProvider
}
// 创建 OpenAIProvider
func CreateAzureProvider(c *gin.Context) *AzureProvider {
return &AzureProvider{
OpenAIProvider: OpenAIProvider{
ProviderConfig: ProviderConfig{
BaseURL: "",
Completions: "/completions",
ChatCompletions: "/chat/completions",
Embeddings: "/embeddings",
AudioSpeech: "/audio/speech",
AudioTranscriptions: "/audio/transcriptions",
AudioTranslations: "/audio/translations",
Context: c,
},
isAzure: true,
},
}
}
// // 获取完整请求 URL
// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string {
// apiVersion := p.Context.GetString("api_version")
// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion)
// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
// }
// return fmt.Sprintf("%s%s", baseURL, requestURL)
// }

136
providers/baidu_base.go Normal file
View File

@@ -0,0 +1,136 @@
package providers
import (
"encoding/json"
"errors"
"fmt"
"one-api/common"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
var baiduTokenStore sync.Map
type BaiduProvider struct {
ProviderConfig
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
return &BaiduProvider{
ProviderConfig: ProviderConfig{
BaseURL: "https://aip.baidubce.com",
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
Context: c,
},
}
}
// 获取完整请求 URL
func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string {
var modelNameMap = map[string]string{
"ERNIE-Bot": "completions",
"ERNIE-Bot-turbo": "eb-instant",
"ERNIE-Bot-4": "completions_pro",
"BLOOMZ-7B": "bloomz_7b1",
"Embedding-V1": "embedding-v1",
}
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
apiKey, err := p.getBaiduAccessToken()
if err != nil {
return ""
}
return fmt.Sprintf("%s%s/%s?access_token=%s", baseURL, requestURL, modelNameMap[modelName], apiKey)
}
// 获取请求头
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
return headers
}
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
apiKey := p.Context.GetString("api_key")
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken
if accessToken, ok = val.(BaiduAccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
_, _ = p.getBaiduAccessTokenHelper(apiKey)
}()
}
return accessToken.AccessToken, nil
}
}
accessToken, err := p.getBaiduAccessTokenHelper(apiKey)
if err != nil {
return "", err
}
if accessToken == nil {
return "", errors.New("getBaiduAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
}
client := common.NewClient()
url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
var headers = map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
}
req, err := client.NewRequest("POST", url, common.WithHeader(headers))
if err != nil {
return nil, err
}
resp, err := common.HttpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var accessToken BaiduAccessToken
err = json.NewDecoder(resp.Body).Decode(&accessToken)
if err != nil {
return nil, err
}
if accessToken.Error != "" {
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
}
if accessToken.AccessToken == "" {
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
}
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
baiduTokenStore.Store(apiKey, accessToken)
return &accessToken, nil
}

228
providers/baidu_chat.go Normal file
View File

@@ -0,0 +1,228 @@
package providers
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
)
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage *types.Usage `json:"usage"`
BaiduError
}
func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if baiduResponse.ErrorMsg != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: baiduResponse.Result,
},
FinishReason: "stop",
}
fullTextResponse := types.ChatCompletionResponse{
ID: baiduResponse.Id,
Object: "chat.completion",
Created: baiduResponse.Created,
Choices: []types.ChatCompletionChoice{choice},
Usage: baiduResponse.Usage,
}
return fullTextResponse, nil
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
}
}
func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
if fullRequestURL == "" {
return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
return
}
} else {
baiduChatRequest := &BaiduChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest)
if openAIErrorWithStatusCode != nil {
return
}
usage = baiduChatRequest.Usage
}
return
}
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason
}
response := types.ChatCompletionStreamResponse{
ID: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
usage = &types.Usage{}
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), nil
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, usage
}

View File

@@ -0,0 +1,88 @@
package providers
import (
"net/http"
"one-api/common"
"one-api/types"
)
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage types.Usage `json:"usage"`
BaiduError
}
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
}
func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if baiduResponse.ErrorMsg != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}
}
openAIEmbeddingResponse := &types.EmbeddingResponse{
Object: "list",
Data: make([]types.Embedding, 0, len(baiduResponse.Data)),
Model: "text-embedding-v1",
Usage: &baiduResponse.Usage,
}
for _, item := range baiduResponse.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return openAIEmbeddingResponse, nil
}
func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getEmbeddingsRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
if fullRequestURL == "" {
return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
}
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = &baiduEmbeddingResponse.Usage
return usage, nil
}

150
providers/base.go Normal file
View File

@@ -0,0 +1,150 @@
package providers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
var stopFinishReason = "stop"
type ProviderConfig struct {
BaseURL string
Completions string
ChatCompletions string
Embeddings string
AudioSpeech string
AudioTranscriptions string
AudioTranslations string
Proxy string
Context *gin.Context
}
type BaseProviderAction interface {
GetBaseURL() string
GetFullRequestURL(requestURL string, modelName string) string
GetRequestHeaders() (headers map[string]string)
}
type CompletionProviderAction interface {
BaseProviderAction
CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type ChatProviderAction interface {
BaseProviderAction
ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type EmbeddingsProviderAction interface {
BaseProviderAction
EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type BalanceProviderAction interface {
Balance(channel *model.Channel) (float64, error)
}
func (p *ProviderConfig) GetBaseURL() string {
if p.Context.GetString("base_url") != "" {
return p.Context.GetString("base_url")
}
return p.BaseURL
}
func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s", baseURL, requestURL)
}
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var errorResponse types.OpenAIErrorResponse
err = json.Unmarshal(responseBody, &errorResponse)
if err != nil {
return
}
if errorResponse.Error.Type != "" {
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody)
}
return
}
// 供应商响应处理函数
type ProviderResponseHandler interface {
// 请求处理函数
requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
// 发送请求
func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp)
}
// 解析响应
err = common.DecodeResponse(resp.Body, response)
if err != nil {
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
return nil
}

55
providers/claude_base.go Normal file
View File

@@ -0,0 +1,55 @@
package providers
import (
"github.com/gin-gonic/gin"
)
type ClaudeProvider struct {
ProviderConfig
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
return &ClaudeProvider{
ProviderConfig: ProviderConfig{
BaseURL: "https://api.anthropic.com",
ChatCompletions: "/v1/complete",
Context: c,
},
}
}
// 获取请求头
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["x-api-key"] = p.Context.GetString("api_key")
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
headers["anthropic-version"] = anthropicVersion
return headers
}
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}

232
providers/claude_chat.go Normal file
View File

@@ -0,0 +1,232 @@
package providers
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
}
func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if claudeResponse.Error.Type != "" {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
}
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
claudeResponse.Usage.CompletionTokens = completionTokens
claudeResponse.Usage.TotalTokens = claudeResponse.Usage.PromptTokens + completionTokens
fullTextResponse.Usage = claudeResponse.Usage
return fullTextResponse, nil
}
func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *ClaudeRequest) {
claudeRequest := ClaudeRequest{
Model: request.Model,
Prompt: "",
MaxTokensToSample: request.MaxTokens,
StopSequences: nil,
Temperature: request.Temperature,
TopP: request.TopP,
Stream: request.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range request.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
var responseText string
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
return
}
usage.PromptTokens = promptTokens
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
usage.TotalTokens = promptTokens + usage.CompletionTokens
} else {
var claudeResponse = &ClaudeResponse{
Usage: &types.Usage{
PromptTokens: promptTokens,
},
}
openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = claudeResponse.Usage
}
return
}
func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response types.ChatCompletionStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []types.ChatCompletionStreamChoice{choice}
return &response
}
func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
}
defer resp.Body.Close()
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := p.streamResponseClaude2OpenAI(&claudeResponse)
response.ID = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, responseText
}

View File

@@ -0,0 +1,50 @@
package providers
import (
"fmt"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
type CloseaiProxyProvider struct {
*OpenAIProvider
}
type OpenAICreditGrants struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalAvailable float64 `json:"total_available"`
}
// 创建 CloseaiProxyProvider
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
return &CloseaiProxyProvider{
OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
}
}
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response OpenAICreditGrants
err = client.SendRequest(req, &response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalAvailable)
return response.TotalAvailable, nil
}

215
providers/openai_base.go Normal file
View File

@@ -0,0 +1,215 @@
package providers
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type OpenAIProvider struct {
ProviderConfig
isAzure bool
}
type OpenAIProviderResponseHandler interface {
// 请求处理函数
requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
}
type OpenAIProviderStreamResponseHandler interface {
// 请求流处理函数
requestStreamHandler() (responseText string)
}
// 创建 OpenAIProvider
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
if baseURL == "" {
baseURL = "https://api.openai.com"
}
return &OpenAIProvider{
ProviderConfig: ProviderConfig{
BaseURL: baseURL,
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",
Context: c,
},
isAzure: false,
}
}
// 获取完整请求 URL
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
if p.isAzure {
apiVersion := p.Context.GetString("api_version")
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
}
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
if p.isAzure {
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
} else {
requestURL = strings.TrimPrefix(requestURL, "/v1")
}
}
return fmt.Sprintf("%s%s", baseURL, requestURL)
}
// 获取请求头
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
if p.isAzure {
headers["api-key"] = p.Context.GetString("api_key")
} else {
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
}
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json; charset=utf-8"
}
return headers
}
// 获取请求体
func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
if isModelMapped {
jsonStr, err := json.Marshal(request)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = p.Context.Request.Body
}
return
}
// 发送请求
func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp)
}
// 创建一个 bytes.Buffer 来存储响应体
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
// 解析响应
err = common.DecodeResponse(tee, response)
if err != nil {
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIErrorWithStatusCode = response.requestHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(p.Context.Writer, &buf)
if err != nil {
return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
err := json.Unmarshal([]byte(data), response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
responseText += response.requestStreamHandler()
}
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
p.Context.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
return nil, responseText
}

92
providers/openai_chat.go Normal file
View File

@@ -0,0 +1,92 @@
package providers
import (
"net/http"
"one-api/common"
"one-api/types"
)
type OpenAIProviderChatResponse struct {
types.ChatCompletionResponse
types.OpenAIErrorResponse
}
type OpenAIProviderChatStreamResponse struct {
types.ChatCompletionStreamResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil
}
func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
}
return
}
func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream && headers["Accept"] == "" {
headers["Accept"] = "text/event-stream"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
var textResponse string
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(textResponse, request.Model),
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
}
} else {
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = openAIProviderChatResponse.Usage
if usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range openAIProviderChatResponse.Choices {
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return
}

View File

@@ -0,0 +1,87 @@
package providers
import (
"net/http"
"one-api/common"
"one-api/types"
)
type OpenAIProviderCompletionResponse struct {
types.CompletionResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil
}
func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Text
}
return
}
func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream && headers["Accept"] == "" {
headers["Accept"] = "text/event-stream"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
if request.Stream {
// TODO
var textResponse string
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(textResponse, request.Model),
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
}
} else {
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = openAIProviderCompletionResponse.Usage
if usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range openAIProviderCompletionResponse.Choices {
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return
}

View File

@@ -0,0 +1,50 @@
package providers
import (
"net/http"
"one-api/common"
"one-api/types"
)
type OpenAIProviderEmbeddingsResponse struct {
types.EmbeddingResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil
}
func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = openAIProviderEmbeddingsResponse.Usage
return
}

View File

@@ -0,0 +1,58 @@
package providers
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-gonic/gin"
)
type OpenaiSBProvider struct {
*OpenAIProvider
}
type OpenAISBUsageResponse struct {
Msg string `json:"msg"`
Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}
// 创建 OpenaiSBProvider
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
return &OpenaiSBProvider{
OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"),
}
}
func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var response OpenAISBUsageResponse
err = client.SendRequest(req, &response)
if err != nil {
return 0, err
}
if response.Data == nil {
return 0, errors.New(response.Msg)
}
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}

43
providers/palm_base.go Normal file
View File

@@ -0,0 +1,43 @@
package providers
import (
"fmt"
"strings"
"github.com/gin-gonic/gin"
)
type PalmProvider struct {
ProviderConfig
}
// 创建 PalmProvider
func CreatePalmProvider(c *gin.Context) *PalmProvider {
return &PalmProvider{
ProviderConfig: ProviderConfig{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
Context: c,
},
}
}
// 获取请求头
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
return headers
}
// 获取完整请求 URL
func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key"))
}

232
providers/palm_chat.go Normal file
View File

@@ -0,0 +1,232 @@
package providers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
)
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []types.ChatCompletionMessage `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
Usage *types.Usage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
}
func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
Code: palmResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}
}
fullTextResponse := types.ChatCompletionResponse{
Choices: make([]types.ChatCompletionChoice, 0, len(palmResponse.Candidates)),
}
for i, candidate := range palmResponse.Candidates {
choice := types.ChatCompletionChoice{
Index: i,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: candidate.Content,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
completionTokens := common.CountTokenText(palmResponse.Candidates[0].Content, palmResponse.Model)
palmResponse.Usage.CompletionTokens = completionTokens
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
fullTextResponse.Usage = palmResponse.Usage
return fullTextResponse, nil
}
func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(request.Messages)),
},
Temperature: request.Temperature,
CandidateCount: request.N,
TopP: request.TopP,
TopK: request.MaxTokens,
}
for _, message := range request.Messages {
palmMessage := PaLMChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
} else {
palmMessage.Author = "1"
}
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
}
return &palmRequest
}
func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
var responseText string
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
return
}
usage.PromptTokens = promptTokens
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
usage.TotalTokens = promptTokens + usage.CompletionTokens
} else {
var palmChatResponse = &PaLMChatResponse{
Model: request.Model,
Usage: &types.Usage{
PromptTokens: promptTokens,
},
}
openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = palmChatResponse.Usage
}
return
}
func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = &stopFinishReason
var response types.ChatCompletionStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
response.Choices = []types.ChatCompletionStreamChoice{choice}
return &response
}
func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
}
defer resp.Body.Close()
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
common.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := p.streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.ID = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
setEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: " + data})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, responseText
}

94
providers/tencent_base.go Normal file
View File

@@ -0,0 +1,94 @@
package providers
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"sort"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type TencentProvider struct {
ProviderConfig
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
// 创建 TencentProvider
func CreateTencentProvider(c *gin.Context) *TencentProvider {
return &TencentProvider{
ProviderConfig: ProviderConfig{
BaseURL: "https://hunyuan.cloud.tencent.com",
ChatCompletions: "/hyllm/v1/chat/completions",
Context: c,
},
}
}
// 获取请求头
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
return headers
}
func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
apiKey := p.Context.GetString("api_key")
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
if err != nil {
return ""
}
req.AppId = appId
req.SecretId = secretId
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

265
providers/tencent_chat.go Normal file
View File

@@ -0,0 +1,265 @@
package providers
import (
"bufio"
"encoding/json"
"errors"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
)
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage *types.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if TencentResponse.Error.Code != 0 {
return &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := types.ChatCompletionResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: TencentResponse.Usage,
}
if len(TencentResponse.Choices) > 0 {
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: TencentResponse.Choices[0].Messages.Content,
},
FinishReason: TencentResponse.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return fullTextResponse, nil
}
func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
sign := p.getTencentSign(*requestBody)
if sign == "" {
return nil, types.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
headers["Authorization"] = sign
if request.Stream {
headers["Accept"] = "text/event-stream"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
var responseText string
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
return
}
usage.PromptTokens = promptTokens
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
usage.TotalTokens = promptTokens + usage.CompletionTokens
} else {
tencentResponse := &TencentChatResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = tencentResponse.Usage
}
return
}
func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *types.ChatCompletionStreamResponse {
response := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), ""
}
defer resp.Body.Close()
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := p.streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, responseText
}

96
providers/xunfei_base.go Normal file
View File

@@ -0,0 +1,96 @@
package providers
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/url"
"one-api/common"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiProvider struct {
ProviderConfig
domain string
apiId string
}
// 创建 XunfeiProvider
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
return &XunfeiProvider{
ProviderConfig: ProviderConfig{
BaseURL: "wss://spark-api.xf-yun.com",
ChatCompletions: "",
Context: c,
},
}
}
// 获取请求头
func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
return headers
}
// 获取完整请求 URL
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
splits := strings.Split(p.Context.GetString("api_key"), "|")
if len(splits) != 3 {
return ""
}
domain, authUrl := p.getXunfeiAuthUrl(splits[2], splits[1])
p.domain = domain
p.apiId = splits[0]
return authUrl
}
func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (string, string) {
query := p.Context.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = p.Context.GetString("api_version")
}
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
}
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret)
return domain, authUrl
}
func (p *XunfeiProvider) buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
HmacWithShaToBase64 := func(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
ul, err := url.Parse(hostUrl)
if err != nil {
fmt.Println(err)
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
sign := strings.Join(signString, "\n")
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
callUrl := hostUrl + "?" + v.Encode()
return callUrl
}

263
providers/xunfei_chat.go Normal file
View File

@@ -0,0 +1,263 @@
package providers
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/types"
"time"
"github.com/gorilla/websocket"
)
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text types.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
if request.Stream {
return p.sendStreamRequest(request, authUrl)
} else {
return p.sendRequest(request, authUrl)
}
}
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
}
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := p.responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return nil, types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
_, _ = p.Context.Writer.Write(jsonResponse)
return usage, nil
}
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
}
setEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return usage, nil
}
func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
}
xunfeiRequest := XunfeiChatRequest{}
xunfeiRequest.Header.AppId = p.apiId
xunfeiRequest.Parameter.Chat.Domain = p.domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest
}
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *types.ChatCompletionResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: stopFinishReason,
}
fullTextResponse := types.ChatCompletionResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Usage: &response.Payload.Usage.Text,
}
return &fullTextResponse
}
func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return nil, nil, err
}
data := p.requestOpenAI2Xunfei(textRequest)
err = conn.WriteJSON(data)
if err != nil {
return nil, nil, err
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
}
break
}
}
stopChan <- true
}()
return dataChan, stopChan, nil
}
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *types.ChatCompletionStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &stopFinishReason
}
response := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}

104
providers/zhipu_base.go Normal file
View File

@@ -0,0 +1,104 @@
package providers
import (
"fmt"
"one-api/common"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
)
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
type ZhipuProvider struct {
ProviderConfig
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
}
// 创建 ZhipuProvider
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
return &ZhipuProvider{
ProviderConfig: ProviderConfig{
BaseURL: "https://open.bigmodel.cn",
ChatCompletions: "/api/paas/v3/model-api",
Context: c,
},
}
}
// 获取请求头
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["Authorization"] = p.getZhipuToken()
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
if headers["Content-Type"] == "" {
headers["Content-Type"] = "application/json"
}
return headers
}
// 获取完整请求 URL
func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
}
func (p *ZhipuProvider) getZhipuToken() string {
apikey := p.Context.GetString("api_key")
data, ok := zhipuTokens.Load(apikey)
if ok {
tokenData := data.(zhipuTokenData)
if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token
}
}
split := strings.Split(apikey, ".")
if len(split) != 2 {
common.SysError("invalid zhipu key: " + apikey)
return ""
}
id := split[0]
secret := split[1]
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
timestamp := time.Now().UnixNano() / 1e6
payload := jwt.MapClaims{
"api_key": id,
"exp": expMillis,
"timestamp": timestamp,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
return ""
}
zhipuTokens.Store(apikey, zhipuTokenData{
Token: tokenString,
ExpiryTime: expiryTime,
})
return tokenString
}

260
providers/zhipu_chat.go Normal file
View File

@@ -0,0 +1,260 @@
package providers
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
)
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
types.Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
types.Usage `json:"usage"`
}
func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
if !zhipuResponse.Success {
return &types.OpenAIErrorWithStatusCode{
OpenAIError: types.OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
Code: zhipuResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := types.ChatCompletionResponse{
ID: zhipuResponse.Data.TaskId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
Usage: &zhipuResponse.Data.Usage,
}
for i, choice := range zhipuResponse.Data.Choices {
openaiChoice := types.ChatCompletionChoice{
Index: i,
Message: types.ChatCompletionMessage{
Role: choice.Role,
Content: strings.Trim(choice.Content, "\""),
},
FinishReason: "",
}
if i == len(zhipuResponse.Data.Choices)-1 {
openaiChoice.FinishReason = "stop"
}
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
}
return fullTextResponse, nil
}
func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) *ZhipuRequest {
messages := make([]ZhipuMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, ZhipuMessage{
Role: "system",
Content: message.StringContent(),
})
messages = append(messages, ZhipuMessage{
Role: "user",
Content: "Okay",
})
} else {
messages = append(messages, ZhipuMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &ZhipuRequest{
Prompt: messages,
Temperature: request.Temperature,
TopP: request.TopP,
Incremental: false,
}
}
func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
headers["Accept"] = "text/event-stream"
fullRequestURL += "/sse-invoke"
} else {
fullRequestURL += "/invoke"
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if request.Stream {
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
if openAIErrorWithStatusCode != nil {
return
}
} else {
zhipuResponse := &ZhipuResponse{}
openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse)
if openAIErrorWithStatusCode != nil {
return
}
usage = &zhipuResponse.Data.Usage
}
return
}
func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = zhipuResponse
response := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "chatglm",
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = ""
choice.FinishReason = &stopFinishReason
response := types.ChatCompletionStreamResponse{
ID: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "chatglm",
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response, &zhipuResponse.Usage
}
func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
}
if common.IsFailureStatusCode(resp) {
return p.handleErrorResp(resp), nil
}
defer resp.Body.Close()
var usage *types.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
metaChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
lines := strings.Split(data, "\n")
for i, line := range lines {
if len(line) < 5 {
continue
}
if line[:5] == "data:" {
dataChan <- line[5:]
if i != len(lines)-1 {
dataChan <- "\n"
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
}
}
}
stopChan <- true
}()
setEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
response := p.streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case data := <-metaChan:
var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, usage
}