mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-09 18:23:40 +08:00
chore: reorganize adaptor related package
This commit is contained in:
@@ -1,8 +0,0 @@
|
||||
package ai360
|
||||
|
||||
var ModelList = []string{
|
||||
"360GPT_S2_V9",
|
||||
"embedding-bert-512-v1",
|
||||
"embedding_s1_v1",
|
||||
"semantic_similarity_s1_v1",
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
aiProxyLibraryRequest := ConvertRequest(*request)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
|
||||
return aiProxyLibraryRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "aiproxy"
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package aiproxy
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
|
||||
var ModelList = []string{""}
|
||||
|
||||
func init() {
|
||||
ModelList = openai.ModelList
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
|
||||
query := ""
|
||||
if len(request.Messages) != 0 {
|
||||
query = request.Messages[len(request.Messages)-1].StringContent()
|
||||
}
|
||||
return &LibraryRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Query: query,
|
||||
}
|
||||
}
|
||||
|
||||
func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
|
||||
if len(documents) == 0 {
|
||||
return ""
|
||||
}
|
||||
content := "\n\n参考文档:\n"
|
||||
for i, document := range documents {
|
||||
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
|
||||
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = response.Content
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
var documents []LibraryDocument
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var AIProxyLibraryResponse LibraryStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if len(AIProxyLibraryResponse.Documents) != 0 {
|
||||
documents = AIProxyLibraryResponse.Documents
|
||||
}
|
||||
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
response := documentsAIProxyLibrary(documents)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var AIProxyLibraryResponse LibraryResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: AIProxyLibraryResponse.Message,
|
||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||
Code: AIProxyLibraryResponse.ErrCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package aiproxy
|
||||
|
||||
type LibraryRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
LibraryId string `json:"libraryId"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type LibraryError struct {
|
||||
ErrCode int `json:"errCode"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type LibraryDocument struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type LibraryResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Answer string `json:"answer"`
|
||||
Documents []LibraryDocument `json:"documents"`
|
||||
LibraryError
|
||||
}
|
||||
|
||||
type LibraryStreamResponse struct {
|
||||
Content string `json:"content"`
|
||||
Finish bool `json:"finish"`
|
||||
Model string `json:"model"`
|
||||
Documents []LibraryDocument `json:"documents"`
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
fullRequestURL := ""
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
|
||||
case relaymode.ImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
|
||||
default:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
|
||||
}
|
||||
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.IsStream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
|
||||
if meta.Mode == relaymode.ImagesGenerations {
|
||||
req.Header.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
if c.GetString(common.ConfigKeyPlugin) != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return aliEmbeddingRequest, nil
|
||||
default:
|
||||
aliRequest := ConvertRequest(*request)
|
||||
return aliRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
aliRequest := ConvertImageRequest(*request)
|
||||
return aliRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
case relaymode.ImagesGenerations:
|
||||
err, usage = ImageHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "ali"
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package ali
|
||||
|
||||
var ModelList = []string{
|
||||
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
||||
"text-embedding-v1",
|
||||
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse TaskResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
logger.SysError("aliAsyncTask err: " + string(responseBody))
|
||||
return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
|
||||
url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
|
||||
|
||||
var aliResponse TaskResponse
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysError("aliAsyncTask client.Do err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
var response TaskResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
return &response, nil, responseBody
|
||||
}
|
||||
|
||||
func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
|
||||
waitSeconds := 2
|
||||
step := 0
|
||||
maxStep := 20
|
||||
|
||||
var taskResponse TaskResponse
|
||||
var responseBody []byte
|
||||
|
||||
for {
|
||||
step++
|
||||
rsp, err, body := asyncTask(taskID, key)
|
||||
responseBody = body
|
||||
if err != nil {
|
||||
return &taskResponse, responseBody, err
|
||||
}
|
||||
|
||||
if rsp.Output.TaskStatus == "" {
|
||||
return &taskResponse, responseBody, nil
|
||||
}
|
||||
|
||||
switch rsp.Output.TaskStatus {
|
||||
case "FAILED":
|
||||
fallthrough
|
||||
case "CANCELED":
|
||||
fallthrough
|
||||
case "SUCCEEDED":
|
||||
fallthrough
|
||||
case "UNKNOWN":
|
||||
return rsp, responseBody, nil
|
||||
}
|
||||
if step >= maxStep {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
|
||||
}
|
||||
|
||||
func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
|
||||
imageResponse := openai.ImageResponse{
|
||||
Created: helper.GetTimestamp(),
|
||||
}
|
||||
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
// 读取 data.Url 的图片数据并转存到 b64Json
|
||||
imageData, err := getImageData(data.Url)
|
||||
if err != nil {
|
||||
// 处理获取图片数据失败的情况
|
||||
logger.SysError("getImageData Error getting image data: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 将图片数据转为 Base64 编码的字符串
|
||||
b64Json = Base64Encode(imageData)
|
||||
} else {
|
||||
// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageResponse.Data = append(imageResponse.Data, openai.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func getImageData(url string) ([]byte, error) {
|
||||
response, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
imageData, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return imageData, nil
|
||||
}
|
||||
|
||||
func Base64Encode(data []byte) string {
|
||||
b64Json := base64.StdEncoding.EncodeToString(data)
|
||||
return b64Json
|
||||
}
|
||||
@@ -1,278 +0,0 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
const EnableSearchModelSuffix = "-internet"
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
messages = append(messages, Message{
|
||||
Content: message.StringContent(),
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
enableSearch := false
|
||||
aliModel := request.Model
|
||||
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||
}
|
||||
if request.TopP >= 1 {
|
||||
request.TopP = 0.9999
|
||||
}
|
||||
return &ChatRequest{
|
||||
Model: aliModel,
|
||||
Input: Input{
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: Parameters{
|
||||
EnableSearch: enableSearch,
|
||||
IncrementalOutput: request.Stream,
|
||||
Seed: uint64(request.Seed),
|
||||
MaxTokens: request.MaxTokens,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
TopK: request.TopK,
|
||||
ResultFormat: "message",
|
||||
Tools: request.Tools,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertImageRequest(request model.ImageRequest) *ImageRequest {
|
||||
var imageRequest ImageRequest
|
||||
imageRequest.Input.Prompt = request.Prompt
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
||||
imageRequest.Parameters.N = request.N
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
return &imageRequest
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var aliResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: response.Output.Choices,
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
if len(aliResponse.Output.Choices) == 0 {
|
||||
return nil
|
||||
}
|
||||
aliChoice := aliResponse.Output.Choices[0]
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta = aliChoice.Message
|
||||
if aliChoice.FinishReason != "null" {
|
||||
finishReason := aliChoice.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "qwen",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
//lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
//lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
ctx := c.Request.Context()
|
||||
var aliResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
logger.Debugf(ctx, "response body: %s\n", responseBody)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
fullTextResponse.Model = "qwen"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
//Prompt string `json:"prompt"`
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
ResultFormat string `json:"result_format,omitempty"`
|
||||
Tools []model.Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameters Parameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
} `json:"input"`
|
||||
Parameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResponse struct {
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Output struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Results []struct {
|
||||
B64Image string `json:"b64_image,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
} `json:"results,omitempty"`
|
||||
TaskMetrics struct {
|
||||
Total int `json:"TOTAL,omitempty"`
|
||||
Succeeded int `json:"SUCCEEDED,omitempty"`
|
||||
Failed int `json:"FAILED,omitempty"`
|
||||
} `json:"task_metrics,omitempty"`
|
||||
} `json:"output,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
Action string `json:"action,omitempty"`
|
||||
Streaming string `json:"streaming,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Event string `json:"event,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Attributes any `json:"attributes,omitempty"`
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Task string `json:"task,omitempty"`
|
||||
TaskGroup string `json:"task_group,omitempty"`
|
||||
Function string `json:"function,omitempty"`
|
||||
Parameters struct {
|
||||
SampleRate int `json:"sample_rate,omitempty"`
|
||||
Rate float64 `json:"rate,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
Input struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"input,omitempty"`
|
||||
Usage struct {
|
||||
Characters int `json:"characters,omitempty"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type WSSMessage struct {
|
||||
Header Header `json:"header,omitempty"`
|
||||
Payload Payload `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
//Text string `json:"text"`
|
||||
//FinishReason string `json:"finish_reason"`
|
||||
Choices []openai.TextResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-api-key", meta.APIKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
req.Header.Set("anthropic-beta", "messages-2023-12-15")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "anthropic"
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason *string) string {
|
||||
if reason == nil {
|
||||
return ""
|
||||
}
|
||||
switch *reason {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
default:
|
||||
return *reason
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
claudeRequest := Request{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
}
|
||||
// legacy model name mapping
|
||||
if claudeRequest.Model == "claude-instant-1" {
|
||||
claudeRequest.Model = "claude-instant-1.1"
|
||||
} else if claudeRequest.Model == "claude-2" {
|
||||
claudeRequest.Model = "claude-2.1"
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "system" && claudeRequest.System == "" {
|
||||
claudeRequest.System = message.StringContent()
|
||||
continue
|
||||
}
|
||||
claudeMessage := Message{
|
||||
Role: message.Role,
|
||||
}
|
||||
var content Content
|
||||
if message.IsStringContent() {
|
||||
content.Type = "text"
|
||||
content.Text = message.StringContent()
|
||||
claudeMessage.Content = append(claudeMessage.Content, content)
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
continue
|
||||
}
|
||||
var contents []Content
|
||||
openaiContent := message.ParseContent()
|
||||
for _, part := range openaiContent {
|
||||
var content Content
|
||||
if part.Type == model.ContentTypeText {
|
||||
content.Type = "text"
|
||||
content.Text = part.Text
|
||||
} else if part.Type == model.ContentTypeImageURL {
|
||||
content.Type = "image"
|
||||
content.Source = &ImageSource{
|
||||
Type: "base64",
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
content.Source.MediaType = mimeType
|
||||
content.Source.Data = data
|
||||
}
|
||||
contents = append(contents, content)
|
||||
}
|
||||
claudeMessage.Content = contents
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
}
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/messages-streaming
|
||||
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
||||
var response *Response
|
||||
var responseText string
|
||||
var stopReason string
|
||||
switch claudeResponse.Type {
|
||||
case "message_start":
|
||||
return nil, claudeResponse.Message
|
||||
case "content_block_start":
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
responseText = claudeResponse.ContentBlock.Text
|
||||
}
|
||||
case "content_block_delta":
|
||||
if claudeResponse.Delta != nil {
|
||||
responseText = claudeResponse.Delta.Text
|
||||
}
|
||||
case "message_delta":
|
||||
if claudeResponse.Usage != nil {
|
||||
response = &Response{
|
||||
Usage: *claudeResponse.Usage,
|
||||
}
|
||||
}
|
||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||
stopReason = *claudeResponse.Delta.StopReason
|
||||
}
|
||||
}
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = responseText
|
||||
choice.Delta.Role = "assistant"
|
||||
finishReason := stopReasonClaude2OpenAI(&stopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var openaiResponse openai.ChatCompletionsStreamResponse
|
||||
openaiResponse.Object = "chat.completion.chunk"
|
||||
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &openaiResponse, response
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||
var responseText string
|
||||
if len(claudeResponse.Content) > 0 {
|
||||
responseText = claudeResponse.Content[0].Text
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: responseText,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
|
||||
Model: claudeResponse.Model,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
createdTime := helper.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(data, "data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
var usage model.Usage
|
||||
var modelName string
|
||||
var id string
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse StreamResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
|
||||
if meta != nil {
|
||||
usage.PromptTokens += meta.Usage.InputTokens
|
||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||
modelName = meta.Model
|
||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||
return true
|
||||
}
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = modelName
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
_ = resp.Body.Close()
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var claudeResponse Response
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
usage := model.Usage{
|
||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
type Metadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content []Content `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []Content `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
Type string `json:"type"`
|
||||
Message *Response `json:"message"`
|
||||
Index int `json:"index"`
|
||||
ContentBlock *Content `json:"content_block"`
|
||||
Delta *Delta `json:"delta"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
)
|
||||
|
||||
func GetAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package baichuan
|
||||
|
||||
var ModelList = []string{
|
||||
"Baichuan2-Turbo",
|
||||
"Baichuan2-Turbo-192k",
|
||||
"Baichuan-Text-Embedding",
|
||||
}
|
||||
@@ -1,143 +0,0 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
suffix := "chat/"
|
||||
if strings.HasPrefix(meta.ActualModelName, "Embedding") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(meta.ActualModelName, "bge-large") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(meta.ActualModelName, "tao-8k") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
switch meta.ActualModelName {
|
||||
case "ERNIE-4.0":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot-4":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot":
|
||||
suffix += "completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Speed":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-4.0-8K":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-3.5-8K":
|
||||
suffix += "completions"
|
||||
case "ERNIE-3.5-8K-0205":
|
||||
suffix += "ernie-3.5-8k-0205"
|
||||
case "ERNIE-3.5-8K-1222":
|
||||
suffix += "ernie-3.5-8k-1222"
|
||||
case "ERNIE-Bot-8K":
|
||||
suffix += "ernie_bot_8k"
|
||||
case "ERNIE-3.5-4K-0205":
|
||||
suffix += "ernie-3.5-4k-0205"
|
||||
case "ERNIE-Speed-8K":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-Speed-128K":
|
||||
suffix += "ernie-speed-128k"
|
||||
case "ERNIE-Lite-8K-0922":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Lite-8K-0308":
|
||||
suffix += "ernie-lite-8k"
|
||||
case "ERNIE-Tiny-8K":
|
||||
suffix += "ernie-tiny-8k"
|
||||
case "BLOOMZ-7B":
|
||||
suffix += "bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
suffix += "embedding-v1"
|
||||
case "bge-large-zh":
|
||||
suffix += "bge_large_zh"
|
||||
case "bge-large-en":
|
||||
suffix += "bge_large_en"
|
||||
case "tao-8k":
|
||||
suffix += "tao_8k"
|
||||
default:
|
||||
suffix += strings.ToLower(meta.ActualModelName)
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fullRequestURL += "?access_token=" + accessToken
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := ConvertRequest(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "baidu"
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-4.0-8K",
|
||||
"ERNIE-3.5-8K",
|
||||
"ERNIE-3.5-8K-0205",
|
||||
"ERNIE-3.5-8K-1222",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-3.5-4K-0205",
|
||||
"ERNIE-Speed-8K",
|
||||
"ERNIE-Speed-128K",
|
||||
"ERNIE-Lite-8K-0922",
|
||||
"ERNIE-Lite-8K-0308",
|
||||
"ERNIE-Tiny-8K",
|
||||
"BLOOMZ-7B",
|
||||
"Embedding-V1",
|
||||
"bge-large-zh",
|
||||
"bge-large-en",
|
||||
"tao-8k",
|
||||
}
|
||||
@@ -1,328 +0,0 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
|
||||
type TokenResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
baiduRequest := ChatRequest{
|
||||
Messages: make([]Message, 0, len(request.Messages)),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
UserId: request.User,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
baiduRequest.System = message.StringContent()
|
||||
} else {
|
||||
baiduRequest.Messages = append(baiduRequest.Messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &baiduRequest
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Result,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: response.Created,
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Usage,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "baidu-embedding",
|
||||
Usage: response.Usage,
|
||||
}
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse ChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
fullTextResponse.Model = "ernie-bot"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func GetAccessToken(apiKey string) (string, error) {
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken AccessToken
|
||||
if accessToken, ok = val.(AccessToken); ok {
|
||||
// soon this will expire
|
||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
go func() {
|
||||
_, _ = getBaiduAccessTokenHelper(apiKey)
|
||||
}()
|
||||
}
|
||||
return accessToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accessToken == nil {
|
||||
return "", errors.New("GetAccessToken return a nil token")
|
||||
}
|
||||
return (*accessToken).AccessToken, nil
|
||||
}
|
||||
|
||||
func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
}
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||
parts[0], parts[1]), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := util.ImpatientHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var accessToken AccessToken
|
||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accessToken.Error != "" {
|
||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
}
|
||||
if accessToken.AccessToken == "" {
|
||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||
}
|
||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||
baiduTokenStore.Store(apiKey, accessToken)
|
||||
return &accessToken, nil
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type ChatStreamResponse struct {
|
||||
ChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, req, meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := DoRequest(c, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
return resp, nil
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
||||
action := "generateContent"
|
||||
if meta.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channelhelper.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google gemini"
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package gemini
|
||||
|
||||
// https://ai.google.dev/models/gemini
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
|
||||
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
||||
|
||||
const (
|
||||
VisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest := ChatRequest{
|
||||
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []ChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
},
|
||||
GenerationConfig: ChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []ChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
content := ChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []Part
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == model.ContentTypeText {
|
||||
parts = append(parts, Part{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == model.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > VisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
parts = append(parts, Part{
|
||||
InlineData: &InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to prompt from user for the same reason
|
||||
if content.Role == "system" {
|
||||
content.Role = "user"
|
||||
shouldAddDummyModelMessage = true
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
|
||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
|
||||
Role: "model",
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
},
|
||||
})
|
||||
shouldAddDummyModelMessage = false
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatCandidate `json:"candidates"`
|
||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
|
||||
func (g *ChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
}
|
||||
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
return g.Candidates[0].Content.Parts[0].Text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ChatCandidate struct {
|
||||
Content ChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type ChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type ChatPromptFeedback struct {
|
||||
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: constant.StopFinishReason,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
|
||||
usage := model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package gemini
|
||||
|
||||
type ChatRequest struct {
|
||||
Contents []ChatContent `json:"contents"`
|
||||
SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []ChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type InlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type ChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []Part `json:"parts"`
|
||||
}
|
||||
|
||||
type ChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type ChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type ChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package groq
|
||||
|
||||
// https://console.groq.com/docs/models
|
||||
|
||||
var ModelList = []string{
|
||||
"gemma-7b-it",
|
||||
"llama2-7b-2048",
|
||||
"llama2-70b-4096",
|
||||
"mixtral-8x7b-32768",
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
Init(meta *meta.Meta)
|
||||
GetRequestURL(meta *meta.Meta) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
ConvertImageRequest(request *model.ImageRequest) (any, error)
|
||||
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package lingyiwanwu
|
||||
|
||||
// https://platform.lingyiwanwu.com/docs
|
||||
|
||||
var ModelList = []string{
|
||||
"yi-34b-chat-0205",
|
||||
"yi-34b-chat-200k",
|
||||
"yi-vl-plus",
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package minimax
|
||||
|
||||
var ModelList = []string{
|
||||
"abab5.5s-chat",
|
||||
"abab5.5-chat",
|
||||
"abab6-chat",
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if meta.Mode == relaymode.ChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay relaymode %d for minimax", meta.Mode)
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package mistral
|
||||
|
||||
var ModelList = []string{
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
"mistral-embed",
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package moonshot
|
||||
|
||||
var ModelList = []string{
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k",
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
|
||||
if meta.Mode == relaymode.Embeddings {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return ollamaEmbeddingRequest, nil
|
||||
default:
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "ollama"
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package ollama
|
||||
|
||||
var ModelList = []string{
|
||||
"qwen:0.5b-chat",
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
ollamaRequest := ChatRequest{
|
||||
Model: request.Model,
|
||||
Options: &Options{
|
||||
Seed: int(request.Seed),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
},
|
||||
Stream: request.Stream,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
ollamaRequest.Messages = append(ollamaRequest.Messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
return &ollamaRequest
|
||||
}
|
||||
|
||||
func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: response.Message.Role,
|
||||
Content: response.Message.Content,
|
||||
},
|
||||
}
|
||||
if response.Done {
|
||||
choice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.PromptEvalCount,
|
||||
CompletionTokens: response.EvalCount,
|
||||
TotalTokens: response.PromptEvalCount + response.EvalCount,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Role = ollamaResponse.Message.Role
|
||||
choice.Delta.Content = ollamaResponse.Message.Content
|
||||
if ollamaResponse.Done {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: ollamaResponse.Model,
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "}\n"); i >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := strings.TrimPrefix(scanner.Text(), "}")
|
||||
dataChan <- data + "}"
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var ollamaResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &ollamaResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if ollamaResponse.EvalCount != 0 {
|
||||
usage.PromptTokens = ollamaResponse.PromptEvalCount
|
||||
usage.CompletionTokens = ollamaResponse.EvalCount
|
||||
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
|
||||
}
|
||||
response := streamResponseOllama2OpenAI(&ollamaResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Prompt: strings.Join(request.ParseInput(), " "),
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var ollamaResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&ollamaResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if ollamaResponse.Error != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: ollamaResponse.Error,
|
||||
Type: "ollama_error",
|
||||
Param: "",
|
||||
Code: "ollama_error",
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, 1),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: model.Usage{TotalTokens: 0},
|
||||
}
|
||||
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: 0,
|
||||
Embedding: response.Embedding,
|
||||
})
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
ctx := context.TODO()
|
||||
var ollamaResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
logger.Debugf(ctx, "ollama response: %s", string(responseBody))
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &ollamaResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if ollamaResponse.Error != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: ollamaResponse.Error,
|
||||
Type: "ollama_error",
|
||||
Param: "",
|
||||
Code: "ollama_error",
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseOllama2OpenAI(&ollamaResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package ollama
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
Message Message `json:"message,omitempty"`
|
||||
Response string `json:"response,omitempty"` // for stream response
|
||||
Done bool `json:"done,omitempty"`
|
||||
TotalDuration int `json:"total_duration,omitempty"`
|
||||
LoadDuration int `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Embedding []float64 `json:"embedding,omitempty"`
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.ChannelType = meta.ChannelType
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
switch meta.ChannelType {
|
||||
case channeltype.Azure:
|
||||
if meta.Mode == relaymode.ImagesGenerations {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
||||
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := meta.ActualModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
//https://github.com/songquanpeng/one-api/issues/1191
|
||||
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
|
||||
case channeltype.Minimax:
|
||||
return minimax.GetRequestURL(meta)
|
||||
default:
|
||||
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.ChannelType == channeltype.Azure {
|
||||
req.Header.Set("api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
if meta.ChannelType == channeltype.OpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
|
||||
if usage == nil {
|
||||
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
}
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err, _ = ImageHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
_, modelList := GetCompatibleChannelMeta(a.ChannelType)
|
||||
return modelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
channelName, _ := GetCompatibleChannelMeta(a.ChannelType)
|
||||
return channelName
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/ai360"
|
||||
"github.com/songquanpeng/one-api/relay/channel/baichuan"
|
||||
"github.com/songquanpeng/one-api/relay/channel/groq"
|
||||
"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu"
|
||||
"github.com/songquanpeng/one-api/relay/channel/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/channel/mistral"
|
||||
"github.com/songquanpeng/one-api/relay/channel/moonshot"
|
||||
"github.com/songquanpeng/one-api/relay/channel/stepfun"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
var CompatibleChannels = []int{
|
||||
channeltype.Azure,
|
||||
channeltype.AI360,
|
||||
channeltype.Moonshot,
|
||||
channeltype.Baichuan,
|
||||
channeltype.Minimax,
|
||||
channeltype.Mistral,
|
||||
channeltype.Groq,
|
||||
channeltype.LingYiWanWu,
|
||||
channeltype.StepFun,
|
||||
}
|
||||
|
||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||
switch channelType {
|
||||
case channeltype.Azure:
|
||||
return "azure", ModelList
|
||||
case channeltype.AI360:
|
||||
return "360", ai360.ModelList
|
||||
case channeltype.Moonshot:
|
||||
return "moonshot", moonshot.ModelList
|
||||
case channeltype.Baichuan:
|
||||
return "baichuan", baichuan.ModelList
|
||||
case channeltype.Minimax:
|
||||
return "minimax", minimax.ModelList
|
||||
case channeltype.Mistral:
|
||||
return "mistralai", mistral.ModelList
|
||||
case channeltype.Groq:
|
||||
return "groq", groq.ModelList
|
||||
case channeltype.LingYiWanWu:
|
||||
return "lingyiwanwu", lingyiwanwu.ModelList
|
||||
case channeltype.StepFun:
|
||||
return "stepfun", stepfun.ModelList
|
||||
default:
|
||||
return "openai", ModelList
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package openai
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
||||
usage := &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var imageResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &imageResponse)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/conv"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
var usage *model.Usage
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
switch relayMode {
|
||||
case relaymode.ChatCompletions:
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue // just ignore the error
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += conv.AsString(choice.Delta.Content)
|
||||
}
|
||||
if streamResponse.Usage != nil {
|
||||
usage = streamResponse.Usage
|
||||
}
|
||||
case relaymode.Completions:
|
||||
var streamResponse CompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += choice.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
||||
}
|
||||
return nil, responseText, usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var textResponse SlimTextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the HTTPClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
||||
}
|
||||
textResponse.Usage = model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
return nil, &textResponse.Usage
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
type TextContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type ImageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
ImageURL *model.ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperJSONResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperVerboseJSONResponse struct {
|
||||
Task string `json:"task,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Segments []Segment `json:"segments,omitempty"`
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
Id int `json:"id"`
|
||||
Seek int `json:"seek"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
AvgLogprob float64 `json:"avg_logprob"`
|
||||
CompressionRatio float64 `json:"compression_ratio"`
|
||||
NoSpeechProb float64 `json:"no_speech_prob"`
|
||||
}
|
||||
|
||||
type TextToSpeechRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Input string `json:"input" binding:"required"`
|
||||
Voice string `json:"voice" binding:"required"`
|
||||
Speed float64 `json:"speed"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
}
|
||||
|
||||
type UsageOrResponseText struct {
|
||||
*model.Usage
|
||||
ResponseText string
|
||||
}
|
||||
|
||||
type SlimTextResponse struct {
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
Error model.Error `json:"error"`
|
||||
}
|
||||
|
||||
type TextResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
model.Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type EmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
//model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta model.Message `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
@@ -1,208 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
logger.SysLog("initializing token encoders")
|
||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
}
|
||||
defaultTokenEncoder = gpt35TokenEncoder
|
||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model := range billingratio.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = nil
|
||||
}
|
||||
}
|
||||
logger.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
return tokenEncoder
|
||||
}
|
||||
if ok {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
if config.ApproximateTokenEnabled {
|
||||
return int(float64(len(text)) * 0.38)
|
||||
}
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []model.Message, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if model == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
switch v := message.Content.(type) {
|
||||
case string:
|
||||
tokenNum += getTokenNum(tokenEncoder, v)
|
||||
case []any:
|
||||
for _, it := range v {
|
||||
m := it.(map[string]any)
|
||||
switch m["type"] {
|
||||
case "text":
|
||||
tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
|
||||
case "image_url":
|
||||
imageUrl, ok := m["image_url"].(map[string]any)
|
||||
if ok {
|
||||
url := imageUrl["url"].(string)
|
||||
detail := ""
|
||||
if imageUrl["detail"] != nil {
|
||||
detail = imageUrl["detail"].(string)
|
||||
}
|
||||
imageTokens, err := countImageTokens(url, detail)
|
||||
if err != nil {
|
||||
logger.SysError("error counting image tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += imageTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum
|
||||
}
|
||||
|
||||
const (
|
||||
lowDetailCost = 85
|
||||
highDetailCostPerTile = 170
|
||||
additionalCost = 85
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
func countImageTokens(url string, detail string) (_ int, err error) {
|
||||
var fetchSize = true
|
||||
var width, height int
|
||||
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
|
||||
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
|
||||
// According to the official guide, "low" disable the high-res model,
|
||||
// and only receive low-res 512px x 512px version of the image, indicating
|
||||
// that image is treated as low-res when size is smaller than 512px x 512px,
|
||||
// then we can assume that image size larger than 512px x 512px is treated
|
||||
// as high-res. Then we have the following logic:
|
||||
// if detail == "" || detail == "auto" {
|
||||
// width, height, err = image.GetImageSize(url)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// fetchSize = false
|
||||
// // not sure if this is correct
|
||||
// if width > 512 || height > 512 {
|
||||
// detail = "high"
|
||||
// } else {
|
||||
// detail = "low"
|
||||
// }
|
||||
// }
|
||||
|
||||
// However, in my test, it seems to be always the same as "high".
|
||||
// The following image, which is 125x50, is still treated as high-res, taken
|
||||
// 255 tokens in the response of non-stream chat completion api.
|
||||
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
|
||||
if detail == "" || detail == "auto" {
|
||||
// assume by test, not sure if this is correct
|
||||
detail = "high"
|
||||
}
|
||||
switch detail {
|
||||
case "low":
|
||||
return lowDetailCost, nil
|
||||
case "high":
|
||||
if fetchSize {
|
||||
width, height, err = image.GetImageSize(url)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if width > 2048 || height > 2048 { // max(width, height) > 2048
|
||||
ratio := float64(2048) / math.Max(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
if width > 768 && height > 768 { // min(width, height) > 768
|
||||
ratio := float64(768) / math.Min(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
|
||||
result := numSquares*highDetailCostPerTile + additionalCost
|
||||
return result, nil
|
||||
default:
|
||||
return 0, errors.New("invalid detail option")
|
||||
}
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return CountTokenText(text, model)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func CountTokenText(text string, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||
Error := model.Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
Code: code,
|
||||
}
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: Error,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google palm"
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package palm
|
||||
|
||||
var ModelList = []string{
|
||||
"PaLM-2",
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type ChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Filter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Prompt struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Prompt Prompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatMessage `json:"candidates"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
Filters []Filter `json:"filters"`
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
palmRequest := ChatRequest{
|
||||
Prompt: Prompt{
|
||||
Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
|
||||
},
|
||||
Temperature: textRequest.Temperature,
|
||||
CandidateCount: textRequest.N,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.MaxTokens,
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := ChatMessage{
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
} else {
|
||||
palmMessage.Author = "1"
|
||||
}
|
||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||
}
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: candidate.Content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
||||
createdTime := helper.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.SysError("error reading stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
logger.SysError("error closing stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
var palmResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.Id = responseId
|
||||
fullTextResponse.Created = createdTime
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
responseText = palmResponse.Candidates[0].Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var palmResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
Code: palmResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName)
|
||||
usage := model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package stepfun
|
||||
|
||||
var ModelList = []string{
|
||||
"step-1-32k",
|
||||
"step-1v-32k",
|
||||
"step-1-200k",
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/api/1729/101837
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", meta.ActualModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := ParseConfig(apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := ConvertRequest(*request)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
// we have to calculate the sign here
|
||||
a.Sign = GetSign(*tencentRequest, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "tencent"
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package tencent
|
||||
|
||||
var ModelList = []string{
|
||||
"ChatPro",
|
||||
"ChatStd",
|
||||
"hunyuan",
|
||||
}
|
||||
@@ -1,230 +0,0 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/conv"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
messages = append(messages, Message{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
stream := 0
|
||||
if request.Stream {
|
||||
stream = 1
|
||||
}
|
||||
return &ChatRequest{
|
||||
Timestamp: helper.GetTimestamp(),
|
||||
Expired: helper.GetTimestamp() + 24*60*60,
|
||||
QueryID: random.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
Messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Choices[0].Messages.Content,
|
||||
},
|
||||
FinishReason: response.Choices[0].FinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += conv.AsString(response.Choices[0].Delta.Content)
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var TencentResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
fullTextResponse.Model = "hunyuan"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
parts := strings.Split(config, "|")
|
||||
if len(parts) != 3 {
|
||||
err = errors.New("invalid tencent config")
|
||||
return
|
||||
}
|
||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
secretId = parts[1]
|
||||
secretKey = parts[2]
|
||||
return
|
||||
}
|
||||
|
||||
func GetSign(req ChatRequest, secretKey string) string {
|
||||
params := make([]string, 0)
|
||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
params = append(params, "secret_id="+req.SecretId)
|
||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
params = append(params, "query_id="+req.QueryID)
|
||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
var messageStr string
|
||||
for _, msg := range req.Messages {
|
||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
}
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Strings(params)
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
mac.Write([]byte(signURL))
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type ResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage model.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
request *model.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
// check DoResponse for auth part
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
a.request = request
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
dummyResp.StatusCode = http.StatusOK
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
splits := strings.Split(meta.APIKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
if a.request == nil {
|
||||
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
}
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "xunfei"
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package xunfei
|
||||
|
||||
var ModelList = []string{
|
||||
"SparkDesk",
|
||||
"SparkDesk-v1.1",
|
||||
"SparkDesk-v2.1",
|
||||
"SparkDesk-v3.1",
|
||||
"SparkDesk-v3.5",
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
|
||||
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
var lastToolCalls []model.Tool
|
||||
for _, message := range request.Messages {
|
||||
if message.ToolCalls != nil {
|
||||
lastToolCalls = message.ToolCalls
|
||||
}
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
xunfeiRequest := ChatRequest{}
|
||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
if len(lastToolCalls) != 0 {
|
||||
for _, toolCall := range lastToolCalls {
|
||||
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
func getToolCalls(response *ChatResponse) []model.Tool {
|
||||
var toolCalls []model.Tool
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
return toolCalls
|
||||
}
|
||||
item := response.Payload.Choices.Text[0]
|
||||
if item.FunctionCall == nil {
|
||||
return toolCalls
|
||||
}
|
||||
toolCall := model.Tool{
|
||||
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||
Type: "function",
|
||||
Function: *item.FunctionCall,
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
response.Payload.Choices.Text = []ChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
ToolCalls: getToolCalls(response),
|
||||
},
|
||||
FinishReason: constant.StopFinishReason,
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Payload.Usage.Text,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
choice.Delta.ToolCalls = getToolCalls(xunfeiResponse)
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
encodeData := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(encodeData)
|
||||
}
|
||||
ul, err := url.Parse(hostUrl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
date := time.Now().UTC().Format(time.RFC1123)
|
||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
sign := strings.Join(signString, "\n")
|
||||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
"hmac-sha256", "host date request-line", sha)
|
||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
v := url.Values{}
|
||||
v.Add("host", ul.Host)
|
||||
v.Add("date", date)
|
||||
v.Add("authorization", authorization)
|
||||
callUrl := hostUrl + "?" + v.Encode()
|
||||
return callUrl
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.SetEventStreamHeaders(c)
|
||||
var usage model.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var usage model.Usage
|
||||
var content string
|
||||
var xunfeiResponse ChatResponse
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case xunfeiResponse = <-dataChan:
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil
|
||||
}
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
conn, resp, err := d.Dial(authUrl, nil)
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return nil, nil, err
|
||||
}
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dataChan := make(chan ChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
if msg == nil {
|
||||
_, msg, err = conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
var response ChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
msg = nil
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
logger.SysError("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
|
||||
return dataChan, stopChan, nil
|
||||
}
|
||||
|
||||
func getAPIVersion(c *gin.Context, modelName string) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion != "" {
|
||||
return apiVersion
|
||||
}
|
||||
parts := strings.Split(modelName, "-")
|
||||
if len(parts) == 2 {
|
||||
apiVersion = parts[1]
|
||||
return apiVersion
|
||||
|
||||
}
|
||||
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
|
||||
if apiVersion != "" {
|
||||
return apiVersion
|
||||
}
|
||||
apiVersion = "v1.1"
|
||||
logger.SysLog("api_version not found, using default: " + apiVersion)
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
|
||||
func apiVersion2domain(apiVersion string) string {
|
||||
switch apiVersion {
|
||||
case "v1.1":
|
||||
return "general"
|
||||
case "v2.1":
|
||||
return "generalv2"
|
||||
case "v3.1":
|
||||
return "generalv3"
|
||||
case "v3.5":
|
||||
return "generalv3.5"
|
||||
}
|
||||
return "general" + apiVersion
|
||||
}
|
||||
|
||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
|
||||
apiVersion := getAPIVersion(c, modelName)
|
||||
domain := apiVersion2domain(apiVersion)
|
||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []Message `json:"text"`
|
||||
} `json:"message"`
|
||||
Functions struct {
|
||||
Text []model.Function `json:"text,omitempty"`
|
||||
} `json:"functions,omitempty"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type ChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
ContentType string `json:"content_type"`
|
||||
FunctionCall *model.Function `json:"function_call"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []ChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text model.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
APIVersion string
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetVersionByModeName(modelName string) {
|
||||
if strings.HasPrefix(modelName, "glm-") {
|
||||
a.APIVersion = "v4"
|
||||
} else {
|
||||
a.APIVersion = "v3"
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
|
||||
case relaymode.Embeddings:
|
||||
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
|
||||
}
|
||||
a.SetVersionByModeName(meta.ActualModelName)
|
||||
if a.APIVersion == "v4" {
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
|
||||
}
|
||||
method := "invoke"
|
||||
if meta.IsStream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
token := GetToken(meta.APIKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
// TopP (0.0, 1.0)
|
||||
request.TopP = math.Min(0.99, request.TopP)
|
||||
request.TopP = math.Max(0.01, request.TopP)
|
||||
|
||||
// Temperature (0.0, 1.0)
|
||||
request.Temperature = math.Min(0.99, request.Temperature)
|
||||
request.Temperature = math.Max(0.01, request.Temperature)
|
||||
a.SetVersionByModeName(request.Model)
|
||||
if a.APIVersion == "v4" {
|
||||
return request, nil
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
newRequest := ImageRequest{
|
||||
Model: request.Model,
|
||||
Prompt: request.Prompt,
|
||||
UserId: request.User,
|
||||
}
|
||||
return newRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
|
||||
} else {
|
||||
err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingsHandler(c, resp)
|
||||
return
|
||||
case relaymode.ImagesGenerations:
|
||||
err, usage = openai.ImageHandler(c, resp)
|
||||
return
|
||||
}
|
||||
if a.APIVersion == "v4" {
|
||||
return a.DoResponseV4(c, resp, meta)
|
||||
}
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
if meta.Mode == relaymode.Embeddings {
|
||||
err, usage = EmbeddingsHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: "embedding-2",
|
||||
Input: request.Input.(string),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "zhipu"
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package zhipu
|
||||
|
||||
var ModelList = []string{
|
||||
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
|
||||
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
|
||||
"cogview-3",
|
||||
}
|
||||
@@ -1,303 +0,0 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
// chatglm_std, chatglm_lite
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
func GetToken(apikey string) string {
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(tokenData)
|
||||
if time.Now().Before(tokenData.ExpiryTime) {
|
||||
return tokenData.Token
|
||||
}
|
||||
}
|
||||
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
logger.SysError("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
id := split[0]
|
||||
secret := split[1]
|
||||
|
||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||
|
||||
timestamp := time.Now().UnixNano() / 1e6
|
||||
|
||||
payload := jwt.MapClaims{
|
||||
"api_key": id,
|
||||
"exp": expMillis,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
|
||||
token.Header["alg"] = "HS256"
|
||||
token.Header["sign_type"] = "SIGN"
|
||||
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
zhipuTokens.Store(apikey, tokenData{
|
||||
Token: tokenString,
|
||||
ExpiryTime: expiryTime,
|
||||
})
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
return &Request{
|
||||
Prompt: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Incremental: false,
|
||||
}
|
||||
}
|
||||
|
||||
func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
|
||||
Usage: response.Data.Usage,
|
||||
}
|
||||
for i, choice := range response.Data.Choices {
|
||||
openaiChoice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: model.Message{
|
||||
Role: choice.Role,
|
||||
Content: strings.Trim(choice.Content, "\""),
|
||||
},
|
||||
FinishReason: "",
|
||||
}
|
||||
if i == len(response.Data.Choices)-1 {
|
||||
openaiChoice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage *model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
metaChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
lines := strings.Split(data, "\n")
|
||||
for i, line := range lines {
|
||||
if len(line) < 5 {
|
||||
continue
|
||||
}
|
||||
if line[:5] == "data:" {
|
||||
dataChan <- line[5:]
|
||||
if i != len(lines)-1 {
|
||||
dataChan <- "\n"
|
||||
}
|
||||
} else if line[:5] == "meta:" {
|
||||
metaChan <- line[5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
response := streamResponseZhipu2OpenAI(data)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case data := <-metaChan:
|
||||
var zhipuResponse StreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage = zhipuUsage
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var zhipuResponse Response
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
Code: zhipuResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||
fullTextResponse.Model = "chatglm"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var zhipuResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseZhipu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
|
||||
Model: response.Model,
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.PromptTokens,
|
||||
CompletionTokens: response.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, item := range response.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Prompt []Message `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
type ResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []Message `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ResponseData `json:"data"`
|
||||
}
|
||||
|
||||
type StreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Embeddings []EmbeddingData `json:"data"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user