diff --git a/controller/channel-test.go b/controller/channel-test.go index e1af673..90d02d6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return err, nil } if resp != nil && resp.StatusCode != http.StatusOK { - err := relaycommon.RelayErrorHandler(resp) + err := service.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err } usage, respErr := adaptor.DoResponse(c, resp, meta) diff --git a/middleware/distributor.go b/middleware/distributor.go index 2552f29..1ce787e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -161,9 +161,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranslation } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranscription } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 2f3c46d..a518da8 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -1,6 +1,7 @@ package cloudflare import ( + "bytes" "errors" "fmt" "github.com/gin-gonic/gin" @@ -15,16 +16,6 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } @@ -65,11 +56,42 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + // 添加文件字段 + file, _, err := c.Request.FormFile("file") + if err != nil { + return nil, errors.New("file is required") + } + defer file.Close() + // 打开临时文件用于保存上传的文件内容 + requestBody := &bytes.Buffer{} + + // 将上传的文件内容复制到临时文件 + if _, err := io.Copy(requestBody, file); err != nil { + return nil, err + } + return requestBody, nil +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = cfStreamHandler(c, resp, info) - } else { - err, usage = cfHandler(c, resp, info) + switch info.RelayMode { + case constant.RelayModeEmbeddings: + fallthrough + case constant.RelayModeChatCompletions: + if info.IsStream { + err, usage = cfStreamHandler(c, resp, info) + } else { + err, usage = cfHandler(c, resp, info) + } + case constant.RelayModeAudioTranslation: + fallthrough + case constant.RelayModeAudioTranscription: + err, usage = cfSTTHandler(c, resp, info) } return } diff --git a/relay/channel/cloudflare/model.go b/relay/channel/cloudflare/dto.go similarity index 78% rename from relay/channel/cloudflare/model.go rename to relay/channel/cloudflare/dto.go index c870813..2f6531c 100644 --- a/relay/channel/cloudflare/model.go +++ b/relay/channel/cloudflare/dto.go @@ -11,3 +11,11 @@ type CfRequest struct { Stream bool `json:"stream,omitempty"` Temperature float64 `json:"temperature,omitempty"` } + +type CfAudioResponse struct { + Result CfSTTResult `json:"result"` +} + +type CfSTTResult struct { + Text string `json:"text"` +} diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index d9319ef..69d6b85 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -119,3 +119,38 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) _, _ = c.Writer.Write(jsonResponse) return nil, usage } + +func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var cfResp CfAudioResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &cfResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + audioResp := &dto.AudioResponse{ + Text: cfResp.Result.Text, + } + + jsonResponse, err := json.Marshal(audioResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return nil, usage +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4b27a07..651e82e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -165,10 +165,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. service.Done(c) - err := resp.Body.Close() - if err != nil { - common.LogError(c, "close_response_body_failed: "+err.Error()) - } + resp.Body.Close() return nil, usage } @@ -206,11 +203,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - + resp.Body.Close() if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { @@ -257,7 +250,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens - return nil, usage } @@ -290,10 +282,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + resp.Body.Close() var text string switch responseFormat { @@ -313,7 +302,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage.PromptTokens = info.PromptTokens usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return nil, usage } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 9ef9a8b..6daf003 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -1,50 +1,17 @@ package common import ( - "encoding/json" "fmt" "github.com/gin-gonic/gin" _ "image/gif" _ "image/jpeg" _ "image/png" - "io" - "net/http" "one-api/common" - "one-api/dto" - "strconv" "strings" ) var StopFinishReason = "stop" -func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { - OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - Error: dto.OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var textResponse dto.TextResponseWithError - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody) - return - } - OpenAIErrorWithStatusCode.Error = textResponse.Error - return -} - func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 05b723c..2a0278e 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -105,6 +105,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { audioRequest.Model = modelMap[audioRequest.Model] } } + relayInfo.UpstreamModelName = audioRequest.Model adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { diff --git a/relay/relay-image.go b/relay/relay-image.go index d83ec26..6d6e4d4 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -180,7 +180,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if resp.StatusCode != http.StatusOK { - return relaycommon.RelayErrorHandler(resp) + return service.RelayErrorHandler(resp) } var textResponse dto.ImageResponse diff --git a/service/error.go b/service/error.go index 0f6d472..3410de8 100644 --- a/service/error.go +++ b/service/error.go @@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, Error: dto.OpenAIError{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), }, } responseBody, err := io.ReadAll(resp.Body)