♻️ refactor: provider refactor (#41)

* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
Buer
2024-01-19 02:47:10 +08:00
committed by GitHub
parent 0bfe1f5779
commit ef041e28a1
96 changed files with 4339 additions and 3276 deletions

View File

@@ -1,299 +0,0 @@
package common
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"one-api/types"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/proxy"
)
var clientPool = &sync.Pool{
New: func() interface{} {
return &http.Client{}
},
}
func GetHttpClient(proxyAddr string) *http.Client {
client := clientPool.Get().(*http.Client)
if RelayTimeout > 0 {
client.Timeout = time.Duration(RelayTimeout) * time.Second
}
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
SysError("Error parsing proxy address: " + err.Error())
return client
}
switch proxyURL.Scheme {
case "http", "https":
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
SysError("Error creating SOCKS5 dialer: " + err.Error())
return client
}
client.Transport = &http.Transport{
Dial: dialer.Dial,
}
default:
SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
}
}
return client
}
func PutHttpClient(c *http.Client) {
clientPool.Put(c)
}
type Client struct {
requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder
}
func NewClient() *Client {
return &Client{
requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
},
}
}
type requestOptions struct {
body any
header http.Header
}
type requestOption func(*requestOptions)
type Stringer interface {
GetString() *string
}
func WithBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}
func WithHeader(header map[string]string) requestOption {
return func(args *requestOptions) {
for k, v := range header {
args.header.Set(k, v)
}
}
}
func WithContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}
type RequestError struct {
HTTPStatusCode int
Err error
}
func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
// Default Options
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := c.requestBuilder.Build(method, url, args.body, args.header)
if err != nil {
return nil, err
}
return req, nil
}
func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
client := GetHttpClient(proxyAddr)
resp, err := client.Do(req)
if err != nil {
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
PutHttpClient(client)
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp)
}
// 解析响应
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = DecodeResponse(resp.Body, response)
}
if err != nil {
return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
if outputResp {
return resp, nil
}
return nil, nil
}
type GeneralErrorResponse struct {
Error types.OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
// 处理错误响应
func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
// var errorResponse types.OpenAIErrorResponse
var errorResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errorResponse)
if err != nil {
return
}
if errorResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage()
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}
func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
client := GetHttpClient(proxyAddr)
resp, err := client.Do(req)
PutHttpClient(client)
if err != nil {
return
}
return resp.Body, nil
}
func IsFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
func DecodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return DecodeString(body, result)
}
if stringer, ok := v.(Stringer); ok {
return DecodeString(body, stringer.GetString())
}
return json.NewDecoder(body).Decode(v)
}
func DecodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@@ -37,6 +37,14 @@ func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWith
return StringErrorWrapper(err.Error(), code, statusCode)
}
func ErrorToOpenAIError(err error) *types.OpenAIError {
return &types.OpenAIError{
Code: "system error",
Message: err.Error(),
Type: "one_api_error",
}
}
func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode {
openAIError := types.OpenAIError{
Message: err,

View File

@@ -1,59 +0,0 @@
package common
// type Quota struct {
// ModelName string
// ModelRatio float64
// GroupRatio float64
// Ratio float64
// UserQuota int
// }
// func CreateQuota(modelName string, userQuota int, group string) *Quota {
// modelRatio := GetModelRatio(modelName)
// groupRatio := GetGroupRatio(group)
// return &Quota{
// ModelName: modelName,
// ModelRatio: modelRatio,
// GroupRatio: groupRatio,
// Ratio: modelRatio * groupRatio,
// UserQuota: userQuota,
// }
// }
// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
// if ApproximateTokenEnabled {
// return int(float64(len(text)) * 0.38)
// }
// return len(tokenEncoder.Encode(text, nil, nil))
// }
// func (q *Quota) CountTokenMessages(messages []Message, model string) int {
// tokenEncoder := q.getTokenEncoder(model)
// // Reference:
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// // https://github.com/pkoukk/tiktoken-go/issues/6
// //
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
// var tokensPerMessage int
// var tokensPerName int
// if model == "gpt-3.5-turbo-0301" {
// tokensPerMessage = 4
// tokensPerName = -1 // If there's a name, the role is omitted
// } else {
// tokensPerMessage = 3
// tokensPerName = 1
// }
// tokenNum := 0
// for _, message := range messages {
// tokenNum += tokensPerMessage
// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent())
// tokenNum += q.getTokenNum(tokenEncoder, message.Role)
// if message.Name != nil {
// tokenNum += tokensPerName
// tokenNum += q.getTokenNum(tokenEncoder, *message.Name)
// }
// }
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
// return tokenNum
// }

View File

@@ -1,4 +1,4 @@
package common
package requester
import (
"fmt"

View File

@@ -0,0 +1,68 @@
package requester
import (
"fmt"
"net/http"
"net/url"
"one-api/common"
"sync"
"time"
"golang.org/x/net/proxy"
)
type HTTPClient struct{}
var clientPool = &sync.Pool{
New: func() interface{} {
return &http.Client{}
},
}
func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client {
client := clientPool.Get().(*http.Client)
if common.RelayTimeout > 0 {
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
}
if proxyAddr != "" {
err := h.setProxy(client, proxyAddr)
if err != nil {
common.SysError(err.Error())
return client
}
}
return client
}
func (h *HTTPClient) returnClientToPool(client *http.Client) {
clientPool.Put(client)
}
func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
case "socks5":
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err)
}
client.Transport = &http.Transport{
Dial: dialer.Dial,
}
default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
return nil
}

View File

@@ -0,0 +1,229 @@
package requester
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strconv"
"github.com/gin-gonic/gin"
)
type HttpErrorHandler func(*http.Response) *types.OpenAIError
type HTTPRequester struct {
HTTPClient HTTPClient
requestBuilder RequestBuilder
CreateFormBuilder func(io.Writer) FormBuilder
ErrorHandler HttpErrorHandler
proxyAddr string
}
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
// proxyAddr: 是代理服务器的地址。
// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。
// 如果 errorHandler 为 nil那么会使用一个默认的错误处理函数。
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
return &HTTPRequester{
HTTPClient: HTTPClient{},
requestBuilder: NewRequestBuilder(),
CreateFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
},
ErrorHandler: errorHandler,
proxyAddr: proxyAddr,
}
}
type requestOptions struct {
body any
header http.Header
}
type requestOption func(*requestOptions)
// 创建请求
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := r.requestBuilder.Build(method, url, args.body, args.header)
if err != nil {
return nil, err
}
return req, nil
}
// 发送请求
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
if !outputResp {
defer resp.Body.Close()
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
// 解析响应
if outputResp {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
err = DecodeResponse(tee, response)
// 将响应体重新写入 resp.Body
resp.Body = io.NopCloser(&buf)
} else {
err = json.NewDecoder(resp.Body).Decode(response)
}
if err != nil {
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
}
return resp, nil
}
// 发送请求 RAW
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
// 发送请求
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
resp, err := client.Do(req)
r.HTTPClient.returnClientToPool(client)
if err != nil {
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
// 处理响应
if r.IsFailureStatusCode(resp) {
return nil, HandleErrorResp(resp, r.ErrorHandler)
}
return resp, nil
}
// 获取流式响应
func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
// 如果返回的头是json格式 说明有错误
if resp.Header.Get("Content-Type") == "application/json" {
return nil, HandleErrorResp(resp, requester.ErrorHandler)
}
return &streamReader[T]{
reader: bufio.NewReader(resp.Body),
response: resp,
handlerPrefix: handlerPrefix,
}, nil
}
// 设置请求体
func (r *HTTPRequester) WithBody(body any) requestOption {
return func(args *requestOptions) {
args.body = body
}
}
// 设置请求头
func (r *HTTPRequester) WithHeader(header map[string]string) requestOption {
return func(args *requestOptions) {
for k, v := range header {
args.header.Set(k, v)
}
}
}
// 设置Content-Type
func (r *HTTPRequester) WithContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
}
}
// 判断是否为失败状态码
func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}
// 处理错误响应
func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode {
openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: types.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
defer resp.Body.Close()
if toOpenAIError != nil {
errorResponse := toOpenAIError(resp)
if errorResponse != nil && errorResponse.Message != "" {
openAIErrorWithStatusCode.OpenAIError = *errorResponse
}
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return openAIErrorWithStatusCode
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
type Stringer interface {
GetString() *string
}
func DecodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
}
if result, ok := v.(*string); ok {
return DecodeString(body, result)
}
if stringer, ok := v.(Stringer); ok {
return DecodeString(body, stringer.GetString())
}
return json.NewDecoder(body).Decode(v)
}
func DecodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil
}

View File

@@ -0,0 +1,79 @@
package requester
import (
"bufio"
"bytes"
"io"
"net/http"
)
// 流处理函数,判断依据如下:
// 1.如果有错误信息,则直接返回错误信息
// 2.如果isFinished=true则返回io.EOF并且如果response不为空还将返回response
// 3.如果rawLine=nil 或者 response长度为0则直接跳过
// 4.如果以上条件都不满足则返回response
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
type streamable interface {
// types.ChatCompletionStreamResponse | types.CompletionResponse
any
}
type StreamReaderInterface[T streamable] interface {
Recv() (*[]T, error)
Close()
}
type streamReader[T streamable] struct {
isFinished bool
reader *bufio.Reader
response *http.Response
handlerPrefix HandlerPrefix[T]
}
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
//nolint:gocognit
func (stream *streamReader[T]) processLines() (*[]T, error) {
for {
rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil {
return nil, readErr
}
noSpaceLine := bytes.TrimSpace(rawLine)
var response []T
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if noSpaceLine == nil || len(response) == 0 {
continue
}
return &response, nil
}
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}

View File

@@ -1,9 +1,10 @@
package common
package requester
import (
"bytes"
"io"
"net/http"
"one-api/common"
)
type RequestBuilder interface {
@@ -11,12 +12,12 @@ type RequestBuilder interface {
}
type HTTPRequestBuilder struct {
marshaller Marshaller
marshaller common.Marshaller
}
func NewRequestBuilder() *HTTPRequestBuilder {
return &HTTPRequestBuilder{
marshaller: &JSONMarshaller{},
marshaller: &common.JSONMarshaller{},
}
}

View File

@@ -0,0 +1,53 @@
package requester
import (
"fmt"
"net"
"net/http"
"net/url"
"one-api/common"
"time"
"github.com/gorilla/websocket"
"golang.org/x/net/proxy"
)
func GetWSClient(proxyAddr string) *websocket.Dialer {
dialer := &websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
if proxyAddr != "" {
err := setWSProxy(dialer, proxyAddr)
if err != nil {
common.SysError(err.Error())
return dialer
}
}
return dialer
}
func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("error parsing proxy address: %w", err)
}
switch proxyURL.Scheme {
case "http", "https":
dialer.Proxy = http.ProxyURL(proxyURL)
case "socks5":
socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating socks5 dialer: %w", err)
}
dialer.NetDial = func(network, addr string) (net.Conn, error) {
return socks5Proxy.Dial(network, addr)
}
default:
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
return nil
}

View File

@@ -0,0 +1,58 @@
package requester
import (
"io"
"github.com/gorilla/websocket"
)
type wsReader[T streamable] struct {
isFinished bool
reader *websocket.Conn
handlerPrefix HandlerPrefix[T]
}
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
func (stream *wsReader[T]) processLines() (*[]T, error) {
for {
_, msg, err := stream.reader.ReadMessage()
if err != nil {
return nil, err
}
var response []T
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if msg == nil || len(response) == 0 {
continue
}
return &response, nil
}
}
func (stream *wsReader[T]) Close() {
stream.reader.Close()
}

View File

@@ -0,0 +1,54 @@
package requester
import (
"errors"
"net/http"
"one-api/common"
"one-api/types"
"github.com/gorilla/websocket"
)
type WSRequester struct {
WSClient *websocket.Dialer
}
func NewWSRequester(proxyAddr string) *WSRequester {
return &WSRequester{
WSClient: GetWSClient(proxyAddr),
}
}
func (w *WSRequester) NewRequest(url string, header http.Header) (*websocket.Conn, error) {
conn, resp, err := w.WSClient.Dial(url, header)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, errors.New("ws unexpected status code")
}
return conn, nil
}
func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPrefix HandlerPrefix[T]) (*wsReader[T], *types.OpenAIErrorWithStatusCode) {
err := conn.WriteJSON(data)
if err != nil {
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
}
return &wsReader[T]{
reader: conn,
handlerPrefix: handlerPrefix,
}, nil
}
// 设置请求头
func (r *WSRequester) WithHeader(headers map[string]string) http.Header {
header := make(http.Header)
for k, v := range headers {
header.Set(k, v)
}
return header
}

55
common/test/api.go Normal file
View File

@@ -0,0 +1,55 @@
package test
import (
"io"
"net/http"
"net/http/httptest"
"one-api/model"
"github.com/gin-gonic/gin"
)
func RequestJSONConfig() map[string]string {
return map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
}
}
func GetContext(method, path string, headers map[string]string, body io.Reader) (*gin.Context, *httptest.ResponseRecorder) {
var req *http.Request
req, _ = http.NewRequest(method, path, body)
for k, v := range headers {
req.Header.Set(k, v)
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
return c, w
}
func GetGinRouter(method, path string, headers map[string]string, body *io.Reader) *httptest.ResponseRecorder {
var req *http.Request
r := gin.Default()
w := httptest.NewRecorder()
req, _ = http.NewRequest(method, path, *body)
for k, v := range headers {
req.Header.Set(k, v)
}
r.ServeHTTP(w, req)
return w
}
func GetChannel(channelType int, baseUrl, other, porxy, modelMapping string) model.Channel {
return model.Channel{
Type: channelType,
BaseURL: &baseUrl,
Other: other,
Proxy: porxy,
ModelMapping: &modelMapping,
Key: GetTestToken(),
}
}

132
common/test/chat_config.go Normal file
View File

@@ -0,0 +1,132 @@
package test
import (
"encoding/json"
"one-api/types"
"strings"
)
func GetChatCompletionRequest(chatType, modelName, stream string) *types.ChatCompletionRequest {
chatJSON := GetChatRequest(chatType, modelName, stream)
chatCompletionRequest := &types.ChatCompletionRequest{}
json.NewDecoder(chatJSON).Decode(chatCompletionRequest)
return chatCompletionRequest
}
func GetChatRequest(chatType, modelName, stream string) *strings.Reader {
var chatJSON string
switch chatType {
case "image":
chatJSON = `{
"model": "` + modelName + `",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
"max_tokens": 300,
"stream": ` + stream + `
}`
case "default":
chatJSON = `{
"model": "` + modelName + `",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}
],
"stream": ` + stream + `
}`
case "function":
chatJSON = `{
"model": "` + modelName + `",
"stream": ` + stream + `,
"messages": [
{
"role": "user",
"content": "What is the weather like in Boston?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
],
"tool_choice": "auto"
}`
case "tools":
chatJSON = `{
"model": "` + modelName + `",
"stream": ` + stream + `,
"messages": [
{
"role": "user",
"content": "What is the weather like in Boston?"
}
],
"functions": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
]
}
},
"required": [
"location"
]
}
}
]
}`
}
return strings.NewReader(chatJSON)
}

65
common/test/check_chat.go Normal file
View File

@@ -0,0 +1,65 @@
package test
import (
"one-api/types"
"testing"
"github.com/stretchr/testify/assert"
)
func CheckChat(t *testing.T, response *types.ChatCompletionResponse, modelName string, usage *types.Usage) {
assert.NotEmpty(t, response.ID)
assert.NotEmpty(t, response.Object)
assert.NotEmpty(t, response.Created)
assert.Equal(t, response.Model, modelName)
assert.IsType(t, []types.ChatCompletionChoice{}, response.Choices)
// check choices 长度大于1
assert.True(t, len(response.Choices) > 0)
for _, choice := range response.Choices {
assert.NotNil(t, choice.Index)
assert.IsType(t, types.ChatCompletionMessage{}, choice.Message)
assert.NotEmpty(t, choice.Message.Role)
assert.NotEmpty(t, choice.FinishReason)
// check message
if choice.Message.Content != nil {
multiContents, ok := choice.Message.Content.([]types.ChatMessagePart)
if ok {
for _, content := range multiContents {
assert.NotEmpty(t, content.Type)
if content.Type == "text" {
assert.NotEmpty(t, content.Text)
} else if content.Type == "image_url" {
assert.IsType(t, types.ChatMessageImageURL{}, content.ImageURL)
}
}
} else {
content, ok := choice.Message.Content.(string)
assert.True(t, ok)
assert.NotEmpty(t, content)
}
} else if choice.Message.FunctionCall != nil {
assert.NotEmpty(t, choice.Message.FunctionCall.Name)
assert.Equal(t, choice.FinishReason, types.FinishReasonFunctionCall)
} else if choice.Message.ToolCalls != nil {
assert.IsType(t, []types.ChatCompletionToolCalls{}, choice.Message.ToolCalls)
assert.NotEmpty(t, choice.Message.ToolCalls[0].Id)
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function)
assert.Equal(t, choice.Message.ToolCalls[0].Function, "function")
assert.IsType(t, types.ChatCompletionToolCallsFunction{}, choice.Message.ToolCalls[0].Function)
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function.Name)
assert.Equal(t, choice.FinishReason, types.FinishReasonToolCalls)
} else {
assert.Fail(t, "message content is nil")
}
}
// check usage
assert.IsType(t, &types.Usage{}, response.Usage)
assert.Equal(t, response.Usage.PromptTokens, usage.PromptTokens)
assert.Equal(t, response.Usage.CompletionTokens, usage.CompletionTokens)
assert.Equal(t, response.Usage.TotalTokens, usage.TotalTokens)
}

48
common/test/checks.go Normal file
View File

@@ -0,0 +1,48 @@
package test
import (
"errors"
"testing"
)
func NoError(t *testing.T, err error, message ...string) {
t.Helper()
if err != nil {
t.Error(err, message)
}
}
func HasError(t *testing.T, err error, message ...string) {
t.Helper()
if err == nil {
t.Error(err, message)
}
}
func ErrorIs(t *testing.T, err, target error, msg ...string) {
t.Helper()
if !errors.Is(err, target) {
t.Fatal(msg)
}
}
func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) {
t.Helper()
if !errors.Is(err, target) {
t.Fatalf(format, msg)
}
}
func ErrorIsNot(t *testing.T, err, target error, msg ...string) {
t.Helper()
if errors.Is(err, target) {
t.Fatal(msg)
}
}
func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) {
t.Helper()
if errors.Is(err, target) {
t.Fatalf(format, msg)
}
}

7
common/test/init/init.go Normal file
View File

@@ -0,0 +1,7 @@
package init
import "testing"
func init() {
testing.Init()
}

63
common/test/server.go Normal file
View File

@@ -0,0 +1,63 @@
package test
import (
"log"
"net/http"
"net/http/httptest"
"regexp"
"strings"
)
const testAPI = "this-is-my-secure-token-do-not-steal!!"
func GetTestToken() string {
return testAPI
}
type ServerTest struct {
handlers map[string]handler
}
type handler func(w http.ResponseWriter, r *http.Request)
func NewTestServer() *ServerTest {
return &ServerTest{handlers: make(map[string]handler)}
}
func OpenAICheck(w http.ResponseWriter, r *http.Request) bool {
if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() {
w.WriteHeader(http.StatusUnauthorized)
return false
}
return true
}
func (ts *ServerTest) RegisterHandler(path string, handler handler) {
// to make the registered paths friendlier to a regex match in the route handler
// in OpenAITestServer
path = strings.ReplaceAll(path, "*", ".*")
ts.handlers[path] = handler
}
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
func (ts *ServerTest) TestServer(headerCheck func(w http.ResponseWriter, r *http.Request) bool) *httptest.Server {
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path)
// check auth
if headerCheck != nil && !headerCheck(w, r) {
return
}
// Handle /path/* routes.
// Note: the * is converted to a .* in register handler for proper regex handling
for route, handler := range ts.handlers {
// Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered
pattern, _ := regexp.Compile("^" + route + "$")
if pattern.MatchString(r.URL.Path) {
handler(w, r)
return
}
}
http.Error(w, "the resource path doesn't exist", http.StatusNotFound)
}))
}