mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 13:13:41 +08:00
♻️ refactor: provider refactor (#41)
* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
299
common/client.go
299
common/client.go
@@ -1,299 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
var clientPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &http.Client{}
|
||||
},
|
||||
}
|
||||
|
||||
func GetHttpClient(proxyAddr string) *http.Client {
|
||||
client := clientPool.Get().(*http.Client)
|
||||
|
||||
if RelayTimeout > 0 {
|
||||
client.Timeout = time.Duration(RelayTimeout) * time.Second
|
||||
}
|
||||
|
||||
if proxyAddr != "" {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
SysError("Error parsing proxy address: " + err.Error())
|
||||
return client
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
SysError("Error creating SOCKS5 dialer: " + err.Error())
|
||||
return client
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
default:
|
||||
SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
}
|
||||
|
||||
func PutHttpClient(c *http.Client) {
|
||||
clientPool.Put(c)
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
requestBuilder RequestBuilder
|
||||
CreateFormBuilder func(io.Writer) FormBuilder
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
requestBuilder: NewRequestBuilder(),
|
||||
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
||||
return NewFormBuilder(body)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type requestOptions struct {
|
||||
body any
|
||||
header http.Header
|
||||
}
|
||||
|
||||
type requestOption func(*requestOptions)
|
||||
|
||||
type Stringer interface {
|
||||
GetString() *string
|
||||
}
|
||||
|
||||
func WithBody(body any) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.body = body
|
||||
}
|
||||
}
|
||||
|
||||
func WithHeader(header map[string]string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
for k, v := range header {
|
||||
args.header.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithContentType(contentType string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.header.Set("Content-Type", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
type RequestError struct {
|
||||
HTTPStatusCode int
|
||||
Err error
|
||||
}
|
||||
|
||||
func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
||||
// Default Options
|
||||
args := &requestOptions{
|
||||
body: nil,
|
||||
header: make(http.Header),
|
||||
}
|
||||
for _, setter := range setters {
|
||||
setter(args)
|
||||
}
|
||||
req, err := c.requestBuilder.Build(method, url, args.body, args.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
// 发送请求
|
||||
client := GetHttpClient(proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
PutHttpClient(client)
|
||||
|
||||
if !outputResp {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
if IsFailureStatusCode(resp) {
|
||||
return nil, HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
if outputResp {
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
err = DecodeResponse(tee, response)
|
||||
|
||||
// 将响应体重新写入 resp.Body
|
||||
resp.Body = io.NopCloser(&buf)
|
||||
} else {
|
||||
err = DecodeResponse(resp.Body, response)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if outputResp {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error types.OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// var errorResponse types.OpenAIErrorResponse
|
||||
var errorResponse GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errorResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if errorResponse.Error.Message != "" {
|
||||
// OpenAI format error, so we override the default one
|
||||
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
|
||||
} else {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage()
|
||||
}
|
||||
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
|
||||
client := GetHttpClient(proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
PutHttpClient(client)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func IsFailureStatusCode(resp *http.Response) bool {
|
||||
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
func DecodeResponse(body io.Reader, v any) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result, ok := v.(*string); ok {
|
||||
return DecodeString(body, result)
|
||||
}
|
||||
|
||||
if stringer, ok := v.(Stringer); ok {
|
||||
return DecodeString(body, stringer.GetString())
|
||||
}
|
||||
|
||||
return json.NewDecoder(body).Decode(v)
|
||||
}
|
||||
|
||||
func DecodeString(body io.Reader, output *string) error {
|
||||
b, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*output = string(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
@@ -37,6 +37,14 @@ func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWith
|
||||
return StringErrorWrapper(err.Error(), code, statusCode)
|
||||
}
|
||||
|
||||
func ErrorToOpenAIError(err error) *types.OpenAIError {
|
||||
return &types.OpenAIError{
|
||||
Code: "system error",
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
}
|
||||
}
|
||||
|
||||
func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode {
|
||||
openAIError := types.OpenAIError{
|
||||
Message: err,
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package common
|
||||
|
||||
// type Quota struct {
|
||||
// ModelName string
|
||||
// ModelRatio float64
|
||||
// GroupRatio float64
|
||||
// Ratio float64
|
||||
// UserQuota int
|
||||
// }
|
||||
|
||||
// func CreateQuota(modelName string, userQuota int, group string) *Quota {
|
||||
// modelRatio := GetModelRatio(modelName)
|
||||
// groupRatio := GetGroupRatio(group)
|
||||
|
||||
// return &Quota{
|
||||
// ModelName: modelName,
|
||||
// ModelRatio: modelRatio,
|
||||
// GroupRatio: groupRatio,
|
||||
// Ratio: modelRatio * groupRatio,
|
||||
// UserQuota: userQuota,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
// if ApproximateTokenEnabled {
|
||||
// return int(float64(len(text)) * 0.38)
|
||||
// }
|
||||
// return len(tokenEncoder.Encode(text, nil, nil))
|
||||
// }
|
||||
|
||||
// func (q *Quota) CountTokenMessages(messages []Message, model string) int {
|
||||
// tokenEncoder := q.getTokenEncoder(model)
|
||||
// // Reference:
|
||||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
// //
|
||||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
// var tokensPerMessage int
|
||||
// var tokensPerName int
|
||||
// if model == "gpt-3.5-turbo-0301" {
|
||||
// tokensPerMessage = 4
|
||||
// tokensPerName = -1 // If there's a name, the role is omitted
|
||||
// } else {
|
||||
// tokensPerMessage = 3
|
||||
// tokensPerName = 1
|
||||
// }
|
||||
// tokenNum := 0
|
||||
// for _, message := range messages {
|
||||
// tokenNum += tokensPerMessage
|
||||
// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent())
|
||||
// tokenNum += q.getTokenNum(tokenEncoder, message.Role)
|
||||
// if message.Name != nil {
|
||||
// tokenNum += tokensPerName
|
||||
// tokenNum += q.getTokenNum(tokenEncoder, *message.Name)
|
||||
// }
|
||||
// }
|
||||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
// return tokenNum
|
||||
// }
|
||||
@@ -1,4 +1,4 @@
|
||||
package common
|
||||
package requester
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
68
common/requester/http_client.go
Normal file
68
common/requester/http_client.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type HTTPClient struct{}
|
||||
|
||||
var clientPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &http.Client{}
|
||||
},
|
||||
}
|
||||
|
||||
func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client {
|
||||
client := clientPool.Get().(*http.Client)
|
||||
|
||||
if common.RelayTimeout > 0 {
|
||||
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
||||
}
|
||||
|
||||
if proxyAddr != "" {
|
||||
err := h.setProxy(client, proxyAddr)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
return client
|
||||
}
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func (h *HTTPClient) returnClientToPool(client *http.Client) {
|
||||
clientPool.Put(client)
|
||||
}
|
||||
|
||||
func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing proxy address: %w", err)
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating socks5 dialer: %w", err)
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
229
common/requester/http_requester.go
Normal file
229
common/requester/http_requester.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type HttpErrorHandler func(*http.Response) *types.OpenAIError
|
||||
|
||||
type HTTPRequester struct {
|
||||
HTTPClient HTTPClient
|
||||
requestBuilder RequestBuilder
|
||||
CreateFormBuilder func(io.Writer) FormBuilder
|
||||
ErrorHandler HttpErrorHandler
|
||||
proxyAddr string
|
||||
}
|
||||
|
||||
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
|
||||
// proxyAddr: 是代理服务器的地址。
|
||||
// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。
|
||||
// 如果 errorHandler 为 nil,那么会使用一个默认的错误处理函数。
|
||||
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
|
||||
return &HTTPRequester{
|
||||
HTTPClient: HTTPClient{},
|
||||
requestBuilder: NewRequestBuilder(),
|
||||
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
||||
return NewFormBuilder(body)
|
||||
},
|
||||
ErrorHandler: errorHandler,
|
||||
proxyAddr: proxyAddr,
|
||||
}
|
||||
}
|
||||
|
||||
type requestOptions struct {
|
||||
body any
|
||||
header http.Header
|
||||
}
|
||||
|
||||
type requestOption func(*requestOptions)
|
||||
|
||||
// 创建请求
|
||||
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
||||
args := &requestOptions{
|
||||
body: nil,
|
||||
header: make(http.Header),
|
||||
}
|
||||
for _, setter := range setters {
|
||||
setter(args)
|
||||
}
|
||||
req, err := r.requestBuilder.Build(method, url, args.body, args.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
r.HTTPClient.returnClientToPool(client)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if !outputResp {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
if r.IsFailureStatusCode(resp) {
|
||||
return nil, HandleErrorResp(resp, r.ErrorHandler)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
if outputResp {
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
err = DecodeResponse(tee, response)
|
||||
|
||||
// 将响应体重新写入 resp.Body
|
||||
resp.Body = io.NopCloser(&buf)
|
||||
} else {
|
||||
err = json.NewDecoder(resp.Body).Decode(response)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// 发送请求 RAW
|
||||
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
// 发送请求
|
||||
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
r.HTTPClient.returnClientToPool(client)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
if r.IsFailureStatusCode(resp) {
|
||||
return nil, HandleErrorResp(resp, r.ErrorHandler)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// 获取流式响应
|
||||
func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
|
||||
// 如果返回的头是json格式 说明有错误
|
||||
if resp.Header.Get("Content-Type") == "application/json" {
|
||||
return nil, HandleErrorResp(resp, requester.ErrorHandler)
|
||||
}
|
||||
|
||||
return &streamReader[T]{
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置请求体
|
||||
func (r *HTTPRequester) WithBody(body any) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.body = body
|
||||
}
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
func (r *HTTPRequester) WithHeader(header map[string]string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
for k, v := range header {
|
||||
args.header.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置Content-Type
|
||||
func (r *HTTPRequester) WithContentType(contentType string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.header.Set("Content-Type", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
// 判断是否为失败状态码
|
||||
func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool {
|
||||
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode {
|
||||
|
||||
openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if toOpenAIError != nil {
|
||||
errorResponse := toOpenAIError(resp)
|
||||
|
||||
if errorResponse != nil && errorResponse.Message != "" {
|
||||
openAIErrorWithStatusCode.OpenAIError = *errorResponse
|
||||
}
|
||||
}
|
||||
|
||||
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return openAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
type Stringer interface {
|
||||
GetString() *string
|
||||
}
|
||||
|
||||
func DecodeResponse(body io.Reader, v any) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result, ok := v.(*string); ok {
|
||||
return DecodeString(body, result)
|
||||
}
|
||||
|
||||
if stringer, ok := v.(Stringer); ok {
|
||||
return DecodeString(body, stringer.GetString())
|
||||
}
|
||||
|
||||
return json.NewDecoder(body).Decode(v)
|
||||
}
|
||||
|
||||
func DecodeString(body io.Reader, output *string) error {
|
||||
b, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*output = string(b)
|
||||
return nil
|
||||
}
|
||||
79
common/requester/http_stream_reader.go
Normal file
79
common/requester/http_stream_reader.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 流处理函数,判断依据如下:
|
||||
// 1.如果有错误信息,则直接返回错误信息
|
||||
// 2.如果isFinished=true,则返回io.EOF,并且如果response不为空,还将返回response
|
||||
// 3.如果rawLine=nil 或者 response长度为0,则直接跳过
|
||||
// 4.如果以上条件都不满足,则返回response
|
||||
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
|
||||
|
||||
type streamable interface {
|
||||
// types.ChatCompletionStreamResponse | types.CompletionResponse
|
||||
any
|
||||
}
|
||||
|
||||
type StreamReaderInterface[T streamable] interface {
|
||||
Recv() (*[]T, error)
|
||||
Close()
|
||||
}
|
||||
|
||||
type streamReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (stream *streamReader[T]) processLines() (*[]T, error) {
|
||||
for {
|
||||
rawLine, readErr := stream.reader.ReadBytes('\n')
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||
|
||||
var response []T
|
||||
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if noSpaceLine == nil || len(response) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Close() {
|
||||
stream.response.Body.Close()
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
package common
|
||||
package requester
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type RequestBuilder interface {
|
||||
@@ -11,12 +12,12 @@ type RequestBuilder interface {
|
||||
}
|
||||
|
||||
type HTTPRequestBuilder struct {
|
||||
marshaller Marshaller
|
||||
marshaller common.Marshaller
|
||||
}
|
||||
|
||||
func NewRequestBuilder() *HTTPRequestBuilder {
|
||||
return &HTTPRequestBuilder{
|
||||
marshaller: &JSONMarshaller{},
|
||||
marshaller: &common.JSONMarshaller{},
|
||||
}
|
||||
}
|
||||
|
||||
53
common/requester/ws_client.go
Normal file
53
common/requester/ws_client.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func GetWSClient(proxyAddr string) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
if proxyAddr != "" {
|
||||
err := setWSProxy(dialer, proxyAddr)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
return dialer
|
||||
}
|
||||
}
|
||||
|
||||
return dialer
|
||||
}
|
||||
|
||||
func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing proxy address: %w", err)
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
dialer.Proxy = http.ProxyURL(proxyURL)
|
||||
case "socks5":
|
||||
socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating socks5 dialer: %w", err)
|
||||
}
|
||||
dialer.NetDial = func(network, addr string) (net.Conn, error) {
|
||||
return socks5Proxy.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
58
common/requester/ws_reader.go
Normal file
58
common/requester/ws_reader.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type wsReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *websocket.Conn
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) processLines() (*[]T, error) {
|
||||
for {
|
||||
_, msg, err := stream.reader.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response []T
|
||||
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if msg == nil || len(response) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) Close() {
|
||||
stream.reader.Close()
|
||||
}
|
||||
54
common/requester/ws_requester.go
Normal file
54
common/requester/ws_requester.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WSRequester struct {
|
||||
WSClient *websocket.Dialer
|
||||
}
|
||||
|
||||
func NewWSRequester(proxyAddr string) *WSRequester {
|
||||
return &WSRequester{
|
||||
WSClient: GetWSClient(proxyAddr),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WSRequester) NewRequest(url string, header http.Header) (*websocket.Conn, error) {
|
||||
conn, resp, err := w.WSClient.Dial(url, header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, errors.New("ws unexpected status code")
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPrefix HandlerPrefix[T]) (*wsReader[T], *types.OpenAIErrorWithStatusCode) {
|
||||
err := conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return &wsReader[T]{
|
||||
reader: conn,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
func (r *WSRequester) WithHeader(headers map[string]string) http.Header {
|
||||
header := make(http.Header)
|
||||
for k, v := range headers {
|
||||
header.Set(k, v)
|
||||
}
|
||||
return header
|
||||
}
|
||||
55
common/test/api.go
Normal file
55
common/test/api.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RequestJSONConfig() map[string]string {
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
func GetContext(method, path string, headers map[string]string, body io.Reader) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
var req *http.Request
|
||||
req, _ = http.NewRequest(method, path, body)
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
return c, w
|
||||
}
|
||||
|
||||
func GetGinRouter(method, path string, headers map[string]string, body *io.Reader) *httptest.ResponseRecorder {
|
||||
var req *http.Request
|
||||
r := gin.Default()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ = http.NewRequest(method, path, *body)
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
func GetChannel(channelType int, baseUrl, other, porxy, modelMapping string) model.Channel {
|
||||
return model.Channel{
|
||||
Type: channelType,
|
||||
BaseURL: &baseUrl,
|
||||
Other: other,
|
||||
Proxy: porxy,
|
||||
ModelMapping: &modelMapping,
|
||||
Key: GetTestToken(),
|
||||
}
|
||||
}
|
||||
132
common/test/chat_config.go
Normal file
132
common/test/chat_config.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetChatCompletionRequest(chatType, modelName, stream string) *types.ChatCompletionRequest {
|
||||
chatJSON := GetChatRequest(chatType, modelName, stream)
|
||||
chatCompletionRequest := &types.ChatCompletionRequest{}
|
||||
json.NewDecoder(chatJSON).Decode(chatCompletionRequest)
|
||||
return chatCompletionRequest
|
||||
}
|
||||
|
||||
func GetChatRequest(chatType, modelName, stream string) *strings.Reader {
|
||||
var chatJSON string
|
||||
switch chatType {
|
||||
case "image":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
"stream": ` + stream + `
|
||||
}`
|
||||
case "default":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
],
|
||||
"stream": ` + stream + `
|
||||
}`
|
||||
case "function":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"stream": ` + stream + `,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Boston?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "auto"
|
||||
}`
|
||||
|
||||
case "tools":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"stream": ` + stream + `,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Boston?"
|
||||
}
|
||||
],
|
||||
"functions": [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
}
|
||||
|
||||
return strings.NewReader(chatJSON)
|
||||
}
|
||||
65
common/test/check_chat.go
Normal file
65
common/test/check_chat.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"one-api/types"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func CheckChat(t *testing.T, response *types.ChatCompletionResponse, modelName string, usage *types.Usage) {
|
||||
assert.NotEmpty(t, response.ID)
|
||||
assert.NotEmpty(t, response.Object)
|
||||
assert.NotEmpty(t, response.Created)
|
||||
assert.Equal(t, response.Model, modelName)
|
||||
assert.IsType(t, []types.ChatCompletionChoice{}, response.Choices)
|
||||
// check choices 长度大于1
|
||||
assert.True(t, len(response.Choices) > 0)
|
||||
for _, choice := range response.Choices {
|
||||
assert.NotNil(t, choice.Index)
|
||||
assert.IsType(t, types.ChatCompletionMessage{}, choice.Message)
|
||||
assert.NotEmpty(t, choice.Message.Role)
|
||||
assert.NotEmpty(t, choice.FinishReason)
|
||||
|
||||
// check message
|
||||
if choice.Message.Content != nil {
|
||||
multiContents, ok := choice.Message.Content.([]types.ChatMessagePart)
|
||||
if ok {
|
||||
for _, content := range multiContents {
|
||||
assert.NotEmpty(t, content.Type)
|
||||
if content.Type == "text" {
|
||||
assert.NotEmpty(t, content.Text)
|
||||
} else if content.Type == "image_url" {
|
||||
assert.IsType(t, types.ChatMessageImageURL{}, content.ImageURL)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
content, ok := choice.Message.Content.(string)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, content)
|
||||
}
|
||||
} else if choice.Message.FunctionCall != nil {
|
||||
assert.NotEmpty(t, choice.Message.FunctionCall.Name)
|
||||
assert.Equal(t, choice.FinishReason, types.FinishReasonFunctionCall)
|
||||
} else if choice.Message.ToolCalls != nil {
|
||||
assert.IsType(t, []types.ChatCompletionToolCalls{}, choice.Message.ToolCalls)
|
||||
assert.NotEmpty(t, choice.Message.ToolCalls[0].Id)
|
||||
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function)
|
||||
assert.Equal(t, choice.Message.ToolCalls[0].Function, "function")
|
||||
|
||||
assert.IsType(t, types.ChatCompletionToolCallsFunction{}, choice.Message.ToolCalls[0].Function)
|
||||
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function.Name)
|
||||
|
||||
assert.Equal(t, choice.FinishReason, types.FinishReasonToolCalls)
|
||||
} else {
|
||||
assert.Fail(t, "message content is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// check usage
|
||||
assert.IsType(t, &types.Usage{}, response.Usage)
|
||||
assert.Equal(t, response.Usage.PromptTokens, usage.PromptTokens)
|
||||
assert.Equal(t, response.Usage.CompletionTokens, usage.CompletionTokens)
|
||||
assert.Equal(t, response.Usage.TotalTokens, usage.TotalTokens)
|
||||
|
||||
}
|
||||
48
common/test/checks.go
Normal file
48
common/test/checks.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func NoError(t *testing.T, err error, message ...string) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Error(err, message)
|
||||
}
|
||||
}
|
||||
|
||||
func HasError(t *testing.T, err error, message ...string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Error(err, message)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIs(t *testing.T, err, target error, msg ...string) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, target) {
|
||||
t.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, target) {
|
||||
t.Fatalf(format, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIsNot(t *testing.T, err, target error, msg ...string) {
|
||||
t.Helper()
|
||||
if errors.Is(err, target) {
|
||||
t.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) {
|
||||
t.Helper()
|
||||
if errors.Is(err, target) {
|
||||
t.Fatalf(format, msg)
|
||||
}
|
||||
}
|
||||
7
common/test/init/init.go
Normal file
7
common/test/init/init.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package init
|
||||
|
||||
import "testing"
|
||||
|
||||
func init() {
|
||||
testing.Init()
|
||||
}
|
||||
63
common/test/server.go
Normal file
63
common/test/server.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const testAPI = "this-is-my-secure-token-do-not-steal!!"
|
||||
|
||||
func GetTestToken() string {
|
||||
return testAPI
|
||||
}
|
||||
|
||||
type ServerTest struct {
|
||||
handlers map[string]handler
|
||||
}
|
||||
type handler func(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
func NewTestServer() *ServerTest {
|
||||
return &ServerTest{handlers: make(map[string]handler)}
|
||||
}
|
||||
|
||||
func OpenAICheck(w http.ResponseWriter, r *http.Request) bool {
|
||||
if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *ServerTest) RegisterHandler(path string, handler handler) {
|
||||
// to make the registered paths friendlier to a regex match in the route handler
|
||||
// in OpenAITestServer
|
||||
path = strings.ReplaceAll(path, "*", ".*")
|
||||
ts.handlers[path] = handler
|
||||
}
|
||||
|
||||
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
|
||||
func (ts *ServerTest) TestServer(headerCheck func(w http.ResponseWriter, r *http.Request) bool) *httptest.Server {
|
||||
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path)
|
||||
|
||||
// check auth
|
||||
if headerCheck != nil && !headerCheck(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle /path/* routes.
|
||||
// Note: the * is converted to a .* in register handler for proper regex handling
|
||||
for route, handler := range ts.handlers {
|
||||
// Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered
|
||||
pattern, _ := regexp.Compile("^" + route + "$")
|
||||
if pattern.MatchString(r.URL.Path) {
|
||||
handler(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
http.Error(w, "the resource path doesn't exist", http.StatusNotFound)
|
||||
}))
|
||||
}
|
||||
Reference in New Issue
Block a user