♻️ refactor: provider refactor (#41)

* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
Buer
2024-01-19 02:47:10 +08:00
committed by GitHub
parent 0bfe1f5779
commit ef041e28a1
96 changed files with 4339 additions and 3276 deletions

View File

@@ -0,0 +1,71 @@
package requester
import (
"fmt"
"io"
"mime/multipart"
"path"
)
type FormBuilder interface {
CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
}
type DefaultFormBuilder struct {
writer *multipart.Writer
}
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
return &DefaultFormBuilder{
writer: multipart.NewWriter(body),
}
}
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error {
file, err := fileHeader.Open()
if err != nil {
return err
}
defer file.Close()
return fb.createFormFile(fieldname, file, fileHeader.Filename)
}
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename))
}
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
if filename == "" {
return fmt.Errorf("filename cannot be empty")
}
fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}
return nil
}
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
return fb.writer.WriteField(fieldname, value)
}
func (fb *DefaultFormBuilder) Close() error {
return fb.writer.Close()
}
func (fb *DefaultFormBuilder) FormDataContentType() string {
return fb.writer.FormDataContentType()
}

View File

@@ -0,0 +1,68 @@
package requester
import (
"fmt"
"net/http"
"net/url"
"one-api/common"
"sync"
"time"
"golang.org/x/net/proxy"
)
type HTTPClient struct{}
var clientPool = &sync.Pool{
New: func() interface{} {
return &http.Client{}
},
}
func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client {
client := clientPool.Get().(*http.Client)
if common.RelayTimeout > 0 {
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
}
if proxyAddr != "" {
err := h.setProxy(client, proxyAddr)
if err != nil {
common.SysError(err.Error())
return client
}
}
return client
}
func (h *HTTPClient) returnClientToPool(client *http.Client) {
clientPool.Put(client)
}
func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err)
}
client.Transport = &http.Transport{
Dial: dialer.Dial,
}
default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
return nil
}

View File

@@ -0,0 +1,229 @@
package requester
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strconv"
"github.com/gin-gonic/gin"
)
type HttpErrorHandler func(*http.Response) *types.OpenAIError
type HTTPRequester struct {
HTTPClient HTTPClient
requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder
ErrorHandler HttpErrorHandler
proxyAddr string
}
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
// proxyAddr: 是代理服务器的地址。
// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。
// 如果 errorHandler 为 nil那么会使用一个默认的错误处理函数。
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
return &HTTPRequester{
HTTPClient: HTTPClient{},
requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
},
ErrorHandler: errorHandler,
proxyAddr: proxyAddr,
}
}
type requestOptions struct {
body any
header http.Header
}
type requestOption func(*requestOptions)
// 创建请求
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := r.requestBuilder.Build(method, url, args.body, args.header)
if err != nil {
return nil, err
}
return req, nil
}
// 发送请求
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
// 解析响应
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = json.NewDecoder(resp.Body).Decode(response)
}
if err != nil {
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
return resp, nil
}
// 发送请求 RAW
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
return resp, nil
}
// 获取流式响应
func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
// 如果返回的头是json格式 说明有错误
if resp.Header.Get("Content-Type") == "application/json" {
return nil, HandleErrorResp(resp, requester.ErrorHandler)
}
return &streamReader[T]{
reader: bufio.NewReader(resp.Body),
response: resp,
handlerPrefix: handlerPrefix,
}, nil
}
// 设置请求体
func (r *HTTPRequester) WithBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}
// 设置请求头
func (r *HTTPRequester) WithHeader(header map[string]string) requestOption {
return func(args *requestOptions) {
for k, v := range header {
args.header.Set(k, v)
}
}
}
// 设置Content-Type
func (r *HTTPRequester) WithContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}
// 判断是否为失败状态码
func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
// 处理错误响应
func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode {
openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
defer resp.Body.Close()
if toOpenAIError != nil {
errorResponse := toOpenAIError(resp)
if errorResponse != nil && errorResponse.Message != "" {
openAIErrorWithStatusCode.OpenAIError = *errorResponse
}
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return openAIErrorWithStatusCode
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
type Stringer interface {
GetString() *string
}
func DecodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return DecodeString(body, result)
}
if stringer, ok := v.(Stringer); ok {
return DecodeString(body, stringer.GetString())
}
return json.NewDecoder(body).Decode(v)
}
func DecodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}

View File

@@ -0,0 +1,79 @@
package requester
import (
"bufio"
"bytes"
"io"
"net/http"
)
// 流处理函数,判断依据如下:
// 1.如果有错误信息,则直接返回错误信息
// 2.如果isFinished=true则返回io.EOF并且如果response不为空还将返回response
// 3.如果rawLine=nil 或者 response长度为0则直接跳过
// 4.如果以上条件都不满足则返回response
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
type streamable interface {
// types.ChatCompletionStreamResponse | types.CompletionResponse
any
}
type StreamReaderInterface[T streamable] interface {
Recv() (*[]T, error)
Close()
}
type streamReader[T streamable] struct {
isFinished bool
reader *bufio.Reader
response *http.Response
handlerPrefix HandlerPrefix[T]
}
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
//nolint:gocognit
func (stream *streamReader[T]) processLines() (*[]T, error) {
for {
rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil {
return nil, readErr
}
noSpaceLine := bytes.TrimSpace(rawLine)
var response []T
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if noSpaceLine == nil || len(response) == 0 {
continue
}
return &response, nil
}
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}

View File

@@ -0,0 +1,51 @@
package requester
import (
"bytes"
"io"
"net/http"
"one-api/common"
)
type RequestBuilder interface {
Build(method, url string, body any, header http.Header) (*http.Request, error)
}
type HTTPRequestBuilder struct {
marshaller common.Marshaller
}
func NewRequestBuilder() *HTTPRequestBuilder {
return &HTTPRequestBuilder{
marshaller: &common.JSONMarshaller{},
}
}
func (b *HTTPRequestBuilder) Build(
method string,
url string,
body any,
header http.Header,
) (req *http.Request, err error) {
var bodyReader io.Reader
if body != nil {
if v, ok := body.(io.Reader); ok {
bodyReader = v
} else {
var reqBytes []byte
reqBytes, err = b.marshaller.Marshal(body)
if err != nil {
return
}
bodyReader = bytes.NewBuffer(reqBytes)
}
}
req, err = http.NewRequest(method, url, bodyReader)
if err != nil {
return
}
if header != nil {
req.Header = header
}
return
}

View File

@@ -0,0 +1,53 @@
package requester
import (
"fmt"
"net"
"net/http"
"net/url"
"one-api/common"
"time"
"github.com/gorilla/websocket"
"golang.org/x/net/proxy"
)
func GetWSClient(proxyAddr string) *websocket.Dialer {
dialer := &websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
if proxyAddr != "" {
err := setWSProxy(dialer, proxyAddr)
if err != nil {
common.SysError(err.Error())
return dialer
}
}
return dialer
}
func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
dialer.Proxy = http.ProxyURL(proxyURL)
case "socks5":
socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err)
}
dialer.NetDial = func(network, addr string) (net.Conn, error) {
return socks5Proxy.Dial(network, addr)
}
default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
return nil
}

View File

@@ -0,0 +1,58 @@
package requester
import (
"io"
"github.com/gorilla/websocket"
)
type wsReader[T streamable] struct {
isFinished bool
reader *websocket.Conn
handlerPrefix HandlerPrefix[T]
}
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
func (stream *wsReader[T]) processLines() (*[]T, error) {
for {
_, msg, err := stream.reader.ReadMessage()
if err != nil {
return nil, err
}
var response []T
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if msg == nil || len(response) == 0 {
continue
}
return &response, nil
}
}
func (stream *wsReader[T]) Close() {
stream.reader.Close()
}

View File

@@ -0,0 +1,54 @@
package requester
import (
"errors"
"net/http"
"one-api/common"
"one-api/types"
"github.com/gorilla/websocket"
)
type WSRequester struct {
WSClient *websocket.Dialer
}
func NewWSRequester(proxyAddr string) *WSRequester {
return &WSRequester{
WSClient: GetWSClient(proxyAddr),
}
}
func (w *WSRequester) NewRequest(url string, header http.Header) (*websocket.Conn, error) {
conn, resp, err := w.WSClient.Dial(url, header)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, errors.New("ws unexpected status code")
}
return conn, nil
}
func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPrefix HandlerPrefix[T]) (*wsReader[T], *types.OpenAIErrorWithStatusCode) {
err := conn.WriteJSON(data)
if err != nil {
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
}
return &wsReader[T]{
reader: conn,
handlerPrefix: handlerPrefix,
}, nil
}
// 设置请求头
func (r *WSRequester) WithHeader(headers map[string]string) http.Header {
header := make(http.Header)
for k, v := range headers {
header.Set(k, v)
}
return header
}