mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
♻️ refactor: provider refactor (#41)
* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
24
providers/ali/ali_test.go
Normal file
24
providers/ali/ali_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package ali_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/test"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
func setupAliTestServer() (baseUrl string, server *test.ServerTest, teardown func()) {
|
||||
server = test.NewTestServer()
|
||||
ts := server.TestServer(func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return test.OpenAICheck(w, r)
|
||||
})
|
||||
ts.Start()
|
||||
teardown = ts.Close
|
||||
|
||||
baseUrl = ts.URL
|
||||
return
|
||||
}
|
||||
|
||||
func getAliChannel(baseUrl string) model.Channel {
|
||||
return test.GetChannel(common.ChannelTypeAli, baseUrl, "", "", "")
|
||||
}
|
||||
@@ -1,32 +1,66 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
// 定义供应商工厂
|
||||
type AliProviderFactory struct{}
|
||||
|
||||
type AliProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
// 创建 AliProvider
|
||||
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||
func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f AliProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &AliProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://dashscope.aliyuncs.com",
|
||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type AliProvider struct {
|
||||
base.BaseProvider
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://dashscope.aliyuncs.com",
|
||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var aliError *AliError
|
||||
err := json.NewDecoder(resp.Body).Decode(aliError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(aliError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(aliError *AliError) *types.OpenAIError {
|
||||
if aliError.Code == "" {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: aliError.Message,
|
||||
Type: aliError.Code,
|
||||
Param: aliError.RequestId,
|
||||
Code: aliError.Code,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
|
||||
@@ -1,51 +1,116 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 阿里云响应处理
|
||||
func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
OpenAIResponse = types.ChatCompletionResponse{
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: aliResponse.Model,
|
||||
Choices: aliResponse.Output.ToChatCompletionChoices(),
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: aliResponse.Usage.InputTokens,
|
||||
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
return
|
||||
type aliStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
lastStreamResponse string
|
||||
}
|
||||
|
||||
const AliEnableSearchModelSuffix = "-internet"
|
||||
|
||||
// 获取聊天请求体
|
||||
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getAliChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
aliResponse := &AliChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, aliResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return p.convertToChatOpenai(aliResponse, request)
|
||||
}
|
||||
|
||||
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getAliChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &aliStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
headers["X-DashScope-SSE"] = "enable"
|
||||
}
|
||||
|
||||
aliRequest := convertFromChatOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天请求体
|
||||
func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.AliError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
ID: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: request.Model,
|
||||
Choices: response.Output.ToChatCompletionChoices(),
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
*p.Usage = *openaiResponse.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 阿里云聊天请求体
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
@@ -96,163 +161,68 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
|
||||
}
|
||||
}
|
||||
|
||||
// 聊天
|
||||
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
headers["X-DashScope-SSE"] = "enable"
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data:") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[5:]
|
||||
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal(*rawLine, &aliResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &types.Usage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
aliResponse := &AliChatResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, aliResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: aliResponse.Usage.InputTokens,
|
||||
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||
}
|
||||
error := errorHandle(&aliResponse.AliError)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
return
|
||||
|
||||
return h.convertToOpenaiStream(&aliResponse, response)
|
||||
|
||||
}
|
||||
|
||||
// 阿里云响应转OpenAI响应
|
||||
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||
// chatChoice := aliResponse.Output.ToChatCompletionChoices()
|
||||
// jsonBody, _ := json.MarshalIndent(chatChoice, "", " ")
|
||||
// fmt.Println("requestBody:", string(jsonBody))
|
||||
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
content := aliResponse.Output.Choices[0].Message.StringContent()
|
||||
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Index = aliResponse.Output.Choices[0].Index
|
||||
choice.Delta.Content = aliResponse.Output.Choices[0].Message.StringContent()
|
||||
// fmt.Println("choice.Delta.Content:", chatChoice[0].Message)
|
||||
if aliResponse.Output.Choices[0].FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.Choices[0].FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
choice.Delta.Content = strings.TrimPrefix(content, h.lastStreamResponse)
|
||||
if aliResponse.Output.Choices[0].FinishReason != "" {
|
||||
if aliResponse.Output.Choices[0].FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.Choices[0].FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
if aliResponse.Output.FinishReason != "" {
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
|
||||
h.lastStreamResponse = content
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: aliResponse.Model,
|
||||
Model: h.Request.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// 发送流请求
|
||||
func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
usage = &types.Usage{}
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return nil, common.HandleErrorResp(resp)
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
h.Usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
h.Usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
lastResponseText := ""
|
||||
index := 0
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
aliResponse.Model = model
|
||||
aliResponse.Output.Choices[0].Index = index
|
||||
index++
|
||||
response := p.streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
lastResponseText = aliResponse.Output.Choices[0].Message.StringContent()
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
330
providers/ali/chat_test.go
Normal file
330
providers/ali/chat_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package ali_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/test"
|
||||
_ "one-api/common/test/init"
|
||||
"one-api/providers"
|
||||
providers_base "one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func getChatProvider(url string, context *gin.Context) providers_base.ChatInterface {
|
||||
channel := getAliChannel(url)
|
||||
provider := providers.GetProvider(&channel, context)
|
||||
chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
|
||||
return chatProvider
|
||||
}
|
||||
|
||||
func TestChatCompletions(t *testing.T) {
|
||||
url, server, teardown := setupAliTestServer()
|
||||
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
defer teardown()
|
||||
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionEndpoint)
|
||||
|
||||
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
|
||||
|
||||
chatProvider := getChatProvider(url, context)
|
||||
usage := &types.Usage{}
|
||||
chatProvider.SetUsage(usage)
|
||||
response, errWithCode := chatProvider.CreateChatCompletion(chatRequest)
|
||||
|
||||
assert.Nil(t, errWithCode)
|
||||
assert.IsType(t, &types.Usage{}, usage)
|
||||
assert.Equal(t, 33, usage.TotalTokens)
|
||||
assert.Equal(t, 14, usage.PromptTokens)
|
||||
assert.Equal(t, 19, usage.CompletionTokens)
|
||||
|
||||
// 转换成JSON字符串
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
assert.Fail(t, "json marshal error")
|
||||
}
|
||||
fmt.Println(string(responseBody))
|
||||
|
||||
test.CheckChat(t, response, "qwen-turbo", usage)
|
||||
}
|
||||
|
||||
func TestChatCompletionsError(t *testing.T) {
|
||||
url, server, teardown := setupAliTestServer()
|
||||
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
defer teardown()
|
||||
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionErrorEndpoint)
|
||||
|
||||
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
|
||||
|
||||
chatProvider := getChatProvider(url, context)
|
||||
_, err := chatProvider.CreateChatCompletion(chatRequest)
|
||||
usage := chatProvider.GetUsage()
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, usage)
|
||||
assert.Equal(t, "InvalidParameter", err.Code)
|
||||
}
|
||||
|
||||
// func TestChatCompletionsStream(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
|
||||
|
||||
// usage := &types.Usage{}
|
||||
// chatProvider.SetUsage(usage)
|
||||
// response, errWithCode := chatProvider.CreateChatCompletionStream(chatRequest)
|
||||
// assert.Nil(t, errWithCode)
|
||||
|
||||
// assert.IsType(t, &types.Usage{}, usage)
|
||||
// assert.Equal(t, 16, usage.TotalTokens)
|
||||
// assert.Equal(t, 8, usage.PromptTokens)
|
||||
// assert.Equal(t, 8, usage.CompletionTokens)
|
||||
|
||||
// streamResponseCheck(t, w.Body.String())
|
||||
// }
|
||||
|
||||
// func TestChatCompletionsStreamError(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamErrorEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
|
||||
|
||||
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
||||
|
||||
// // 打印 context 写入的内容
|
||||
// fmt.Println(w.Body.String())
|
||||
|
||||
// assert.NotNil(t, err)
|
||||
// assert.Nil(t, usage)
|
||||
// }
|
||||
|
||||
// func TestChatImageCompletions(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "false")
|
||||
|
||||
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
||||
|
||||
// assert.Nil(t, err)
|
||||
// assert.IsType(t, &types.Usage{}, usage)
|
||||
// assert.Equal(t, 1306, usage.TotalTokens)
|
||||
// assert.Equal(t, 1279, usage.PromptTokens)
|
||||
// assert.Equal(t, 27, usage.CompletionTokens)
|
||||
// }
|
||||
|
||||
// func TestChatImageCompletionsStream(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionStreamEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "true")
|
||||
|
||||
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
||||
|
||||
// fmt.Println(w.Body.String())
|
||||
|
||||
// assert.Nil(t, err)
|
||||
// assert.IsType(t, &types.Usage{}, usage)
|
||||
// assert.Equal(t, 1342, usage.TotalTokens)
|
||||
// assert.Equal(t, 1279, usage.PromptTokens)
|
||||
// assert.Equal(t, 63, usage.CompletionTokens)
|
||||
// streamResponseCheck(t, w.Body.String())
|
||||
// }
|
||||
|
||||
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
response := `{"output":{"choices":[{"finish_reason":"stop","message":{"role":"assistant","content":"您好!我可以帮您查询最近的公园,请问您现在所在的位置是哪里呢?"}}]},"usage":{"total_tokens":33,"output_tokens":19,"input_tokens":14},"request_id":"2479f818-9717-9b0b-9769-0d26e873a3f6"}`
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, response)
|
||||
}
|
||||
|
||||
func handleChatCompletionErrorEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
response := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"4883ee8d-f095-94ff-a94a-5ce0a94bc81f"}`
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, response)
|
||||
}
|
||||
|
||||
func handleChatCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// 检测头部是否有X-DashScope-SSE: enable
|
||||
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
||||
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Send test responses
|
||||
dataBytes := []byte{}
|
||||
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data := `{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":10,"input_tokens":8,"output_tokens":2},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:2\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":"有什么我可以帮助你的吗?","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:3\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":"","role":"assistant"},"finish_reason":"stop"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func handleChatCompletionStreamErrorEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// 检测头部是否有X-DashScope-SSE: enable
|
||||
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
||||
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Send test responses
|
||||
dataBytes := []byte{}
|
||||
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:error\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/400\n")...)
|
||||
//nolint:lll
|
||||
data := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"6b932ba9-41bd-9ad3-b430-24bc1e125880"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func handleChatImageCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
response := `{"output":{"finish_reason":"stop","choices":[{"message":{"role":"assistant","content":[{"text":"这张照片展示的是一个海滩的场景,但是并没有明确指出具体的位置。可以看到海浪和日落背景下的沙滩景色。"}]}}]},"usage":{"output_tokens":27,"input_tokens":1279,"image_tokens":1247},"request_id":"a360d53b-b993-927f-9a68-bef6b2b2042e"}`
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, response)
|
||||
}
|
||||
|
||||
func handleChatImageCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// 检测头部是否有X-DashScope-SSE: enable
|
||||
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
||||
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Send test responses
|
||||
dataBytes := []byte{}
|
||||
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data := `{"output":{"choices":[{"message":{"content":[{"text":"这张"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":1,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:2\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":2,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:3\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片展示的是一个海滩的场景,具体来说是在日落时分。由于没有明显的地标或建筑物等特征可以辨认出具体的地点信息,所以无法确定这是哪个地方的海滩。但是根据图像中的元素和环境特点,我们可以推测这可能是一个位于沿海地区的沙滩海岸线。"}],"role":"assistant"}}],"finish_reason":"stop"},"usage":{"input_tokens":1279,"output_tokens":63,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseCheck(t *testing.T, response string) {
|
||||
// 以换行符分割response
|
||||
lines := strings.Split(response, "\n\n")
|
||||
// 如果最后一行为空,则删除最后一行
|
||||
if lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1]
|
||||
}
|
||||
|
||||
// 循环遍历每一行
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
// assert判断 是否以data: 开头
|
||||
assert.True(t, strings.HasPrefix(line, "data: "))
|
||||
}
|
||||
|
||||
// 检测最后一行是否以data: [DONE] 结尾
|
||||
assert.True(t, strings.HasSuffix(lines[len(lines)-1], "data: [DONE]"))
|
||||
// 检测倒数第二行是否存在 `"finish_reason":"stop"`
|
||||
assert.True(t, strings.Contains(lines[len(lines)-2], `"finish_reason":"stop"`))
|
||||
}
|
||||
@@ -6,40 +6,37 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
// 嵌入请求处理
|
||||
func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
aliRequest := convertFromEmbeddingOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
aliResponse := &AliEmbeddingResponse{}
|
||||
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, aliResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
openAIEmbeddingResponse := &types.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range aliResponse.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
|
||||
return openAIEmbeddingResponse, nil
|
||||
return p.convertToEmbeddingOpenai(aliResponse, request)
|
||||
}
|
||||
|
||||
// 获取嵌入请求体
|
||||
func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
func convertFromEmbeddingOpenai(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
@@ -50,24 +47,36 @@ func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getEmbeddingsRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
aliEmbeddingResponse := &AliEmbeddingResponse{}
|
||||
errWithCode = p.SendRequest(req, aliEmbeddingResponse, false)
|
||||
if errWithCode != nil {
|
||||
func (p *AliProvider) convertToEmbeddingOpenai(response *AliEmbeddingResponse, request *types.EmbeddingRequest) (openaiResponse *types.EmbeddingResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.AliError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
|
||||
|
||||
return usage, nil
|
||||
openaiResponse = &types.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]types.Embedding, 0, len(response.Output.Embeddings)),
|
||||
Model: request.Model,
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: response.Usage.TotalTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openaiResponse.Data = append(openaiResponse.Data, types.Embedding{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
|
||||
*p.Usage = *openaiResponse.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -52,7 +52,8 @@ type AliChoice struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Choices []types.ChatCompletionChoice `json:"choices"`
|
||||
Choices []types.ChatCompletionChoice `json:"choices"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice {
|
||||
@@ -70,7 +71,6 @@ func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice {
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
Model string `json:"model,omitempty"`
|
||||
AliError
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user