mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-29 06:36:38 +08:00
331 lines
13 KiB
Go
331 lines
13 KiB
Go
package ali_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"one-api/common/test"
|
|
_ "one-api/common/test/init"
|
|
"one-api/providers"
|
|
providers_base "one-api/providers/base"
|
|
"one-api/types"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func getChatProvider(url string, context *gin.Context) providers_base.ChatInterface {
|
|
channel := getAliChannel(url)
|
|
provider := providers.GetProvider(&channel, context)
|
|
chatProvider, _ := provider.(providers_base.ChatInterface)
|
|
|
|
return chatProvider
|
|
}
|
|
|
|
func TestChatCompletions(t *testing.T) {
|
|
url, server, teardown := setupAliTestServer()
|
|
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
|
defer teardown()
|
|
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionEndpoint)
|
|
|
|
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
|
|
|
|
chatProvider := getChatProvider(url, context)
|
|
usage := &types.Usage{}
|
|
chatProvider.SetUsage(usage)
|
|
response, errWithCode := chatProvider.CreateChatCompletion(chatRequest)
|
|
|
|
assert.Nil(t, errWithCode)
|
|
assert.IsType(t, &types.Usage{}, usage)
|
|
assert.Equal(t, 33, usage.TotalTokens)
|
|
assert.Equal(t, 14, usage.PromptTokens)
|
|
assert.Equal(t, 19, usage.CompletionTokens)
|
|
|
|
// 转换成JSON字符串
|
|
responseBody, err := json.Marshal(response)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
assert.Fail(t, "json marshal error")
|
|
}
|
|
fmt.Println(string(responseBody))
|
|
|
|
test.CheckChat(t, response, "qwen-turbo", usage)
|
|
}
|
|
|
|
func TestChatCompletionsError(t *testing.T) {
|
|
url, server, teardown := setupAliTestServer()
|
|
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
|
defer teardown()
|
|
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionErrorEndpoint)
|
|
|
|
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
|
|
|
|
chatProvider := getChatProvider(url, context)
|
|
_, err := chatProvider.CreateChatCompletion(chatRequest)
|
|
usage := chatProvider.GetUsage()
|
|
|
|
assert.NotNil(t, err)
|
|
assert.Nil(t, usage)
|
|
assert.Equal(t, "InvalidParameter", err.Code)
|
|
}
|
|
|
|
// func TestChatCompletionsStream(t *testing.T) {
|
|
// url, server, teardown := setupAliTestServer()
|
|
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
|
// defer teardown()
|
|
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamEndpoint)
|
|
|
|
// channel := getAliChannel(url)
|
|
// provider := providers.GetProvider(&channel, context)
|
|
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
|
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
|
|
|
|
// usage := &types.Usage{}
|
|
// chatProvider.SetUsage(usage)
|
|
// response, errWithCode := chatProvider.CreateChatCompletionStream(chatRequest)
|
|
// assert.Nil(t, errWithCode)
|
|
|
|
// assert.IsType(t, &types.Usage{}, usage)
|
|
// assert.Equal(t, 16, usage.TotalTokens)
|
|
// assert.Equal(t, 8, usage.PromptTokens)
|
|
// assert.Equal(t, 8, usage.CompletionTokens)
|
|
|
|
// streamResponseCheck(t, w.Body.String())
|
|
// }
|
|
|
|
// func TestChatCompletionsStreamError(t *testing.T) {
|
|
// url, server, teardown := setupAliTestServer()
|
|
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
|
// defer teardown()
|
|
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamErrorEndpoint)
|
|
|
|
// channel := getAliChannel(url)
|
|
// provider := providers.GetProvider(&channel, context)
|
|
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
|
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
|
|
|
|
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
|
|
|
// // 打印 context 写入的内容
|
|
// fmt.Println(w.Body.String())
|
|
|
|
// assert.NotNil(t, err)
|
|
// assert.Nil(t, usage)
|
|
// }
|
|
|
|
// func TestChatImageCompletions(t *testing.T) {
|
|
// url, server, teardown := setupAliTestServer()
|
|
// context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
|
// defer teardown()
|
|
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionEndpoint)
|
|
|
|
// channel := getAliChannel(url)
|
|
// provider := providers.GetProvider(&channel, context)
|
|
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
|
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "false")
|
|
|
|
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
|
|
|
// assert.Nil(t, err)
|
|
// assert.IsType(t, &types.Usage{}, usage)
|
|
// assert.Equal(t, 1306, usage.TotalTokens)
|
|
// assert.Equal(t, 1279, usage.PromptTokens)
|
|
// assert.Equal(t, 27, usage.CompletionTokens)
|
|
// }
|
|
|
|
// func TestChatImageCompletionsStream(t *testing.T) {
|
|
// url, server, teardown := setupAliTestServer()
|
|
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
|
// defer teardown()
|
|
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionStreamEndpoint)
|
|
|
|
// channel := getAliChannel(url)
|
|
// provider := providers.GetProvider(&channel, context)
|
|
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
|
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "true")
|
|
|
|
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
|
|
|
// fmt.Println(w.Body.String())
|
|
|
|
// assert.Nil(t, err)
|
|
// assert.IsType(t, &types.Usage{}, usage)
|
|
// assert.Equal(t, 1342, usage.TotalTokens)
|
|
// assert.Equal(t, 1279, usage.PromptTokens)
|
|
// assert.Equal(t, 63, usage.CompletionTokens)
|
|
// streamResponseCheck(t, w.Body.String())
|
|
// }
|
|
|
|
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
// completions only accepts POST requests
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
response := `{"output":{"choices":[{"finish_reason":"stop","message":{"role":"assistant","content":"您好!我可以帮您查询最近的公园,请问您现在所在的位置是哪里呢?"}}]},"usage":{"total_tokens":33,"output_tokens":19,"input_tokens":14},"request_id":"2479f818-9717-9b0b-9769-0d26e873a3f6"}`
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprintln(w, response)
|
|
}
|
|
|
|
func handleChatCompletionErrorEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
// completions only accepts POST requests
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
response := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"4883ee8d-f095-94ff-a94a-5ce0a94bc81f"}`
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprintln(w, response)
|
|
}
|
|
|
|
func handleChatCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
// completions only accepts POST requests
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
// 检测头部是否有X-DashScope-SSE: enable
|
|
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
|
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
// Send test responses
|
|
dataBytes := []byte{}
|
|
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
|
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
|
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
|
//nolint:lll
|
|
data := `{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":10,"input_tokens":8,"output_tokens":2},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
|
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
|
|
|
dataBytes = append(dataBytes, []byte("id:2\n")...)
|
|
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
|
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
|
//nolint:lll
|
|
data = `{"output":{"choices":[{"message":{"content":"有什么我可以帮助你的吗?","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
|
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
|
|
|
dataBytes = append(dataBytes, []byte("id:3\n")...)
|
|
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
|
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
|
//nolint:lll
|
|
data = `{"output":{"choices":[{"message":{"content":"","role":"assistant"},"finish_reason":"stop"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
|
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
|
|
|
_, err := w.Write(dataBytes)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func handleChatCompletionStreamErrorEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
// completions only accepts POST requests
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
// 检测头部是否有X-DashScope-SSE: enable
|
|
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
|
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
// Send test responses
|
|
dataBytes := []byte{}
|
|
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
|
dataBytes = append(dataBytes, []byte("event:error\n")...)
|
|
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/400\n")...)
|
|
//nolint:lll
|
|
data := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"6b932ba9-41bd-9ad3-b430-24bc1e125880"}`
|
|
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
|
|
|
_, err := w.Write(dataBytes)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func handleChatImageCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
// completions only accepts POST requests
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
response := `{"output":{"finish_reason":"stop","choices":[{"message":{"role":"assistant","content":[{"text":"这张照片展示的是一个海滩的场景,但是并没有明确指出具体的位置。可以看到海浪和日落背景下的沙滩景色。"}]}}]},"usage":{"output_tokens":27,"input_tokens":1279,"image_tokens":1247},"request_id":"a360d53b-b993-927f-9a68-bef6b2b2042e"}`
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprintln(w, response)
|
|
}
|
|
|
|
func handleChatImageCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
// completions only accepts POST requests
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
// 检测头部是否有X-DashScope-SSE: enable
|
|
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
|
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
// Send test responses
|
|
dataBytes := []byte{}
|
|
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
|
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
|
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
|
//nolint:lll
|
|
data := `{"output":{"choices":[{"message":{"content":[{"text":"这张"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":1,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
|
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
|
|
|
dataBytes = append(dataBytes, []byte("id:2\n")...)
|
|
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
|
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
|
//nolint:lll
|
|
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":2,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
|
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
|
|
|
dataBytes = append(dataBytes, []byte("id:3\n")...)
|
|
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
|
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
|
//nolint:lll
|
|
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片展示的是一个海滩的场景,具体来说是在日落时分。由于没有明显的地标或建筑物等特征可以辨认出具体的地点信息,所以无法确定这是哪个地方的海滩。但是根据图像中的元素和环境特点,我们可以推测这可能是一个位于沿海地区的沙滩海岸线。"}],"role":"assistant"}}],"finish_reason":"stop"},"usage":{"input_tokens":1279,"output_tokens":63,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
|
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
|
|
|
_, err := w.Write(dataBytes)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func streamResponseCheck(t *testing.T, response string) {
|
|
// 以换行符分割response
|
|
lines := strings.Split(response, "\n\n")
|
|
// 如果最后一行为空,则删除最后一行
|
|
if lines[len(lines)-1] == "" {
|
|
lines = lines[:len(lines)-1]
|
|
}
|
|
|
|
// 循环遍历每一行
|
|
for _, line := range lines {
|
|
if line == "" {
|
|
continue
|
|
}
|
|
// assert判断 是否以data: 开头
|
|
assert.True(t, strings.HasPrefix(line, "data: "))
|
|
}
|
|
|
|
// 检测最后一行是否以data: [DONE] 结尾
|
|
assert.True(t, strings.HasSuffix(lines[len(lines)-1], "data: [DONE]"))
|
|
// 检测倒数第二行是否存在 `"finish_reason":"stop"`
|
|
assert.True(t, strings.Contains(lines[len(lines)-2], `"finish_reason":"stop"`))
|
|
}
|