refactor: use adaptor to do relay & test

This commit is contained in:
JustSong
2024-02-18 00:15:31 +08:00
parent d548a01c59
commit 1aa374ccfb
63 changed files with 1452 additions and 1332 deletions

View File

@@ -1,22 +1,66 @@
package xunfei
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
type Adaptor struct {
request *model.GeneralOpenAIRequest
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
// check DoResponse for auth part
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
a.request = request
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK
return dummyResp, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
splits := strings.Split(meta.APIKey, "|")
if len(splits) != 3 {
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
if a.request == nil {
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
}
if meta.IsStream {
err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
} else {
err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "xunfei"
}

View File

@@ -0,0 +1,5 @@
package xunfei
var ModelList = []string{
"SparkDesk",
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"net/url"
@@ -23,7 +24,7 @@ import (
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -62,7 +63,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
}
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
@@ -125,14 +126,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
common.SetEventStreamHeaders(c)
var usage openai.Usage
var usage model.Usage
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
@@ -155,13 +156,13 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
return nil, &usage
}
func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
var usage openai.Usage
var usage model.Usage
var content string
var xunfeiResponse ChatResponse
stop := false
@@ -197,7 +198,7 @@ func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId stri
return nil, &usage
}
func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}

View File

@@ -1,7 +1,7 @@
package xunfei
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
@@ -55,7 +55,7 @@ type ChatResponse struct {
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text openai.Usage `json:"text"`
Text model.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}