mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-14 12:23:41 +08:00
♻️ refactor: split relay
This commit is contained in:
50
providers/ali_base.go
Normal file
50
providers/ali_base.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AliAIProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// 创建 AliAIProvider
|
||||
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||
func CreateAliAIProvider(c *gin.Context) *AliAIProvider {
|
||||
return &AliAIProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "https://dashscope.aliyuncs.com",
|
||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
256
providers/ali_chat.go
Normal file
256
providers/ali_chat.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: aliResponse.Output.Text,
|
||||
},
|
||||
FinishReason: aliResponse.Output.FinishReason,
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: aliResponse.Usage.InputTokens,
|
||||
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.StringContent(),
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = message.StringContent()
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.StringContent(),
|
||||
Bot: request.Messages[i+1].StringContent(),
|
||||
})
|
||||
i++
|
||||
}
|
||||
}
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Input: AliInput{
|
||||
Prompt: prompt,
|
||||
History: messages,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
headers["X-DashScope-SSE"] = "enable"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &types.Usage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
aliResponse := &AliChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, aliResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: aliResponse.Usage.InputTokens,
|
||||
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "ernie-bot",
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
||||
usage = &types.Usage{}
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), nil
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
lastResponseText := ""
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := p.streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
94
providers/ali_embeddings.go
Normal file
94
providers/ali_embeddings.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
openAIEmbeddingResponse := &types.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range aliResponse.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
|
||||
return openAIEmbeddingResponse, nil
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getEmbeddingsRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
aliEmbeddingResponse := &AliEmbeddingResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
14
providers/api2d_base.go
Normal file
14
providers/api2d_base.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package providers
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
type Api2dProvider struct {
|
||||
*OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
|
||||
return &Api2dProvider{
|
||||
OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"),
|
||||
}
|
||||
}
|
||||
41
providers/azure_base.go
Normal file
41
providers/azure_base.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AzureProvider struct {
|
||||
OpenAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func CreateAzureProvider(c *gin.Context) *AzureProvider {
|
||||
return &AzureProvider{
|
||||
OpenAIProvider: OpenAIProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "",
|
||||
Completions: "/completions",
|
||||
ChatCompletions: "/chat/completions",
|
||||
Embeddings: "/embeddings",
|
||||
AudioSpeech: "/audio/speech",
|
||||
AudioTranscriptions: "/audio/transcriptions",
|
||||
AudioTranslations: "/audio/translations",
|
||||
Context: c,
|
||||
},
|
||||
isAzure: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// // 获取完整请求 URL
|
||||
// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
// apiVersion := p.Context.GetString("api_version")
|
||||
// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||
// }
|
||||
|
||||
// return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
// }
|
||||
136
providers/baidu_base.go
Normal file
136
providers/baidu_base.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
type BaiduProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type BaiduAccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
|
||||
return &BaiduProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "https://aip.baidubce.com",
|
||||
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
||||
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
var modelNameMap = map[string]string{
|
||||
"ERNIE-Bot": "completions",
|
||||
"ERNIE-Bot-turbo": "eb-instant",
|
||||
"ERNIE-Bot-4": "completions_pro",
|
||||
"BLOOMZ-7B": "bloomz_7b1",
|
||||
"Embedding-V1": "embedding-v1",
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
apiKey, err := p.getBaiduAccessToken()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s/%s?access_token=%s", baseURL, requestURL, modelNameMap[modelName], apiKey)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
||||
apiKey := p.Context.GetString("api_key")
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken BaiduAccessToken
|
||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||
// soon this will expire
|
||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
go func() {
|
||||
_, _ = p.getBaiduAccessTokenHelper(apiKey)
|
||||
}()
|
||||
}
|
||||
return accessToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := p.getBaiduAccessTokenHelper(apiKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accessToken == nil {
|
||||
return "", errors.New("getBaiduAccessToken return a nil token")
|
||||
}
|
||||
return (*accessToken).AccessToken, nil
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
|
||||
|
||||
var headers = map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
req, err := client.NewRequest("POST", url, common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
var accessToken BaiduAccessToken
|
||||
err = json.NewDecoder(resp.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accessToken.Error != "" {
|
||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
}
|
||||
if accessToken.AccessToken == "" {
|
||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||
}
|
||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||
baiduTokenStore.Store(apiKey, accessToken)
|
||||
return &accessToken, nil
|
||||
}
|
||||
228
providers/baidu_chat.go
Normal file
228
providers/baidu_chat.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BaiduMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type BaiduChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage *types.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: baiduResponse.Result,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
ID: baiduResponse.Id,
|
||||
Object: "chat.completion",
|
||||
Created: baiduResponse.Created,
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Usage: baiduResponse.Usage,
|
||||
}
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
type BaiduChatStreamResponse struct {
|
||||
BaiduChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduError struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
baiduChatRequest := &BaiduChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = baiduChatRequest.Usage
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
ID: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
||||
usage = &types.Usage{}
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), nil
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
88
providers/baidu_embeddings.go
Normal file
88
providers/baidu_embeddings.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage types.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||
return &BaiduEmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
openAIEmbeddingResponse := &types.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]types.Embedding, 0, len(baiduResponse.Data)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: &baiduResponse.Usage,
|
||||
}
|
||||
|
||||
for _, item := range baiduResponse.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
|
||||
return openAIEmbeddingResponse, nil
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getEmbeddingsRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
usage = &baiduEmbeddingResponse.Usage
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
150
providers/base.go
Normal file
150
providers/base.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var stopFinishReason = "stop"
|
||||
|
||||
type ProviderConfig struct {
|
||||
BaseURL string
|
||||
Completions string
|
||||
ChatCompletions string
|
||||
Embeddings string
|
||||
AudioSpeech string
|
||||
AudioTranscriptions string
|
||||
AudioTranslations string
|
||||
Proxy string
|
||||
Context *gin.Context
|
||||
}
|
||||
|
||||
type BaseProviderAction interface {
|
||||
GetBaseURL() string
|
||||
GetFullRequestURL(requestURL string, modelName string) string
|
||||
GetRequestHeaders() (headers map[string]string)
|
||||
}
|
||||
|
||||
type CompletionProviderAction interface {
|
||||
BaseProviderAction
|
||||
CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type ChatProviderAction interface {
|
||||
BaseProviderAction
|
||||
ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type EmbeddingsProviderAction interface {
|
||||
BaseProviderAction
|
||||
EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type BalanceProviderAction interface {
|
||||
Balance(channel *model.Channel) (float64, error)
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetBaseURL() string {
|
||||
if p.Context.GetString("base_url") != "" {
|
||||
return p.Context.GetString("base_url")
|
||||
}
|
||||
|
||||
return p.BaseURL
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
}
|
||||
|
||||
func setEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var errorResponse types.OpenAIErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errorResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if errorResponse.Error.Type != "" {
|
||||
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
|
||||
} else {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 供应商响应处理函数
|
||||
type ProviderResponseHandler interface {
|
||||
// 请求处理函数
|
||||
requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
err = common.DecodeResponse(resp.Body, response)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(openAIResponse)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = p.Context.Writer.Write(jsonResponse)
|
||||
return nil
|
||||
}
|
||||
55
providers/claude_base.go
Normal file
55
providers/claude_base.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ClaudeProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
ChatCompletions: "/v1/complete",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["x-api-key"] = p.Context.GetString("api_key")
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
headers["anthropic-version"] = anthropicVersion
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
232
providers/claude_chat.go
Normal file
232
providers/claude_chat.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
|
||||
claudeResponse.Usage.CompletionTokens = completionTokens
|
||||
claudeResponse.Usage.TotalTokens = claudeResponse.Usage.PromptTokens + completionTokens
|
||||
|
||||
fullTextResponse.Usage = claudeResponse.Usage
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *ClaudeRequest) {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: request.Model,
|
||||
Prompt: "",
|
||||
MaxTokensToSample: request.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: request.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokensToSample == 0 {
|
||||
claudeRequest.MaxTokensToSample = 1000000
|
||||
}
|
||||
prompt := ""
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "user" {
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else if message.Role == "system" {
|
||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
claudeRequest.Prompt = prompt
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
|
||||
} else {
|
||||
var claudeResponse = &ClaudeResponse{
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
},
|
||||
}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = claudeResponse.Usage
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var response types.ChatCompletionStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
||||
return i + 4, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if !strings.HasPrefix(data, "event: completion") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse ClaudeResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
responseText += claudeResponse.Completion
|
||||
response := p.streamResponseClaude2OpenAI(&claudeResponse)
|
||||
response.ID = responseId
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
50
providers/closeai_proxy_base.go
Normal file
50
providers/closeai_proxy_base.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CloseaiProxyProvider struct {
|
||||
*OpenAIProvider
|
||||
}
|
||||
|
||||
type OpenAICreditGrants struct {
|
||||
Object string `json:"object"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
}
|
||||
|
||||
// 创建 CloseaiProxyProvider
|
||||
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
|
||||
return &CloseaiProxyProvider{
|
||||
OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response OpenAICreditGrants
|
||||
err = client.SendRequest(req, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
|
||||
return response.TotalAvailable, nil
|
||||
}
|
||||
215
providers/openai_base.go
Normal file
215
providers/openai_base.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
ProviderConfig
|
||||
isAzure bool
|
||||
}
|
||||
|
||||
type OpenAIProviderResponseHandler interface {
|
||||
// 请求处理函数
|
||||
requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type OpenAIProviderStreamResponseHandler interface {
|
||||
// 请求流处理函数
|
||||
requestStreamHandler() (responseText string)
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
|
||||
return &OpenAIProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: baseURL,
|
||||
Completions: "/v1/completions",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
AudioSpeech: "/v1/audio/speech",
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
Context: c,
|
||||
},
|
||||
isAzure: false,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
if p.isAzure {
|
||||
apiVersion := p.Context.GetString("api_version")
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
if p.isAzure {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||
} else {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
if p.isAzure {
|
||||
headers["api-key"] = p.Context.GetString("api_key")
|
||||
} else {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
}
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// 获取请求体
|
||||
func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = p.Context.Request.Body
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 创建一个 bytes.Buffer 来存储响应体
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
|
||||
// 解析响应
|
||||
err = common.DecodeResponse(tee, response)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIErrorWithStatusCode = response.requestHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
p.Context.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(p.Context.Writer, &buf)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
err := json.Unmarshal([]byte(data), response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue // just ignore the error
|
||||
}
|
||||
responseText += response.requestStreamHandler()
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
p.Context.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
92
providers/openai_chat.go
Normal file
92
providers/openai_chat.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderChatResponse struct {
|
||||
types.ChatCompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderChatStreamResponse struct {
|
||||
types.ChatCompletionStreamResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||
var textResponse string
|
||||
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderChatResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderChatResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
87
providers/openai_completion.go
Normal file
87
providers/openai_completion.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderCompletionResponse struct {
|
||||
types.CompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Text
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
|
||||
if request.Stream {
|
||||
// TODO
|
||||
var textResponse string
|
||||
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderCompletionResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderCompletionResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
50
providers/openai_embeddings.go
Normal file
50
providers/openai_embeddings.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderEmbeddingsResponse struct {
|
||||
types.EmbeddingResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderEmbeddingsResponse.Usage
|
||||
|
||||
return
|
||||
}
|
||||
58
providers/openaisb_base.go
Normal file
58
providers/openaisb_base.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenaiSBProvider struct {
|
||||
*OpenAIProvider
|
||||
}
|
||||
|
||||
type OpenAISBUsageResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
Credit string `json:"credit"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// 创建 OpenaiSBProvider
|
||||
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
|
||||
return &OpenaiSBProvider{
|
||||
OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response OpenAISBUsageResponse
|
||||
err = client.SendRequest(req, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if response.Data == nil {
|
||||
return 0, errors.New(response.Msg)
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
43
providers/palm_base.go
Normal file
43
providers/palm_base.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PalmProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
// 创建 PalmProvider
|
||||
func CreatePalmProvider(c *gin.Context) *PalmProvider {
|
||||
return &PalmProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key"))
|
||||
}
|
||||
232
providers/palm_chat.go
Normal file
232
providers/palm_chat.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []types.ChatCompletionMessage `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
Code: palmResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(palmResponse.Candidates)),
|
||||
}
|
||||
for i, candidate := range palmResponse.Candidates {
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: i,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: candidate.Content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(palmResponse.Candidates[0].Content, palmResponse.Model)
|
||||
palmResponse.Usage.CompletionTokens = completionTokens
|
||||
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
|
||||
|
||||
fullTextResponse.Usage = palmResponse.Usage
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(request.Messages)),
|
||||
},
|
||||
Temperature: request.Temperature,
|
||||
CandidateCount: request.N,
|
||||
TopP: request.TopP,
|
||||
TopK: request.MaxTokens,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
} else {
|
||||
palmMessage.Author = "1"
|
||||
}
|
||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||
}
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
|
||||
} else {
|
||||
var palmChatResponse = &PaLMChatResponse{
|
||||
Model: request.Model,
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
},
|
||||
}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = palmChatResponse.Usage
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &stopFinishReason
|
||||
var response types.ChatCompletionStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
fullTextResponse := p.streamResponsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.ID = responseId
|
||||
fullTextResponse.Created = createdTime
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
responseText = palmResponse.Candidates[0].Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
94
providers/tencent_base.go
Normal file
94
providers/tencent_base.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TencentProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// 创建 TencentProvider
|
||||
func CreateTencentProvider(c *gin.Context) *TencentProvider {
|
||||
return &TencentProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "https://hunyuan.cloud.tencent.com",
|
||||
ChatCompletions: "/hyllm/v1/chat/completions",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
parts := strings.Split(config, "|")
|
||||
if len(parts) != 3 {
|
||||
err = errors.New("invalid tencent config")
|
||||
return
|
||||
}
|
||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
secretId = parts[1]
|
||||
secretKey = parts[2]
|
||||
return
|
||||
}
|
||||
|
||||
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
||||
apiKey := p.Context.GetString("api_key")
|
||||
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
req.AppId = appId
|
||||
req.SecretId = secretId
|
||||
|
||||
params := make([]string, 0)
|
||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
params = append(params, "secret_id="+req.SecretId)
|
||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
params = append(params, "query_id="+req.QueryID)
|
||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
var messageStr string
|
||||
for _, msg := range req.Messages {
|
||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
}
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Sort(sort.StringSlice(params))
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
mac.Write([]byte(signURL))
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
}
|
||||
265
providers/tencent_chat.go
Normal file
265
providers/tencent_chat.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage *types.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: TencentResponse.Usage,
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: TencentResponse.Choices[0].Messages.Content,
|
||||
},
|
||||
FinishReason: TencentResponse.Choices[0].FinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionRequest) *TencentChatRequest {
|
||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
continue
|
||||
}
|
||||
messages = append(messages, TencentMessage{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
stream := 0
|
||||
if request.Stream {
|
||||
stream = 1
|
||||
}
|
||||
return &TencentChatRequest{
|
||||
Timestamp: common.GetTimestamp(),
|
||||
Expired: common.GetTimestamp() + 24*60*60,
|
||||
QueryID: common.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
Messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
sign := p.getTencentSign(*requestBody)
|
||||
if sign == "" {
|
||||
return nil, types.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
headers["Authorization"] = sign
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
|
||||
} else {
|
||||
tencentResponse := &TencentChatResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = tencentResponse.Usage
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *types.ChatCompletionStreamResponse {
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := p.streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
96
providers/xunfei_base.go
Normal file
96
providers/xunfei_base.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
type XunfeiProvider struct {
|
||||
ProviderConfig
|
||||
domain string
|
||||
apiId string
|
||||
}
|
||||
|
||||
// 创建 XunfeiProvider
|
||||
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
|
||||
return &XunfeiProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "wss://spark-api.xf-yun.com",
|
||||
ChatCompletions: "",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
return headers
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
splits := strings.Split(p.Context.GetString("api_key"), "|")
|
||||
if len(splits) != 3 {
|
||||
return ""
|
||||
}
|
||||
domain, authUrl := p.getXunfeiAuthUrl(splits[2], splits[1])
|
||||
|
||||
p.domain = domain
|
||||
p.apiId = splits[0]
|
||||
|
||||
return authUrl
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (string, string) {
|
||||
query := p.Context.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = p.Context.GetString("api_version")
|
||||
}
|
||||
if apiVersion == "" {
|
||||
apiVersion = "v1.1"
|
||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
}
|
||||
domain := "general"
|
||||
if apiVersion != "v1.1" {
|
||||
domain += strings.Split(apiVersion, ".")[0]
|
||||
}
|
||||
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
encodeData := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(encodeData)
|
||||
}
|
||||
ul, err := url.Parse(hostUrl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
date := time.Now().UTC().Format(time.RFC1123)
|
||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
sign := strings.Join(signString, "\n")
|
||||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
"hmac-sha256", "host date request-line", sha)
|
||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
v := url.Values{}
|
||||
v.Add("host", ul.Host)
|
||||
v.Add("date", date)
|
||||
v.Add("authorization", authorization)
|
||||
callUrl := hostUrl + "?" + v.Encode()
|
||||
return callUrl
|
||||
}
|
||||
263
providers/xunfei_chat.go
Normal file
263
providers/xunfei_chat.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type XunfeiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type XunfeiChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []XunfeiMessage `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text types.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
|
||||
if request.Stream {
|
||||
return p.sendStreamRequest(request, authUrl)
|
||||
} else {
|
||||
return p.sendRequest(request, authUrl)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var content string
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case xunfeiResponse = <-dataChan:
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := p.responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = p.Context.Writer.Write(jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
||||
}
|
||||
setEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest {
|
||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
xunfeiRequest := XunfeiChatRequest{}
|
||||
xunfeiRequest.Header.AppId = p.apiId
|
||||
xunfeiRequest.Parameter.Chat.Domain = p.domain
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *types.ChatCompletionResponse {
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
},
|
||||
FinishReason: stopFinishReason,
|
||||
}
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Usage: &response.Payload.Usage.Text,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
conn, resp, err := d.Dial(authUrl, nil)
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return nil, nil, err
|
||||
}
|
||||
data := p.requestOpenAI2Xunfei(textRequest)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
var response XunfeiChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
|
||||
return dataChan, stopChan, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *types.ChatCompletionStreamResponse {
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
}
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
104
providers/zhipu_base.go
Normal file
104
providers/zhipu_base.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
type ZhipuProvider struct {
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
type zhipuTokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
}
|
||||
|
||||
// 创建 ZhipuProvider
|
||||
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
|
||||
return &ZhipuProvider{
|
||||
ProviderConfig: ProviderConfig{
|
||||
BaseURL: "https://open.bigmodel.cn",
|
||||
ChatCompletions: "/api/paas/v3/model-api",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
||||
headers["Authorization"] = p.getZhipuToken()
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
if headers["Content-Type"] == "" {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) getZhipuToken() string {
|
||||
apikey := p.Context.GetString("api_key")
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(zhipuTokenData)
|
||||
if time.Now().Before(tokenData.ExpiryTime) {
|
||||
return tokenData.Token
|
||||
}
|
||||
}
|
||||
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
common.SysError("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
id := split[0]
|
||||
secret := split[1]
|
||||
|
||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||
|
||||
timestamp := time.Now().UnixNano() / 1e6
|
||||
|
||||
payload := jwt.MapClaims{
|
||||
"api_key": id,
|
||||
"exp": expMillis,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
|
||||
token.Header["alg"] = "HS256"
|
||||
token.Header["sign_type"] = "SIGN"
|
||||
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
zhipuTokens.Store(apikey, zhipuTokenData{
|
||||
Token: tokenString,
|
||||
ExpiryTime: expiryTime,
|
||||
})
|
||||
|
||||
return tokenString
|
||||
}
|
||||
260
providers/zhipu_chat.go
Normal file
260
providers/zhipu_chat.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ZhipuMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ZhipuRequest struct {
|
||||
Prompt []ZhipuMessage `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
type ZhipuResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []ZhipuMessage `json:"choices"`
|
||||
types.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ZhipuResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ZhipuResponseData `json:"data"`
|
||||
}
|
||||
|
||||
type ZhipuStreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
types.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
if !zhipuResponse.Success {
|
||||
return &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
Code: zhipuResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
ID: zhipuResponse.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
|
||||
Usage: &zhipuResponse.Data.Usage,
|
||||
}
|
||||
for i, choice := range zhipuResponse.Data.Choices {
|
||||
openaiChoice := types.ChatCompletionChoice{
|
||||
Index: i,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: choice.Role,
|
||||
Content: strings.Trim(choice.Content, "\""),
|
||||
},
|
||||
FinishReason: "",
|
||||
}
|
||||
if i == len(zhipuResponse.Data.Choices)-1 {
|
||||
openaiChoice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
}
|
||||
return fullTextResponse, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) *ZhipuRequest {
|
||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "system",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "user",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &ZhipuRequest{
|
||||
Prompt: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Incremental: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
fullRequestURL += "/sse-invoke"
|
||||
} else {
|
||||
fullRequestURL += "/invoke"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
zhipuResponse := &ZhipuResponse{}
|
||||
openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &zhipuResponse.Data.Usage
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &stopFinishReason
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
ID: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.handleErrorResp(resp), nil
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
var usage *types.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
metaChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
lines := strings.Split(data, "\n")
|
||||
for i, line := range lines {
|
||||
if len(line) < 5 {
|
||||
continue
|
||||
}
|
||||
if line[:5] == "data:" {
|
||||
dataChan <- line[5:]
|
||||
if i != len(lines)-1 {
|
||||
dataChan <- "\n"
|
||||
}
|
||||
} else if line[:5] == "meta:" {
|
||||
metaChan <- line[5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
response := p.streamResponseZhipu2OpenAI(data)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case data := <-metaChan:
|
||||
var zhipuResponse ZhipuStreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage = zhipuUsage
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
Reference in New Issue
Block a user