one-api/common/requester/http_requester.go
Buer c1fc32add7
feat: add Midjourney (#138)
* 🚧 stash

*  feat: add Midjourney

* 📝 doc: update readme
2024-04-05 04:03:46 +08:00

249 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package requester
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type HttpErrorHandler func(*http.Response) *types.OpenAIError
type HTTPRequester struct {
requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder
ErrorHandler HttpErrorHandler
proxyAddr string
Context context.Context
}
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
// proxyAddr: 是代理服务器的地址。
// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。
// 如果 errorHandler 为 nil那么会使用一个默认的错误处理函数。
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
return &HTTPRequester{
requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
},
ErrorHandler: errorHandler,
proxyAddr: proxyAddr,
Context: context.Background(),
}
}
type requestOptions struct {
body any
header http.Header
}
type requestOption func(*requestOptions)
func (r *HTTPRequester) setProxy() context.Context {
if r.proxyAddr == "" {
return r.Context
}
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
if strings.HasPrefix(r.proxyAddr, "socks5://") {
return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr)
}
// 否则使用 http 代理
return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr)
}
// 创建请求
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := r.requestBuilder.Build(r.setProxy(), method, url, args.body, args.header)
if err != nil {
return nil, err
}
return req, nil
}
// 发送请求
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
resp, err := HTTPClient.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
// 解析响应
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = json.NewDecoder(resp.Body).Decode(response)
}
if err != nil {
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
return resp, nil
}
// 发送请求 RAW
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
resp, err := HTTPClient.Do(req)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
return resp, nil
}
// 获取流式响应
func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
// 如果返回的头是json格式 说明有错误
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
return nil, HandleErrorResp(resp, requester.ErrorHandler)
}
stream := &streamReader[T]{
reader: bufio.NewReader(resp.Body),
response: resp,
handlerPrefix: handlerPrefix,
DataChan: make(chan T),
ErrChan: make(chan error),
}
return stream, nil
}
// 设置请求体
func (r *HTTPRequester) WithBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}
// 设置请求头
func (r *HTTPRequester) WithHeader(header map[string]string) requestOption {
return func(args *requestOptions) {
for k, v := range header {
args.header.Set(k, v)
}
}
}
// 设置Content-Type
func (r *HTTPRequester) WithContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}
// 判断是否为失败状态码
func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
// 处理错误响应
func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode {
openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
defer resp.Body.Close()
if toOpenAIError != nil {
errorResponse := toOpenAIError(resp)
if errorResponse != nil && errorResponse.Message != "" {
openAIErrorWithStatusCode.OpenAIError = *errorResponse
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: %s", openAIErrorWithStatusCode.OpenAIError.Message)
}
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: bad response status code %d", resp.StatusCode)
}
return openAIErrorWithStatusCode
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
type Stringer interface {
GetString() *string
}
func DecodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return DecodeString(body, result)
}
if stringer, ok := v.(Stringer); ok {
return DecodeString(body, stringer.GetString())
}
return json.NewDecoder(body).Decode(v)
}
func DecodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}