mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-30 07:06:38 +08:00
* 💄 improve: http client changes proxy using context * 🐛 fix: glm-4v support base64 image
247 lines
6.2 KiB
Go
247 lines
6.2 KiB
Go
package requester
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/types"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
type HttpErrorHandler func(*http.Response) *types.OpenAIError
|
||
|
||
type HTTPRequester struct {
|
||
requestBuilder RequestBuilder
|
||
CreateFormBuilder func(io.Writer) FormBuilder
|
||
ErrorHandler HttpErrorHandler
|
||
proxyAddr string
|
||
}
|
||
|
||
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
|
||
// proxyAddr: 是代理服务器的地址。
|
||
// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。
|
||
// 如果 errorHandler 为 nil,那么会使用一个默认的错误处理函数。
|
||
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
|
||
return &HTTPRequester{
|
||
requestBuilder: NewRequestBuilder(),
|
||
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
||
return NewFormBuilder(body)
|
||
},
|
||
ErrorHandler: errorHandler,
|
||
proxyAddr: proxyAddr,
|
||
}
|
||
}
|
||
|
||
type requestOptions struct {
|
||
body any
|
||
header http.Header
|
||
}
|
||
|
||
type requestOption func(*requestOptions)
|
||
|
||
func (r *HTTPRequester) getContext() context.Context {
|
||
if r.proxyAddr == "" {
|
||
return context.Background()
|
||
}
|
||
|
||
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
|
||
if strings.HasPrefix(r.proxyAddr, "socks5://") {
|
||
return context.WithValue(context.Background(), ProxySock5AddrKey, r.proxyAddr)
|
||
}
|
||
|
||
// 否则使用 http 代理
|
||
return context.WithValue(context.Background(), ProxyHTTPAddrKey, r.proxyAddr)
|
||
|
||
}
|
||
|
||
// 创建请求
|
||
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
||
args := &requestOptions{
|
||
body: nil,
|
||
header: make(http.Header),
|
||
}
|
||
for _, setter := range setters {
|
||
setter(args)
|
||
}
|
||
req, err := r.requestBuilder.Build(r.getContext(), method, url, args.body, args.header)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return req, nil
|
||
}
|
||
|
||
// 发送请求
|
||
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||
resp, err := HTTPClient.Do(req)
|
||
if err != nil {
|
||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
if !outputResp {
|
||
defer resp.Body.Close()
|
||
}
|
||
|
||
// 处理响应
|
||
if r.IsFailureStatusCode(resp) {
|
||
return nil, HandleErrorResp(resp, r.ErrorHandler)
|
||
}
|
||
|
||
// 解析响应
|
||
if outputResp {
|
||
var buf bytes.Buffer
|
||
tee := io.TeeReader(resp.Body, &buf)
|
||
err = DecodeResponse(tee, response)
|
||
|
||
// 将响应体重新写入 resp.Body
|
||
resp.Body = io.NopCloser(&buf)
|
||
} else {
|
||
err = json.NewDecoder(resp.Body).Decode(response)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// 发送请求 RAW
|
||
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||
// 发送请求
|
||
resp, err := HTTPClient.Do(req)
|
||
if err != nil {
|
||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
// 处理响应
|
||
if r.IsFailureStatusCode(resp) {
|
||
return nil, HandleErrorResp(resp, r.ErrorHandler)
|
||
}
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// 获取流式响应
|
||
func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
|
||
// 如果返回的头是json格式 说明有错误
|
||
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
|
||
return nil, HandleErrorResp(resp, requester.ErrorHandler)
|
||
}
|
||
|
||
stream := &streamReader[T]{
|
||
reader: bufio.NewReader(resp.Body),
|
||
response: resp,
|
||
handlerPrefix: handlerPrefix,
|
||
|
||
DataChan: make(chan T),
|
||
ErrChan: make(chan error),
|
||
}
|
||
|
||
return stream, nil
|
||
}
|
||
|
||
// 设置请求体
|
||
func (r *HTTPRequester) WithBody(body any) requestOption {
|
||
return func(args *requestOptions) {
|
||
args.body = body
|
||
}
|
||
}
|
||
|
||
// 设置请求头
|
||
func (r *HTTPRequester) WithHeader(header map[string]string) requestOption {
|
||
return func(args *requestOptions) {
|
||
for k, v := range header {
|
||
args.header.Set(k, v)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 设置Content-Type
|
||
func (r *HTTPRequester) WithContentType(contentType string) requestOption {
|
||
return func(args *requestOptions) {
|
||
args.header.Set("Content-Type", contentType)
|
||
}
|
||
}
|
||
|
||
// 判断是否为失败状态码
|
||
func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool {
|
||
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
||
}
|
||
|
||
// 处理错误响应
|
||
func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode {
|
||
|
||
openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{
|
||
StatusCode: resp.StatusCode,
|
||
OpenAIError: types.OpenAIError{
|
||
Message: "",
|
||
Type: "upstream_error",
|
||
Code: "bad_response_status_code",
|
||
Param: strconv.Itoa(resp.StatusCode),
|
||
},
|
||
}
|
||
|
||
defer resp.Body.Close()
|
||
|
||
if toOpenAIError != nil {
|
||
errorResponse := toOpenAIError(resp)
|
||
|
||
if errorResponse != nil && errorResponse.Message != "" {
|
||
openAIErrorWithStatusCode.OpenAIError = *errorResponse
|
||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: %s", openAIErrorWithStatusCode.OpenAIError.Message)
|
||
}
|
||
}
|
||
|
||
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: bad response status code %d", resp.StatusCode)
|
||
}
|
||
|
||
return openAIErrorWithStatusCode
|
||
}
|
||
|
||
func SetEventStreamHeaders(c *gin.Context) {
|
||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||
c.Writer.Header().Set("Connection", "keep-alive")
|
||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||
}
|
||
|
||
type Stringer interface {
|
||
GetString() *string
|
||
}
|
||
|
||
func DecodeResponse(body io.Reader, v any) error {
|
||
if v == nil {
|
||
return nil
|
||
}
|
||
|
||
if result, ok := v.(*string); ok {
|
||
return DecodeString(body, result)
|
||
}
|
||
|
||
if stringer, ok := v.(Stringer); ok {
|
||
return DecodeString(body, stringer.GetString())
|
||
}
|
||
|
||
return json.NewDecoder(body).Decode(v)
|
||
}
|
||
|
||
func DecodeString(body io.Reader, output *string) error {
|
||
b, err := io.ReadAll(body)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
*output = string(b)
|
||
return nil
|
||
}
|