mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-14 20:23:46 +08:00
🎨 调整供应商目录结构,合并文本输出函数
This commit is contained in:
205
providers/openai/base.go
Normal file
205
providers/openai/base.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
base.BaseProvider
|
||||
IsAzure bool
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
|
||||
return &OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: baseURL,
|
||||
Completions: "/v1/completions",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
AudioSpeech: "/v1/audio/speech",
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
Context: c,
|
||||
},
|
||||
IsAzure: false,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
if p.IsAzure {
|
||||
apiVersion := p.Context.GetString("api_version")
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
if p.IsAzure {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||
} else {
|
||||
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
if p.IsAzure {
|
||||
headers["api-key"] = p.Context.GetString("api_key")
|
||||
} else {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// 获取请求体
|
||||
func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = p.Context.Request.Body
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 创建一个 bytes.Buffer 来存储响应体
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
|
||||
// 解析响应
|
||||
err = common.DecodeResponse(tee, response)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIErrorWithStatusCode = response.responseHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
p.Context.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(p.Context.Writer, &buf)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return p.HandleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
err := json.Unmarshal([]byte(data), response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue // just ignore the error
|
||||
}
|
||||
responseText += response.responseStreamHandler()
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
p.Context.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
82
providers/openai/chat.go
Normal file
82
providers/openai/chat.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderChatResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||
var textResponse string
|
||||
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
||||
errWithCode = p.sendRequest(req, openAIProviderChatResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderChatResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderChatResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
82
providers/openai/completion.go
Normal file
82
providers/openai/completion.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Text
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
|
||||
if request.Stream {
|
||||
// TODO
|
||||
var textResponse string
|
||||
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
errWithCode = p.sendRequest(req, openAIProviderCompletionResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderCompletionResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderCompletionResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
45
providers/openai/embeddings.go
Normal file
45
providers/openai/embeddings.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderEmbeddingsResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
||||
errWithCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderEmbeddingsResponse.Usage
|
||||
|
||||
return
|
||||
}
|
||||
16
providers/openai/interface.go
Normal file
16
providers/openai/interface.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type OpenAIProviderResponseHandler interface {
|
||||
// 请求处理函数
|
||||
responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type OpenAIProviderStreamResponseHandler interface {
|
||||
// 请求流处理函数
|
||||
responseStreamHandler() (responseText string)
|
||||
}
|
||||
23
providers/openai/type.go
Normal file
23
providers/openai/type.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package openai
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type OpenAIProviderChatResponse struct {
|
||||
types.ChatCompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderChatStreamResponse struct {
|
||||
types.ChatCompletionStreamResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderCompletionResponse struct {
|
||||
types.CompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
type OpenAIProviderEmbeddingsResponse struct {
|
||||
types.EmbeddingResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
Reference in New Issue
Block a user