merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-07-19 10:58:21 +08:00
72 changed files with 1989 additions and 1193 deletions

View File

@@ -1,10 +1,13 @@
package openai
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"mime/multipart"
"net/http"
"one-api/common"
"one-api/dto"
@@ -14,22 +17,16 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
ChannelType int
ChannelType int
ResponseFormat string
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
}
@@ -74,28 +71,84 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
}
return bytes.NewReader(jsonData), nil
} else {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
part, err := writer.CreateFormFile("file", header.Filename)
if err != nil {
return nil, errors.New("create form file failed")
}
if _, err := io.Copy(part, file); err != nil {
return nil, errors.New("copy file failed")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &requestBody, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
var toolCount int
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
switch info.RelayMode {
case constant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}

View File

@@ -9,6 +9,7 @@ var ModelList = []string{
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview",
"gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
"text-moderation-latest", "text-moderation-stable",

View File

@@ -4,6 +4,8 @@ import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -14,38 +16,36 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"sync"
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
containStreamUsage := false
responseId := ""
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
var responseTextBuilder strings.Builder
var usage dto.Usage
var usage = &dto.Usage{}
var streamItems []string // store stream items
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
stopChan := make(chan bool)
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string // store stream items
gopool.Go(func() {
for scanner.Scan() {
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
@@ -53,54 +53,46 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
err := service.StringData(c, data)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
streamItems = append(streamItems, data)
}
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
} else {
for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
common.SafeSendBool(stopChan, true)
})
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -116,67 +108,69 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
} else {
for _, streamResponse := range streamResponses {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
} else {
for _, streamResponse := range streamResponses {
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
}
if len(dataChan) > 0 {
// wait data out
time.Sleep(2 * time.Second)
}
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
case <-ticker.C:
common.LogError(c, "reading data from upstream timeout")
return false
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
}
wg.Wait()
return nil, &usage, responseTextBuilder.String(), toolCount
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
}
service.Done(c)
resp.Body.Close()
return nil, usage
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -213,11 +207,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
@@ -232,3 +222,134 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
return nil, &simpleResponse.Usage
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
return nil, usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var audioResp dto.AudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &audioResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func getTextFromVTT(body []byte) (string, error) {
return getTextFromSRT(body)
}
func getTextFromVerboseJSON(body []byte) (string, error) {
var whisperResponse dto.WhisperVerboseJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}
func getTextFromSRT(body []byte) (string, error) {
scanner := bufio.NewScanner(strings.NewReader(string(body)))
var builder strings.Builder
var textLine bool
for scanner.Scan() {
line := scanner.Text()
if textLine {
builder.WriteString(line)
textLine = false
continue
} else if strings.Contains(line, "-->") {
textLine = true
continue
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return builder.String(), nil
}
func getTextFromText(body []byte) (string, error) {
return strings.TrimSuffix(string(body), "\n"), nil
}
func getTextFromJSON(body []byte) (string, error) {
var whisperResponse dto.AudioResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}