♻️ refactor: 重构http请求函数

This commit is contained in:
Martial BE
2023-11-30 13:49:35 +08:00
parent 96dc7614e6
commit 7c6dee7390
17 changed files with 116 additions and 143 deletions

View File

@@ -7,7 +7,6 @@ import (
"net/http"
"one-api/common"
"one-api/types"
"strconv"
"strings"
"github.com/gin-gonic/gin"
@@ -54,42 +53,42 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
}
// 发送请求
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := common.HttpClient.Do(req)
if err != nil {
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true)
if openAIErrorWithStatusCode != nil {
return
}
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp)
}
// 解析响应
err = common.DecodeResponse(resp.Body, response)
if err != nil {
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
if rawOutput {
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
if err != nil {
return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(p.Context.Writer, resp.Body)
if err != nil {
return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
} else {
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
if err != nil {
return types.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
}
return nil
@@ -107,7 +106,7 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC
// 处理响应
if common.IsFailureStatusCode(resp) {
return p.HandleErrorResp(resp)
return common.HandleErrorResp(resp)
}
for k, v := range resp.Header {
@@ -124,38 +123,6 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC
return nil
}
// 处理错误响应
func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var errorResponse types.OpenAIErrorResponse
err = json.Unmarshal(responseBody, &errorResponse)
if err != nil {
return
}
if errorResponse.Error.Type != "" {
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody)
}
return
}
func (p *BaseProvider) SupportAPI(relayMode int) bool {
switch relayMode {
case common.RelayModeChatCompletions: