one-api/common/requester/ws_requester.go
Buer ef041e28a1
♻️ refactor: provider refactor (#41)
* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
2024-01-19 02:47:10 +08:00

55 lines
1.2 KiB
Go

package requester
import (
"errors"
"net/http"
"one-api/common"
"one-api/types"
"github.com/gorilla/websocket"
)
type WSRequester struct {
WSClient *websocket.Dialer
}
func NewWSRequester(proxyAddr string) *WSRequester {
return &WSRequester{
WSClient: GetWSClient(proxyAddr),
}
}
func (w *WSRequester) NewRequest(url string, header http.Header) (*websocket.Conn, error) {
conn, resp, err := w.WSClient.Dial(url, header)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, errors.New("ws unexpected status code")
}
return conn, nil
}
func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPrefix HandlerPrefix[T]) (*wsReader[T], *types.OpenAIErrorWithStatusCode) {
err := conn.WriteJSON(data)
if err != nil {
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
}
return &wsReader[T]{
reader: conn,
handlerPrefix: handlerPrefix,
}, nil
}
// 设置请求头
func (r *WSRequester) WithHeader(headers map[string]string) http.Header {
header := make(http.Header)
for k, v := range headers {
header.Set(k, v)
}
return header
}