mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-12-25 17:25:56 +08:00
Merge commit '2369025842b828ac38f4427fd1ebab8d03b1fe7f'
This commit is contained in:
@@ -1,41 +1,46 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/aiproxy"
|
||||
"github.com/Laisky/one-api/relay/adaptor/ali"
|
||||
"github.com/Laisky/one-api/relay/adaptor/anthropic"
|
||||
"github.com/Laisky/one-api/relay/adaptor/aws"
|
||||
"github.com/Laisky/one-api/relay/adaptor/baidu"
|
||||
"github.com/Laisky/one-api/relay/adaptor/gemini"
|
||||
"github.com/Laisky/one-api/relay/adaptor/ollama"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/adaptor/palm"
|
||||
"github.com/Laisky/one-api/relay/adaptor/tencent"
|
||||
"github.com/Laisky/one-api/relay/adaptor/xunfei"
|
||||
"github.com/Laisky/one-api/relay/adaptor/zhipu"
|
||||
"github.com/Laisky/one-api/relay/apitype"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) adaptor.Adaptor {
|
||||
switch apiType {
|
||||
case apitype.AIProxyLibrary:
|
||||
return &aiproxy.Adaptor{}
|
||||
// case apitype.Ali:
|
||||
// return &ali.Adaptor{}
|
||||
case apitype.Ali:
|
||||
return &ali.Adaptor{}
|
||||
case apitype.Anthropic:
|
||||
return &anthropic.Adaptor{}
|
||||
case apitype.AwsClaude:
|
||||
return &aws.Adaptor{}
|
||||
// case apitype.Baidu:
|
||||
// return &baidu.Adaptor{}
|
||||
case apitype.Baidu:
|
||||
return &baidu.Adaptor{}
|
||||
case apitype.Gemini:
|
||||
return &gemini.Adaptor{}
|
||||
case apitype.OpenAI:
|
||||
return &openai.Adaptor{}
|
||||
case apitype.PaLM:
|
||||
return &palm.Adaptor{}
|
||||
// case apitype.Tencent:
|
||||
// return &tencent.Adaptor{}
|
||||
// case apitype.Xunfei:
|
||||
// return &xunfei.Adaptor{}
|
||||
// case apitype.Zhipu:
|
||||
// return &zhipu.Adaptor{}
|
||||
case apitype.Tencent:
|
||||
return &tencent.Adaptor{}
|
||||
case apitype.Xunfei:
|
||||
return &xunfei.Adaptor{}
|
||||
case apitype.Zhipu:
|
||||
return &zhipu.Adaptor{}
|
||||
case apitype.Ollama:
|
||||
return &ollama.Adaptor{}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@ package aiproxy
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package aiproxy
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
import "github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
|
||||
var ModelList = []string{""}
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/common/random"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||
|
||||
@@ -1,105 +1,106 @@
|
||||
package ali
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/common/config"
|
||||
// "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
// "github.com/songquanpeng/one-api/relay/meta"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/relaymode"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// )
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// type Adaptor struct {
|
||||
// }
|
||||
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
|
||||
// func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
// }
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// fullRequestURL := ""
|
||||
// switch meta.Mode {
|
||||
// case relaymode.Embeddings:
|
||||
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
|
||||
// case relaymode.ImagesGenerations:
|
||||
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
|
||||
// default:
|
||||
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
|
||||
// }
|
||||
}
|
||||
|
||||
// return fullRequestURL, nil
|
||||
// }
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
fullRequestURL := ""
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
|
||||
case relaymode.ImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
|
||||
default:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
|
||||
}
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
// adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
// if meta.IsStream {
|
||||
// req.Header.Set("Accept", "text/event-stream")
|
||||
// req.Header.Set("X-DashScope-SSE", "enable")
|
||||
// }
|
||||
// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
// if meta.Mode == relaymode.ImagesGenerations {
|
||||
// req.Header.Set("X-DashScope-Async", "enable")
|
||||
// }
|
||||
// if c.GetString(config.KeyPlugin) != "" {
|
||||
// req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.IsStream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// switch relayMode {
|
||||
// case relaymode.Embeddings:
|
||||
// aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
// return aliEmbeddingRequest, nil
|
||||
// default:
|
||||
// aliRequest := ConvertRequest(*request)
|
||||
// return aliRequest, nil
|
||||
// }
|
||||
// }
|
||||
if meta.Mode == relaymode.ImagesGenerations {
|
||||
req.Header.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
if c.GetString(config.KeyPlugin) != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return aliEmbeddingRequest, nil
|
||||
default:
|
||||
aliRequest := ConvertRequest(*request)
|
||||
return aliRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
// aliRequest := ConvertImageRequest(*request)
|
||||
// return aliRequest, nil
|
||||
// }
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
aliRequest := ConvertImageRequest(*request)
|
||||
return aliRequest, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, resp)
|
||||
// } else {
|
||||
// switch meta.Mode {
|
||||
// case relaymode.Embeddings:
|
||||
// err, usage = EmbeddingHandler(c, resp)
|
||||
// case relaymode.ImagesGenerations:
|
||||
// err, usage = ImageHandler(c, resp)
|
||||
// default:
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
case relaymode.ImagesGenerations:
|
||||
err, usage = ImageHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "ali"
|
||||
// }
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "ali"
|
||||
}
|
||||
|
||||
@@ -3,17 +3,18 @@ package ali
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
|
||||
@@ -1,323 +1,279 @@
|
||||
package ali
|
||||
|
||||
// import (
|
||||
// "github.com/songquanpeng/one-api/common"
|
||||
// )
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
// // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// type AliMessage struct {
|
||||
// Content string `json:"content"`
|
||||
// Role string `json:"role"`
|
||||
// }
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
// type AliInput struct {
|
||||
// //Prompt string `json:"prompt"`
|
||||
// Messages []AliMessage `json:"messages"`
|
||||
// }
|
||||
const EnableSearchModelSuffix = "-internet"
|
||||
|
||||
// type AliParameters struct {
|
||||
// TopP float64 `json:"top_p,omitempty"`
|
||||
// TopK int `json:"top_k,omitempty"`
|
||||
// Seed uint64 `json:"seed,omitempty"`
|
||||
// EnableSearch bool `json:"enable_search,omitempty"`
|
||||
// }
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
messages = append(messages, Message{
|
||||
Content: message.StringContent(),
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
enableSearch := false
|
||||
aliModel := request.Model
|
||||
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||
}
|
||||
if request.TopP >= 1 {
|
||||
request.TopP = 0.9999
|
||||
}
|
||||
return &ChatRequest{
|
||||
Model: aliModel,
|
||||
Input: Input{
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: Parameters{
|
||||
EnableSearch: enableSearch,
|
||||
IncrementalOutput: request.Stream,
|
||||
Seed: uint64(request.Seed),
|
||||
MaxTokens: request.MaxTokens,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
TopK: request.TopK,
|
||||
ResultFormat: "message",
|
||||
Tools: request.Tools,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// type AliChatRequest struct {
|
||||
// Model string `json:"model"`
|
||||
// Input AliInput `json:"input"`
|
||||
// Parameters AliParameters `json:"parameters,omitempty"`
|
||||
// }
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// type AliEmbeddingRequest struct {
|
||||
// Model string `json:"model"`
|
||||
// Input struct {
|
||||
// Texts []string `json:"texts"`
|
||||
// } `json:"input"`
|
||||
// Parameters *struct {
|
||||
// TextType string `json:"text_type,omitempty"`
|
||||
// } `json:"parameters,omitempty"`
|
||||
// }
|
||||
func ConvertImageRequest(request model.ImageRequest) *ImageRequest {
|
||||
var imageRequest ImageRequest
|
||||
imageRequest.Input.Prompt = request.Prompt
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
||||
imageRequest.Parameters.N = request.N
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
// type AliEmbedding struct {
|
||||
// Embedding []float64 `json:"embedding"`
|
||||
// TextIndex int `json:"text_index"`
|
||||
// }
|
||||
return &imageRequest
|
||||
}
|
||||
|
||||
// type AliEmbeddingResponse struct {
|
||||
// Output struct {
|
||||
// Embeddings []AliEmbedding `json:"embeddings"`
|
||||
// } `json:"output"`
|
||||
// Usage AliUsage `json:"usage"`
|
||||
// AliError
|
||||
// }
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var aliResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
// type AliError struct {
|
||||
// Code string `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// RequestId string `json:"request_id"`
|
||||
// }
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
// type AliUsage struct {
|
||||
// InputTokens int `json:"input_tokens"`
|
||||
// OutputTokens int `json:"output_tokens"`
|
||||
// TotalTokens int `json:"total_tokens"`
|
||||
// }
|
||||
if aliResponse.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// type AliOutput struct {
|
||||
// Text string `json:"text"`
|
||||
// FinishReason string `json:"finish_reason"`
|
||||
// }
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
// type AliChatResponse struct {
|
||||
// Output AliOutput `json:"output"`
|
||||
// Usage AliUsage `json:"usage"`
|
||||
// AliError
|
||||
// }
|
||||
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
// func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
||||
// messages := make([]AliMessage, 0, len(request.Messages))
|
||||
// prompt := ""
|
||||
// for i := 0; i < len(request.Messages); i++ {
|
||||
// message := request.Messages[i]
|
||||
// if message.Role == "system" {
|
||||
// messages = append(messages, AliMessage{
|
||||
// User: message.Content,
|
||||
// Bot: "Okay",
|
||||
// })
|
||||
// continue
|
||||
// } else {
|
||||
// if i == len(request.Messages)-1 {
|
||||
// prompt = message.Content
|
||||
// break
|
||||
// }
|
||||
// messages = append(messages, AliMessage{
|
||||
// User: message.Content,
|
||||
// Bot: request.Messages[i+1].Content,
|
||||
// })
|
||||
// i++
|
||||
// }
|
||||
// }
|
||||
// return &AliChatRequest{
|
||||
// Model: request.Model,
|
||||
// Input: AliInput{
|
||||
// Prompt: prompt,
|
||||
// History: messages,
|
||||
// },
|
||||
// //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
||||
// // TopP: request.TopP,
|
||||
// // TopK: 50,
|
||||
// // //Seed: 0,
|
||||
// // //EnableSearch: false,
|
||||
// //},
|
||||
// }
|
||||
// }
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
// func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||
// return &AliEmbeddingRequest{
|
||||
// Model: "text-embedding-v1",
|
||||
// Input: struct {
|
||||
// Texts []string `json:"texts"`
|
||||
// }{
|
||||
// Texts: request.ParseInput(),
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: response.Output.Choices,
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
// func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var aliResponse AliEmbeddingResponse
|
||||
// err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
if len(aliResponse.Output.Choices) == 0 {
|
||||
return nil
|
||||
}
|
||||
aliChoice := aliResponse.Output.Choices[0]
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta = aliChoice.Message
|
||||
if aliChoice.FinishReason != "null" {
|
||||
finishReason := aliChoice.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "qwen",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
//lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
//lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
// if aliResponse.Code != "" {
|
||||
// return &OpenAIErrorWithStatusCode{
|
||||
// OpenAIError: OpenAIError{
|
||||
// Message: aliResponse.Message,
|
||||
// Type: aliResponse.Code,
|
||||
// Param: aliResponse.RequestId,
|
||||
// Code: aliResponse.Code,
|
||||
// },
|
||||
// StatusCode: resp.StatusCode,
|
||||
// }, nil
|
||||
// }
|
||||
|
||||
// fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
// jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// _, err = c.Writer.Write(jsonResponse)
|
||||
// return nil, &fullTextResponse.Usage
|
||||
// }
|
||||
|
||||
// func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
// openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
// Object: "list",
|
||||
// Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
// Model: "text-embedding-v1",
|
||||
// Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
// }
|
||||
|
||||
// for _, item := range response.Output.Embeddings {
|
||||
// openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
// Object: `embedding`,
|
||||
// Index: item.TextIndex,
|
||||
// Embedding: item.Embedding,
|
||||
// })
|
||||
// }
|
||||
// return &openAIEmbeddingResponse
|
||||
// }
|
||||
|
||||
// func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||
// choice := OpenAITextResponseChoice{
|
||||
// Index: 0,
|
||||
// Message: Message{
|
||||
// Role: "assistant",
|
||||
// Content: response.Output.Text,
|
||||
// },
|
||||
// FinishReason: response.Output.FinishReason,
|
||||
// }
|
||||
// fullTextResponse := OpenAITextResponse{
|
||||
// Id: response.RequestId,
|
||||
// Object: "chat.completion",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Choices: []OpenAITextResponseChoice{choice},
|
||||
// Usage: Usage{
|
||||
// PromptTokens: response.Usage.InputTokens,
|
||||
// CompletionTokens: response.Usage.OutputTokens,
|
||||
// TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
// },
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
// }
|
||||
|
||||
// func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
|
||||
// var choice ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = aliResponse.Output.Text
|
||||
// if aliResponse.Output.FinishReason != "null" {
|
||||
// finishReason := aliResponse.Output.FinishReason
|
||||
// choice.FinishReason = &finishReason
|
||||
// }
|
||||
// response := ChatCompletionsStreamResponse{
|
||||
// Id: aliResponse.RequestId,
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Model: "ernie-bot",
|
||||
// Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
// }
|
||||
// return &response
|
||||
// }
|
||||
|
||||
// func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var usage Usage
|
||||
// scanner := bufio.NewScanner(resp.Body)
|
||||
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// if atEOF && len(data) == 0 {
|
||||
// return 0, nil, nil
|
||||
// }
|
||||
// if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
// return i + 1, data[0:i], nil
|
||||
// }
|
||||
// if atEOF {
|
||||
// return len(data), data, nil
|
||||
// }
|
||||
// return 0, nil, nil
|
||||
// })
|
||||
// dataChan := make(chan string)
|
||||
// stopChan := make(chan bool)
|
||||
// go func() {
|
||||
// for scanner.Scan() {
|
||||
// data := scanner.Text()
|
||||
// if len(data) < 5 { // ignore blank line or wrong format
|
||||
// continue
|
||||
// }
|
||||
// if data[:5] != "data:" {
|
||||
// continue
|
||||
// }
|
||||
// data = data[5:]
|
||||
// dataChan <- data
|
||||
// }
|
||||
// stopChan <- true
|
||||
// }()
|
||||
// setEventStreamHeaders(c)
|
||||
// lastResponseText := ""
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case data := <-dataChan:
|
||||
// var aliResponse AliChatResponse
|
||||
// err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
// if err != nil {
|
||||
// common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// if aliResponse.Usage.OutputTokens != 0 {
|
||||
// usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
// usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
// usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
// }
|
||||
// response := streamResponseAli2OpenAI(&aliResponse)
|
||||
// response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
// lastResponseText = aliResponse.Output.Text
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// common.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
// err := resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// return nil, &usage
|
||||
// }
|
||||
|
||||
// func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var aliResponse AliChatResponse
|
||||
// responseBody, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = json.Unmarshal(responseBody, &aliResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// if aliResponse.Code != "" {
|
||||
// return &OpenAIErrorWithStatusCode{
|
||||
// OpenAIError: OpenAIError{
|
||||
// Message: aliResponse.Message,
|
||||
// Type: aliResponse.Code,
|
||||
// Param: aliResponse.RequestId,
|
||||
// Code: aliResponse.Code,
|
||||
// },
|
||||
// StatusCode: resp.StatusCode,
|
||||
// }, nil
|
||||
// }
|
||||
// fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
// jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// _, err = c.Writer.Write(jsonResponse)
|
||||
// return nil, &fullTextResponse.Usage
|
||||
// }
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
ctx := c.Request.Context()
|
||||
var aliResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
logger.Debugf(ctx, "response body: %s\n", responseBody)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
fullTextResponse.Model = "qwen"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/image"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason *string) string {
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common/ctxkey"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/anthropic"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var _ adaptor.Adaptor = new(Adaptor)
|
||||
@@ -36,9 +36,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
|
||||
claudeReq := anthropic.ConvertRequest(*request)
|
||||
c.Set(common.CtxKeyRequestModel, request.Model)
|
||||
c.Set(common.CtxKeyRawRequest, request)
|
||||
c.Set(common.CtxKeyConvertedRequest, claudeReq)
|
||||
c.Set(ctxkey.RequestModel, request.Model)
|
||||
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
||||
return claudeReq, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,8 +7,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/ctxkey"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/adaptor/anthropic"
|
||||
relaymodel "github.com/Laisky/one-api/relay/model"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
@@ -16,23 +22,14 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) {
|
||||
ks := strings.Split(channel.Key, "\n")
|
||||
if len(ks) != 2 {
|
||||
return nil, errors.New("invalid key")
|
||||
}
|
||||
ak, sk := ks[0], ks[1]
|
||||
|
||||
func newAwsClient(c *gin.Context) (*bedrockruntime.Client, error) {
|
||||
ak := c.GetString(config.KeyAK)
|
||||
sk := c.GetString(config.KeySK)
|
||||
region := c.GetString(config.KeyRegion)
|
||||
client := bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: *channel.BaseURL,
|
||||
Region: region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||
})
|
||||
|
||||
@@ -43,7 +40,7 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
|
||||
return &relaymodel.ErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.Error{
|
||||
Message: fmt.Sprintf("%+v", err),
|
||||
Message: fmt.Sprintf("%s", err.Error()),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -67,19 +64,12 @@ func awsModelID(requestModel string) (string, error) {
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||
var channel *model.Channel
|
||||
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
|
||||
return wrapErr(errors.New("channel not found")), nil
|
||||
} else {
|
||||
channel = channeli.(*model.Channel)
|
||||
}
|
||||
|
||||
awsCli, err := newAwsClient(channel)
|
||||
awsCli, err := newAwsClient(c)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel))
|
||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
@@ -90,11 +80,11 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
|
||||
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReqi.(*anthropic.Request)
|
||||
claudeReq := claudeReq_.(*anthropic.Request)
|
||||
awsClaudeReq := &Request{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
}
|
||||
@@ -133,20 +123,12 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||
createdTime := helper.GetTimestamp()
|
||||
|
||||
var channel *model.Channel
|
||||
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
|
||||
return wrapErr(errors.New("channel not found")), nil
|
||||
} else {
|
||||
channel = channeli.(*model.Channel)
|
||||
}
|
||||
|
||||
awsCli, err := newAwsClient(channel)
|
||||
awsCli, err := newAwsClient(c)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel))
|
||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
@@ -157,11 +139,11 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
|
||||
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReqi.(*anthropic.Request)
|
||||
claudeReq := claudeReq_.(*anthropic.Request)
|
||||
|
||||
awsClaudeReq := &Request{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
@@ -211,7 +193,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
|
||||
return true
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = c.GetString(common.CtxKeyOriginModel)
|
||||
response.Model = c.GetString(ctxkey.OriginalModel)
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package aws
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
import "github.com/Laisky/one-api/relay/adaptor/anthropic"
|
||||
|
||||
// Request is the request to AWS Claude
|
||||
//
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
)
|
||||
|
||||
func GetAPIVersion(c *gin.Context) string {
|
||||
|
||||
@@ -1,93 +1,143 @@
|
||||
package baidu
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/relay/channel"
|
||||
// "github.com/songquanpeng/one-api/relay/constant"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/util"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// )
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
// // type Adaptor struct {
|
||||
// // }
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
// }
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
// var fullRequestURL string
|
||||
// switch meta.ActualModelName {
|
||||
// case "ERNIE-Bot-4":
|
||||
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
// case "ERNIE-Bot-8K":
|
||||
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||
// case "ERNIE-Bot":
|
||||
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
// case "ERNIE-Speed":
|
||||
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
// case "ERNIE-Bot-turbo":
|
||||
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
// case "BLOOMZ-7B":
|
||||
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
// case "Embedding-V1":
|
||||
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
// }
|
||||
// var accessToken string
|
||||
// var err error
|
||||
// if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// fullRequestURL += "?access_token=" + accessToken
|
||||
// return fullRequestURL, nil
|
||||
// }
|
||||
}
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
// channel.SetupCommonRequestHeader(c, req, meta)
|
||||
// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
// return nil
|
||||
// }
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
suffix := "chat/"
|
||||
if strings.HasPrefix(meta.ActualModelName, "Embedding") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(meta.ActualModelName, "bge-large") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(meta.ActualModelName, "tao-8k") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
switch meta.ActualModelName {
|
||||
case "ERNIE-4.0":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot-4":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot":
|
||||
suffix += "completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Speed":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-4.0-8K":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-3.5-8K":
|
||||
suffix += "completions"
|
||||
case "ERNIE-3.5-8K-0205":
|
||||
suffix += "ernie-3.5-8k-0205"
|
||||
case "ERNIE-3.5-8K-1222":
|
||||
suffix += "ernie-3.5-8k-1222"
|
||||
case "ERNIE-Bot-8K":
|
||||
suffix += "ernie_bot_8k"
|
||||
case "ERNIE-3.5-4K-0205":
|
||||
suffix += "ernie-3.5-4k-0205"
|
||||
case "ERNIE-Speed-8K":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-Speed-128K":
|
||||
suffix += "ernie-speed-128k"
|
||||
case "ERNIE-Lite-8K-0922":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Lite-8K-0308":
|
||||
suffix += "ernie-lite-8k"
|
||||
case "ERNIE-Tiny-8K":
|
||||
suffix += "ernie-tiny-8k"
|
||||
case "BLOOMZ-7B":
|
||||
suffix += "bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
suffix += "embedding-v1"
|
||||
case "bge-large-zh":
|
||||
suffix += "bge_large_zh"
|
||||
case "bge-large-en":
|
||||
suffix += "bge_large_en"
|
||||
case "tao-8k":
|
||||
suffix += "tao_8k"
|
||||
default:
|
||||
suffix += strings.ToLower(meta.ActualModelName)
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fullRequestURL += "?access_token=" + accessToken
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// switch relayMode {
|
||||
// case constant.RelayModeEmbeddings:
|
||||
// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
// return baiduEmbeddingRequest, nil
|
||||
// default:
|
||||
// baiduRequest := ConvertRequest(*request)
|
||||
// return baiduRequest, nil
|
||||
// }
|
||||
// }
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := ConvertRequest(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, resp)
|
||||
// } else {
|
||||
// switch meta.Mode {
|
||||
// case constant.RelayModeEmbeddings:
|
||||
// err, usage = EmbeddingHandler(c, resp)
|
||||
// default:
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "baidu"
|
||||
// }
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "baidu"
|
||||
}
|
||||
|
||||
@@ -1,359 +1,329 @@
|
||||
package baidu
|
||||
|
||||
// import (
|
||||
// "bufio"
|
||||
// "encoding/json"
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "one-api/common"
|
||||
// "strings"
|
||||
// "sync"
|
||||
// "time"
|
||||
// )
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
// // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/client"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// type BaiduTokenResponse struct {
|
||||
// ExpiresIn int `json:"expires_in"`
|
||||
// AccessToken string `json:"access_token"`
|
||||
// }
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
|
||||
// type BaiduMessage struct {
|
||||
// Role string `json:"role"`
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
type TokenResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
// type BaiduChatRequest struct {
|
||||
// Messages []BaiduMessage `json:"messages"`
|
||||
// Stream bool `json:"stream"`
|
||||
// UserId string `json:"user_id,omitempty"`
|
||||
// }
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// type BaiduError struct {
|
||||
// ErrorCode int `json:"error_code"`
|
||||
// ErrorMsg string `json:"error_msg"`
|
||||
// }
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// type BaiduChatResponse struct {
|
||||
// Id string `json:"id"`
|
||||
// Object string `json:"object"`
|
||||
// Created int64 `json:"created"`
|
||||
// Result string `json:"result"`
|
||||
// IsTruncated bool `json:"is_truncated"`
|
||||
// NeedClearHistory bool `json:"need_clear_history"`
|
||||
// Usage Usage `json:"usage"`
|
||||
// BaiduError
|
||||
// }
|
||||
type Error struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
// type BaiduChatStreamResponse struct {
|
||||
// BaiduChatResponse
|
||||
// SentenceId int `json:"sentence_id"`
|
||||
// IsEnd bool `json:"is_end"`
|
||||
// }
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
// type BaiduEmbeddingRequest struct {
|
||||
// Input []string `json:"input"`
|
||||
// }
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
baiduRequest := ChatRequest{
|
||||
Messages: make([]Message, 0, len(request.Messages)),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
UserId: request.User,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
baiduRequest.System = message.StringContent()
|
||||
} else {
|
||||
baiduRequest.Messages = append(baiduRequest.Messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &baiduRequest
|
||||
}
|
||||
|
||||
// type BaiduEmbeddingData struct {
|
||||
// Object string `json:"object"`
|
||||
// Embedding []float64 `json:"embedding"`
|
||||
// Index int `json:"index"`
|
||||
// }
|
||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Result,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: response.Created,
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Usage,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
// type BaiduEmbeddingResponse struct {
|
||||
// Id string `json:"id"`
|
||||
// Object string `json:"object"`
|
||||
// Created int64 `json:"created"`
|
||||
// Data []BaiduEmbeddingData `json:"data"`
|
||||
// Usage Usage `json:"usage"`
|
||||
// BaiduError
|
||||
// }
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// type BaiduAccessToken struct {
|
||||
// AccessToken string `json:"access_token"`
|
||||
// Error string `json:"error,omitempty"`
|
||||
// ErrorDescription string `json:"error_description,omitempty"`
|
||||
// ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
// ExpiresAt time.Time `json:"-"`
|
||||
// }
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
// var baiduTokenStore sync.Map
|
||||
func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "baidu-embedding",
|
||||
Usage: response.Usage,
|
||||
}
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
// func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
// messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
// for _, message := range request.Messages {
|
||||
// if message.Role == "system" {
|
||||
// messages = append(messages, BaiduMessage{
|
||||
// Role: "user",
|
||||
// Content: message.Content,
|
||||
// })
|
||||
// messages = append(messages, BaiduMessage{
|
||||
// Role: "assistant",
|
||||
// Content: "Okay",
|
||||
// })
|
||||
// } else {
|
||||
// messages = append(messages, BaiduMessage{
|
||||
// Role: message.Role,
|
||||
// Content: message.Content,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// return &BaiduChatRequest{
|
||||
// Messages: messages,
|
||||
// Stream: request.Stream,
|
||||
// }
|
||||
// }
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse ChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
// func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
||||
// choice := OpenAITextResponseChoice{
|
||||
// Index: 0,
|
||||
// Message: Message{
|
||||
// Role: "assistant",
|
||||
// Content: response.Result,
|
||||
// },
|
||||
// FinishReason: "stop",
|
||||
// }
|
||||
// fullTextResponse := OpenAITextResponse{
|
||||
// Id: response.Id,
|
||||
// Object: "chat.completion",
|
||||
// Created: response.Created,
|
||||
// Choices: []OpenAITextResponseChoice{choice},
|
||||
// Usage: response.Usage,
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
// }
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
fullTextResponse.Model = "ernie-bot"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
// func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
||||
// var choice ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = baiduResponse.Result
|
||||
// if baiduResponse.IsEnd {
|
||||
// choice.FinishReason = &stopFinishReason
|
||||
// }
|
||||
// response := ChatCompletionsStreamResponse{
|
||||
// Id: baiduResponse.Id,
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: baiduResponse.Created,
|
||||
// Model: "ernie-bot",
|
||||
// Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
// }
|
||||
// return &response
|
||||
// }
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
// func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
// return &BaiduEmbeddingRequest{
|
||||
// Input: request.ParseInput(),
|
||||
// }
|
||||
// }
|
||||
func GetAccessToken(apiKey string) (string, error) {
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken AccessToken
|
||||
if accessToken, ok = val.(AccessToken); ok {
|
||||
// soon this will expire
|
||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
go func() {
|
||||
_, _ = getBaiduAccessTokenHelper(apiKey)
|
||||
}()
|
||||
}
|
||||
return accessToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accessToken == nil {
|
||||
return "", errors.New("GetAccessToken return a nil token")
|
||||
}
|
||||
return (*accessToken).AccessToken, nil
|
||||
}
|
||||
|
||||
// func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
// openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
// Object: "list",
|
||||
// Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||
// Model: "baidu-embedding",
|
||||
// Usage: response.Usage,
|
||||
// }
|
||||
// for _, item := range response.Data {
|
||||
// openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
// Object: item.Object,
|
||||
// Index: item.Index,
|
||||
// Embedding: item.Embedding,
|
||||
// })
|
||||
// }
|
||||
// return &openAIEmbeddingResponse
|
||||
// }
|
||||
func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
}
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||
parts[0], parts[1]), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := client.ImpatientHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var usage Usage
|
||||
// scanner := bufio.NewScanner(resp.Body)
|
||||
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// if atEOF && len(data) == 0 {
|
||||
// return 0, nil, nil
|
||||
// }
|
||||
// if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
// return i + 1, data[0:i], nil
|
||||
// }
|
||||
// if atEOF {
|
||||
// return len(data), data, nil
|
||||
// }
|
||||
// return 0, nil, nil
|
||||
// })
|
||||
// dataChan := make(chan string)
|
||||
// stopChan := make(chan bool)
|
||||
// go func() {
|
||||
// for scanner.Scan() {
|
||||
// data := scanner.Text()
|
||||
// if len(data) < 6 { // ignore blank line or wrong format
|
||||
// continue
|
||||
// }
|
||||
// data = data[6:]
|
||||
// dataChan <- data
|
||||
// }
|
||||
// stopChan <- true
|
||||
// }()
|
||||
// setEventStreamHeaders(c)
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case data := <-dataChan:
|
||||
// var baiduResponse BaiduChatStreamResponse
|
||||
// err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
// if err != nil {
|
||||
// common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// if baiduResponse.Usage.TotalTokens != 0 {
|
||||
// usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
// usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
// usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
// }
|
||||
// response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// common.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
// err := resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// return nil, &usage
|
||||
// }
|
||||
|
||||
// func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var baiduResponse BaiduChatResponse
|
||||
// responseBody, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// if baiduResponse.ErrorMsg != "" {
|
||||
// return &OpenAIErrorWithStatusCode{
|
||||
// OpenAIError: OpenAIError{
|
||||
// Message: baiduResponse.ErrorMsg,
|
||||
// Type: "baidu_error",
|
||||
// Param: "",
|
||||
// Code: baiduResponse.ErrorCode,
|
||||
// },
|
||||
// StatusCode: resp.StatusCode,
|
||||
// }, nil
|
||||
// }
|
||||
// fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
// jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// _, err = c.Writer.Write(jsonResponse)
|
||||
// return nil, &fullTextResponse.Usage
|
||||
// }
|
||||
|
||||
// func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var baiduResponse BaiduEmbeddingResponse
|
||||
// responseBody, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// if baiduResponse.ErrorMsg != "" {
|
||||
// return &OpenAIErrorWithStatusCode{
|
||||
// OpenAIError: OpenAIError{
|
||||
// Message: baiduResponse.ErrorMsg,
|
||||
// Type: "baidu_error",
|
||||
// Param: "",
|
||||
// Code: baiduResponse.ErrorCode,
|
||||
// },
|
||||
// StatusCode: resp.StatusCode,
|
||||
// }, nil
|
||||
// }
|
||||
// fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
// jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// _, err = c.Writer.Write(jsonResponse)
|
||||
// return nil, &fullTextResponse.Usage
|
||||
// }
|
||||
|
||||
// func getBaiduAccessToken(apiKey string) (string, error) {
|
||||
// if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
// var accessToken BaiduAccessToken
|
||||
// if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||
// // soon this will expire
|
||||
// if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
// go func() {
|
||||
// _, _ = getBaiduAccessTokenHelper(apiKey)
|
||||
// }()
|
||||
// }
|
||||
// return accessToken.AccessToken, nil
|
||||
// }
|
||||
// }
|
||||
// accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// if accessToken == nil {
|
||||
// return "", errors.New("getBaiduAccessToken return a nil token")
|
||||
// }
|
||||
// return (*accessToken).AccessToken, nil
|
||||
// }
|
||||
|
||||
// func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
// parts := strings.Split(apiKey, "|")
|
||||
// if len(parts) != 2 {
|
||||
// return nil, errors.New("invalid baidu apikey")
|
||||
// }
|
||||
// req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||
// parts[0], parts[1]), nil)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// req.Header.Add("Content-Type", "application/json")
|
||||
// req.Header.Add("Accept", "application/json")
|
||||
// res, err := impatientHTTPClient.Do(req)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// defer res.Body.Close()
|
||||
|
||||
// var accessToken BaiduAccessToken
|
||||
// err = json.NewDecoder(res.Body).Decode(&accessToken)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if accessToken.Error != "" {
|
||||
// return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
// }
|
||||
// if accessToken.AccessToken == "" {
|
||||
// return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||
// }
|
||||
// accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||
// baiduTokenStore.Store(apiKey, accessToken)
|
||||
// return &accessToken, nil
|
||||
// }
|
||||
var accessToken AccessToken
|
||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accessToken.Error != "" {
|
||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
}
|
||||
if accessToken.AccessToken == "" {
|
||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||
}
|
||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||
baiduTokenStore.Store(apiKey, accessToken)
|
||||
return &accessToken, nil
|
||||
}
|
||||
|
||||
@@ -1,50 +1,51 @@
|
||||
package baidu
|
||||
|
||||
// import (
|
||||
// "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
// "time"
|
||||
// )
|
||||
import (
|
||||
"time"
|
||||
|
||||
// type ChatResponse struct {
|
||||
// Id string `json:"id"`
|
||||
// Object string `json:"object"`
|
||||
// Created int64 `json:"created"`
|
||||
// Result string `json:"result"`
|
||||
// IsTruncated bool `json:"is_truncated"`
|
||||
// NeedClearHistory bool `json:"need_clear_history"`
|
||||
// Usage openai.Usage `json:"usage"`
|
||||
// Error
|
||||
// }
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
)
|
||||
|
||||
// type ChatStreamResponse struct {
|
||||
// ChatResponse
|
||||
// SentenceId int `json:"sentence_id"`
|
||||
// IsEnd bool `json:"is_end"`
|
||||
// }
|
||||
type ChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
// type EmbeddingRequest struct {
|
||||
// Input []string `json:"input"`
|
||||
// }
|
||||
type ChatStreamResponse struct {
|
||||
ChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
// type EmbeddingData struct {
|
||||
// Object string `json:"object"`
|
||||
// Embedding []float64 `json:"embedding"`
|
||||
// Index int `json:"index"`
|
||||
// }
|
||||
type EmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
// type EmbeddingResponse struct {
|
||||
// Id string `json:"id"`
|
||||
// Object string `json:"object"`
|
||||
// Created int64 `json:"created"`
|
||||
// Data []EmbeddingData `json:"data"`
|
||||
// Usage openai.Usage `json:"usage"`
|
||||
// Error
|
||||
// }
|
||||
type EmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// type AccessToken struct {
|
||||
// AccessToken string `json:"access_token"`
|
||||
// Error string `json:"error,omitempty"`
|
||||
// ErrorDescription string `json:"error_description,omitempty"`
|
||||
// ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
// ExpiresAt time.Time `json:"-"`
|
||||
// }
|
||||
type EmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/relay/client"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
)
|
||||
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
|
||||
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
channelhelper "github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
|
||||
@@ -8,15 +8,15 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/image"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/common/random"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ package minimax
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/random"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
|
||||
@@ -3,13 +3,13 @@ package openai
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/minimax"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -39,7 +39,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := meta.ActualModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
//https://github.com/songquanpeng/one-api/issues/1191
|
||||
//https://github.com/Laisky/one-api/issues/1191
|
||||
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
|
||||
@@ -58,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
if meta.ChannelType == channeltype.OpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/Laisky/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/groq"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/adaptor/ai360"
|
||||
"github.com/Laisky/one-api/relay/adaptor/baichuan"
|
||||
"github.com/Laisky/one-api/relay/adaptor/groq"
|
||||
"github.com/Laisky/one-api/relay/adaptor/lingyiwanwu"
|
||||
"github.com/Laisky/one-api/relay/adaptor/minimax"
|
||||
"github.com/Laisky/one-api/relay/adaptor/mistral"
|
||||
"github.com/Laisky/one-api/relay/adaptor/moonshot"
|
||||
"github.com/Laisky/one-api/relay/adaptor/stepfun"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
var CompatibleChannels = []int{
|
||||
|
||||
@@ -2,8 +2,8 @@ package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -4,15 +4,16 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/conv"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/conv"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
import "github.com/Laisky/one-api/relay/model"
|
||||
|
||||
type TextContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
|
||||
@@ -3,12 +3,12 @@ package openai
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/image"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
billingratio "github.com/Laisky/one-api/relay/billing/ratio"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
import "github.com/Laisky/one-api/relay/model"
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||
Error := model.Error{
|
||||
|
||||
@@ -3,11 +3,11 @@ package palm
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
)
|
||||
|
||||
type ChatMessage struct {
|
||||
|
||||
@@ -3,14 +3,14 @@ package palm
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/common/random"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -1,76 +1,84 @@
|
||||
package tencent
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/relay/channel"
|
||||
// "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/util"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "strings"
|
||||
// )
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
// // https://cloud.tencent.com/document/api/1729/101837
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// type Adaptor struct {
|
||||
// Sign string
|
||||
// }
|
||||
// https://cloud.tencent.com/document/api/1729/101837
|
||||
|
||||
// func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
}
|
||||
|
||||
// }
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
|
||||
// }
|
||||
}
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
// channel.SetupCommonRequestHeader(c, req, meta)
|
||||
// req.Header.Set("Authorization", a.Sign)
|
||||
// req.Header.Set("X-TC-Action", meta.ActualModelName)
|
||||
// return nil
|
||||
// }
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// apiKey := c.Request.Header.Get("Authorization")
|
||||
// apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
// appId, secretId, secretKey, err := ParseConfig(apiKey)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// tencentRequest := ConvertRequest(*request)
|
||||
// tencentRequest.AppId = appId
|
||||
// tencentRequest.SecretId = secretId
|
||||
// // we have to calculate the sign here
|
||||
// a.Sign = GetSign(*tencentRequest, secretKey)
|
||||
// return tencentRequest, nil
|
||||
// }
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", meta.ActualModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := ParseConfig(apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := ConvertRequest(*request)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
// we have to calculate the sign here
|
||||
a.Sign = GetSign(*tencentRequest, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// var responseText string
|
||||
// err, responseText = StreamHandler(c, resp)
|
||||
// usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
// } else {
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "tencent"
|
||||
// }
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "tencent"
|
||||
}
|
||||
|
||||
@@ -1,238 +1,231 @@
|
||||
package tencent
|
||||
|
||||
// import (
|
||||
// "bufio"
|
||||
// "crypto/hmac"
|
||||
// "crypto/sha1"
|
||||
// "encoding/base64"
|
||||
// "encoding/json"
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/common"
|
||||
// "github.com/songquanpeng/one-api/common/helper"
|
||||
// "github.com/songquanpeng/one-api/common/logger"
|
||||
// "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
// "github.com/songquanpeng/one-api/relay/constant"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "sort"
|
||||
// "strconv"
|
||||
// "strings"
|
||||
// )
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
// // https://cloud.tencent.com/document/product/1729/97732
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/conv"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/common/random"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
// messages := make([]Message, 0, len(request.Messages))
|
||||
// for i := 0; i < len(request.Messages); i++ {
|
||||
// message := request.Messages[i]
|
||||
// if message.Role == "system" {
|
||||
// messages = append(messages, Message{
|
||||
// Role: "user",
|
||||
// Content: message.StringContent(),
|
||||
// })
|
||||
// messages = append(messages, Message{
|
||||
// Role: "assistant",
|
||||
// Content: "Okay",
|
||||
// })
|
||||
// continue
|
||||
// }
|
||||
// messages = append(messages, Message{
|
||||
// Content: message.StringContent(),
|
||||
// Role: message.Role,
|
||||
// })
|
||||
// }
|
||||
// stream := 0
|
||||
// if request.Stream {
|
||||
// stream = 1
|
||||
// }
|
||||
// return &ChatRequest{
|
||||
// Timestamp: helper.GetTimestamp(),
|
||||
// Expired: helper.GetTimestamp() + 24*60*60,
|
||||
// QueryID: helper.GetUUID(),
|
||||
// Temperature: request.Temperature,
|
||||
// TopP: request.TopP,
|
||||
// Stream: stream,
|
||||
// Messages: messages,
|
||||
// }
|
||||
// }
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
// func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
// fullTextResponse := openai.TextResponse{
|
||||
// Object: "chat.completion",
|
||||
// Created: helper.GetTimestamp(),
|
||||
// Usage: response.Usage,
|
||||
// }
|
||||
// if len(response.Choices) > 0 {
|
||||
// choice := openai.TextResponseChoice{
|
||||
// Index: 0,
|
||||
// Message: model.Message{
|
||||
// Role: "assistant",
|
||||
// Content: response.Choices[0].Messages.Content,
|
||||
// },
|
||||
// FinishReason: response.Choices[0].FinishReason,
|
||||
// }
|
||||
// fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
// }
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
messages = append(messages, Message{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
stream := 0
|
||||
if request.Stream {
|
||||
stream = 1
|
||||
}
|
||||
return &ChatRequest{
|
||||
Timestamp: helper.GetTimestamp(),
|
||||
Expired: helper.GetTimestamp() + 24*60*60,
|
||||
QueryID: random.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
Messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
// func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
// response := openai.ChatCompletionsStreamResponse{
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: helper.GetTimestamp(),
|
||||
// Model: "tencent-hunyuan",
|
||||
// }
|
||||
// if len(TencentResponse.Choices) > 0 {
|
||||
// var choice openai.ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
// if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
// choice.FinishReason = &constant.StopFinishReason
|
||||
// }
|
||||
// response.Choices = append(response.Choices, choice)
|
||||
// }
|
||||
// return &response
|
||||
// }
|
||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Choices[0].Messages.Content,
|
||||
},
|
||||
FinishReason: response.Choices[0].FinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
// func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
// var responseText string
|
||||
// scanner := bufio.NewScanner(resp.Body)
|
||||
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// if atEOF && len(data) == 0 {
|
||||
// return 0, nil, nil
|
||||
// }
|
||||
// if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
// return i + 1, data[0:i], nil
|
||||
// }
|
||||
// if atEOF {
|
||||
// return len(data), data, nil
|
||||
// }
|
||||
// return 0, nil, nil
|
||||
// })
|
||||
// dataChan := make(chan string)
|
||||
// stopChan := make(chan bool)
|
||||
// go func() {
|
||||
// for scanner.Scan() {
|
||||
// data := scanner.Text()
|
||||
// if len(data) < 5 { // ignore blank line or wrong format
|
||||
// continue
|
||||
// }
|
||||
// if data[:5] != "data:" {
|
||||
// continue
|
||||
// }
|
||||
// data = data[5:]
|
||||
// dataChan <- data
|
||||
// }
|
||||
// stopChan <- true
|
||||
// }()
|
||||
// common.SetEventStreamHeaders(c)
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case data := <-dataChan:
|
||||
// var TencentResponse ChatResponse
|
||||
// err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
// if err != nil {
|
||||
// logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
// if len(response.Choices) != 0 {
|
||||
// responseText += response.Choices[0].Delta.Content
|
||||
// }
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// logger.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
// err := resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
// }
|
||||
// return nil, responseText
|
||||
// }
|
||||
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
// var TencentResponse ChatResponse
|
||||
// responseBody, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// if TencentResponse.Error.Code != 0 {
|
||||
// return &model.ErrorWithStatusCode{
|
||||
// Error: model.Error{
|
||||
// Message: TencentResponse.Error.Message,
|
||||
// Code: TencentResponse.Error.Code,
|
||||
// },
|
||||
// StatusCode: resp.StatusCode,
|
||||
// }, nil
|
||||
// }
|
||||
// fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
// fullTextResponse.Model = "hunyuan"
|
||||
// jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// _, err = c.Writer.Write(jsonResponse)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// return nil, &fullTextResponse.Usage
|
||||
// }
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += conv.AsString(response.Choices[0].Delta.Content)
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
// func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
// parts := strings.Split(config, "|")
|
||||
// if len(parts) != 3 {
|
||||
// err = errors.New("invalid tencent config")
|
||||
// return
|
||||
// }
|
||||
// appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
// secretId = parts[1]
|
||||
// secretKey = parts[2]
|
||||
// return
|
||||
// }
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var TencentResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
fullTextResponse.Model = "hunyuan"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
// func GetSign(req ChatRequest, secretKey string) string {
|
||||
// params := make([]string, 0)
|
||||
// params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
// params = append(params, "secret_id="+req.SecretId)
|
||||
// params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
// params = append(params, "query_id="+req.QueryID)
|
||||
// params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
// params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
// params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
// params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
parts := strings.Split(config, "|")
|
||||
if len(parts) != 3 {
|
||||
err = errors.New("invalid tencent config")
|
||||
return
|
||||
}
|
||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
secretId = parts[1]
|
||||
secretKey = parts[2]
|
||||
return
|
||||
}
|
||||
|
||||
// var messageStr string
|
||||
// for _, msg := range req.Messages {
|
||||
// messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
// }
|
||||
// messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
// params = append(params, "messages=["+messageStr+"]")
|
||||
func GetSign(req ChatRequest, secretKey string) string {
|
||||
params := make([]string, 0)
|
||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
params = append(params, "secret_id="+req.SecretId)
|
||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
params = append(params, "query_id="+req.QueryID)
|
||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
// sort.Strings(params)
|
||||
// url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
// mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
// signURL := url
|
||||
// mac.Write([]byte(signURL))
|
||||
// sign := mac.Sum([]byte(nil))
|
||||
// return base64.StdEncoding.EncodeToString(sign)
|
||||
// }
|
||||
var messageStr string
|
||||
for _, msg := range req.Messages {
|
||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
}
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Strings(params)
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
mac.Write([]byte(signURL))
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
}
|
||||
|
||||
@@ -1,63 +1,63 @@
|
||||
package tencent
|
||||
|
||||
// import (
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// )
|
||||
import (
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
)
|
||||
|
||||
// type Message struct {
|
||||
// Role string `json:"role"`
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// type ChatRequest struct {
|
||||
// AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
// SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
// Timestamp int64 `json:"timestamp"`
|
||||
// // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
// Expired int64 `json:"expired"`
|
||||
// QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// // 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
// Temperature float64 `json:"temperature"`
|
||||
// // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// // 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
// TopP float64 `json:"top_p"`
|
||||
// // Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// // 同步请求超时:60s,如果内容较长建议使用流式
|
||||
// Stream int `json:"stream"`
|
||||
// // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// // 输入 content 总数最大支持 3000 token。
|
||||
// Messages []Message `json:"messages"`
|
||||
// }
|
||||
type ChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
// type Error struct {
|
||||
// Code int `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// }
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// type Usage struct {
|
||||
// InputTokens int `json:"input_tokens"`
|
||||
// OutputTokens int `json:"output_tokens"`
|
||||
// TotalTokens int `json:"total_tokens"`
|
||||
// }
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// type ResponseChoices struct {
|
||||
// FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
// Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
// Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
// }
|
||||
type ResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
// type ChatResponse struct {
|
||||
// Choices []ResponseChoices `json:"choices,omitempty"` // 结果
|
||||
// Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
// Id string `json:"id,omitempty"` // 会话 id
|
||||
// Usage model.Usage `json:"usage,omitempty"` // token 数量
|
||||
// Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
// Note string `json:"note,omitempty"` // 注释
|
||||
// ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
// }
|
||||
type ChatResponse struct {
|
||||
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage model.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
@@ -1,70 +1,78 @@
|
||||
package xunfei
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/relay/channel"
|
||||
// "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/util"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "strings"
|
||||
// )
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
// type Adaptor struct {
|
||||
// request *model.GeneralOpenAIRequest
|
||||
// }
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
type Adaptor struct {
|
||||
request *model.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
// }
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// return "", nil
|
||||
// }
|
||||
}
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
// channel.SetupCommonRequestHeader(c, req, meta)
|
||||
// // check DoResponse for auth part
|
||||
// return nil
|
||||
// }
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// a.request = request
|
||||
// return nil, nil
|
||||
// }
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
// check DoResponse for auth part
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
// // xunfei's request is not http request, so we don't need to do anything here
|
||||
// dummyResp := &http.Response{}
|
||||
// dummyResp.StatusCode = http.StatusOK
|
||||
// return dummyResp, nil
|
||||
// }
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
a.request = request
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// splits := strings.Split(meta.APIKey, "|")
|
||||
// if len(splits) != 3 {
|
||||
// return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
// }
|
||||
// if a.request == nil {
|
||||
// return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
// }
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
// } else {
|
||||
// err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
dummyResp.StatusCode = http.StatusOK
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "xunfei"
|
||||
// }
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
splits := strings.Split(meta.APIKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
if a.request == nil {
|
||||
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
}
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "xunfei"
|
||||
}
|
||||
|
||||
@@ -1,306 +1,313 @@
|
||||
package xunfei
|
||||
|
||||
// import (
|
||||
// "crypto/hmac"
|
||||
// "crypto/sha256"
|
||||
// "encoding/base64"
|
||||
// "encoding/json"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/gorilla/websocket"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "net/url"
|
||||
// "one-api/common"
|
||||
// "strings"
|
||||
// "time"
|
||||
// )
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// // https://console.xfyun.cn/services/cbm
|
||||
// // https://www.xfyun.cn/doc/spark/Web.html
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/common/random"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// type XunfeiMessage struct {
|
||||
// Role string `json:"role"`
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
|
||||
// type XunfeiChatRequest struct {
|
||||
// Header struct {
|
||||
// AppId string `json:"app_id"`
|
||||
// } `json:"header"`
|
||||
// Parameter struct {
|
||||
// Chat struct {
|
||||
// Domain string `json:"domain,omitempty"`
|
||||
// Temperature float64 `json:"temperature,omitempty"`
|
||||
// TopK int `json:"top_k,omitempty"`
|
||||
// MaxTokens int `json:"max_tokens,omitempty"`
|
||||
// Auditing bool `json:"auditing,omitempty"`
|
||||
// } `json:"chat"`
|
||||
// } `json:"parameter"`
|
||||
// Payload struct {
|
||||
// Message struct {
|
||||
// Text []XunfeiMessage `json:"text"`
|
||||
// } `json:"message"`
|
||||
// } `json:"payload"`
|
||||
// }
|
||||
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
var lastToolCalls []model.Tool
|
||||
for _, message := range request.Messages {
|
||||
if message.ToolCalls != nil {
|
||||
lastToolCalls = message.ToolCalls
|
||||
}
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
xunfeiRequest := ChatRequest{}
|
||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
if len(lastToolCalls) != 0 {
|
||||
for _, toolCall := range lastToolCalls {
|
||||
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// type XunfeiChatResponseTextItem struct {
|
||||
// Content string `json:"content"`
|
||||
// Role string `json:"role"`
|
||||
// Index int `json:"index"`
|
||||
// }
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
// type XunfeiChatResponse struct {
|
||||
// Header struct {
|
||||
// Code int `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// Sid string `json:"sid"`
|
||||
// Status int `json:"status"`
|
||||
// } `json:"header"`
|
||||
// Payload struct {
|
||||
// Choices struct {
|
||||
// Status int `json:"status"`
|
||||
// Seq int `json:"seq"`
|
||||
// Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
// } `json:"choices"`
|
||||
// Usage struct {
|
||||
// //Text struct {
|
||||
// // QuestionTokens string `json:"question_tokens"`
|
||||
// // PromptTokens string `json:"prompt_tokens"`
|
||||
// // CompletionTokens string `json:"completion_tokens"`
|
||||
// // TotalTokens string `json:"total_tokens"`
|
||||
// //} `json:"text"`
|
||||
// Text Usage `json:"text"`
|
||||
// } `json:"usage"`
|
||||
// } `json:"payload"`
|
||||
// }
|
||||
func getToolCalls(response *ChatResponse) []model.Tool {
|
||||
var toolCalls []model.Tool
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
return toolCalls
|
||||
}
|
||||
item := response.Payload.Choices.Text[0]
|
||||
if item.FunctionCall == nil {
|
||||
return toolCalls
|
||||
}
|
||||
toolCall := model.Tool{
|
||||
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||
Type: "function",
|
||||
Function: *item.FunctionCall,
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
||||
// messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||
// for _, message := range request.Messages {
|
||||
// if message.Role == "system" {
|
||||
// messages = append(messages, XunfeiMessage{
|
||||
// Role: "user",
|
||||
// Content: message.Content,
|
||||
// })
|
||||
// messages = append(messages, XunfeiMessage{
|
||||
// Role: "assistant",
|
||||
// Content: "Okay",
|
||||
// })
|
||||
// } else {
|
||||
// messages = append(messages, XunfeiMessage{
|
||||
// Role: message.Role,
|
||||
// Content: message.Content,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// xunfeiRequest := XunfeiChatRequest{}
|
||||
// xunfeiRequest.Header.AppId = xunfeiAppId
|
||||
// xunfeiRequest.Parameter.Chat.Domain = domain
|
||||
// xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
// xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
// xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
// xunfeiRequest.Payload.Message.Text = messages
|
||||
// return &xunfeiRequest
|
||||
// }
|
||||
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
response.Payload.Choices.Text = []ChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
ToolCalls: getToolCalls(response),
|
||||
},
|
||||
FinishReason: constant.StopFinishReason,
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Payload.Usage.Text,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
// func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
||||
// if len(response.Payload.Choices.Text) == 0 {
|
||||
// response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
// {
|
||||
// Content: "",
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
// choice := OpenAITextResponseChoice{
|
||||
// Index: 0,
|
||||
// Message: Message{
|
||||
// Role: "assistant",
|
||||
// Content: response.Payload.Choices.Text[0].Content,
|
||||
// },
|
||||
// FinishReason: stopFinishReason,
|
||||
// }
|
||||
// fullTextResponse := OpenAITextResponse{
|
||||
// Object: "chat.completion",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Choices: []OpenAITextResponseChoice{choice},
|
||||
// Usage: response.Payload.Usage.Text,
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
// }
|
||||
func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
choice.Delta.ToolCalls = getToolCalls(xunfeiResponse)
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
|
||||
// if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
// xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
// {
|
||||
// Content: "",
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
// var choice ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
// if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
// choice.FinishReason = &stopFinishReason
|
||||
// }
|
||||
// response := ChatCompletionsStreamResponse{
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Model: "SparkDesk",
|
||||
// Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
// }
|
||||
// return &response
|
||||
// }
|
||||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
encodeData := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(encodeData)
|
||||
}
|
||||
ul, err := url.Parse(hostUrl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
date := time.Now().UTC().Format(time.RFC1123)
|
||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
sign := strings.Join(signString, "\n")
|
||||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
"hmac-sha256", "host date request-line", sha)
|
||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
v := url.Values{}
|
||||
v.Add("host", ul.Host)
|
||||
v.Add("date", date)
|
||||
v.Add("authorization", authorization)
|
||||
callUrl := hostUrl + "?" + v.Encode()
|
||||
return callUrl
|
||||
}
|
||||
|
||||
// func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
// HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||||
// mac := hmac.New(sha256.New, []byte(key))
|
||||
// mac.Write([]byte(data))
|
||||
// encodeData := mac.Sum(nil)
|
||||
// return base64.StdEncoding.EncodeToString(encodeData)
|
||||
// }
|
||||
// ul, err := url.Parse(hostUrl)
|
||||
// if err != nil {
|
||||
// fmt.Println(err)
|
||||
// }
|
||||
// date := time.Now().UTC().Format(time.RFC1123)
|
||||
// signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
// sign := strings.Join(signString, "\n")
|
||||
// sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||||
// authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
// "hmac-sha256", "host date request-line", sha)
|
||||
// authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
// v := url.Values{}
|
||||
// v.Add("host", ul.Host)
|
||||
// v.Add("date", date)
|
||||
// v.Add("authorization", authorization)
|
||||
// callUrl := hostUrl + "?" + v.Encode()
|
||||
// return callUrl
|
||||
// }
|
||||
func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.SetEventStreamHeaders(c)
|
||||
var usage model.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
// func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
// dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// setEventStreamHeaders(c)
|
||||
// var usage Usage
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case xunfeiResponse := <-dataChan:
|
||||
// usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
// usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
// usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
// response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// common.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
// return nil, &usage
|
||||
// }
|
||||
func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var usage model.Usage
|
||||
var content string
|
||||
var xunfeiResponse ChatResponse
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case xunfeiResponse = <-dataChan:
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil
|
||||
}
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
// func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
// dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// var usage Usage
|
||||
// var content string
|
||||
// var xunfeiResponse XunfeiChatResponse
|
||||
// stop := false
|
||||
// for !stop {
|
||||
// select {
|
||||
// case xunfeiResponse = <-dataChan:
|
||||
// if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
// continue
|
||||
// }
|
||||
// content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
// usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
// usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
// usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
// case stop = <-stopChan:
|
||||
// }
|
||||
// }
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
// xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
conn, resp, err := d.Dial(authUrl, nil)
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return nil, nil, err
|
||||
}
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// _, _ = c.Writer.Write(jsonResponse)
|
||||
// return nil, &usage
|
||||
// }
|
||||
dataChan := make(chan ChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
if msg == nil {
|
||||
_, msg, err = conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
var response ChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
msg = nil
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
logger.SysError("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
|
||||
// func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||
// d := websocket.Dialer{
|
||||
// HandshakeTimeout: 5 * time.Second,
|
||||
// }
|
||||
// conn, resp, err := d.Dial(authUrl, nil)
|
||||
// if err != nil || resp.StatusCode != 101 {
|
||||
// return nil, nil, err
|
||||
// }
|
||||
// data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
// err = conn.WriteJSON(data)
|
||||
// if err != nil {
|
||||
// return nil, nil, err
|
||||
// }
|
||||
return dataChan, stopChan, nil
|
||||
}
|
||||
|
||||
// dataChan := make(chan XunfeiChatResponse)
|
||||
// stopChan := make(chan bool)
|
||||
// go func() {
|
||||
// for {
|
||||
// _, msg, err := conn.ReadMessage()
|
||||
// if err != nil {
|
||||
// common.SysError("error reading stream response: " + err.Error())
|
||||
// break
|
||||
// }
|
||||
// var response XunfeiChatResponse
|
||||
// err = json.Unmarshal(msg, &response)
|
||||
// if err != nil {
|
||||
// common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
// break
|
||||
// }
|
||||
// dataChan <- response
|
||||
// if response.Payload.Choices.Status == 2 {
|
||||
// err := conn.Close()
|
||||
// if err != nil {
|
||||
// common.SysError("error closing websocket connection: " + err.Error())
|
||||
// }
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// stopChan <- true
|
||||
// }()
|
||||
func getAPIVersion(c *gin.Context, modelName string) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion != "" {
|
||||
return apiVersion
|
||||
}
|
||||
parts := strings.Split(modelName, "-")
|
||||
if len(parts) == 2 {
|
||||
apiVersion = parts[1]
|
||||
return apiVersion
|
||||
|
||||
// return dataChan, stopChan, nil
|
||||
// }
|
||||
}
|
||||
apiVersion = c.GetString(config.KeyAPIVersion)
|
||||
if apiVersion != "" {
|
||||
return apiVersion
|
||||
}
|
||||
apiVersion = "v1.1"
|
||||
logger.SysLog("api_version not found, using default: " + apiVersion)
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
// func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
||||
// query := c.Request.URL.Query()
|
||||
// apiVersion := query.Get("api-version")
|
||||
// if apiVersion == "" {
|
||||
// apiVersion = c.GetString("api_version")
|
||||
// }
|
||||
// if apiVersion == "" {
|
||||
// apiVersion = "v1.1"
|
||||
// common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
// }
|
||||
// domain := "general"
|
||||
// if apiVersion != "v1.1" {
|
||||
// domain += strings.Split(apiVersion, ".")[0]
|
||||
// }
|
||||
// authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
// return domain, authUrl
|
||||
// }
|
||||
// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
|
||||
func apiVersion2domain(apiVersion string) string {
|
||||
switch apiVersion {
|
||||
case "v1.1":
|
||||
return "general"
|
||||
case "v2.1":
|
||||
return "generalv2"
|
||||
case "v3.1":
|
||||
return "generalv3"
|
||||
case "v3.5":
|
||||
return "generalv3.5"
|
||||
}
|
||||
return "general" + apiVersion
|
||||
}
|
||||
|
||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
|
||||
apiVersion := getAPIVersion(c, modelName)
|
||||
domain := apiVersion2domain(apiVersion)
|
||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
||||
}
|
||||
|
||||
@@ -1,61 +1,66 @@
|
||||
package xunfei
|
||||
|
||||
// import (
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// )
|
||||
import (
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
)
|
||||
|
||||
// type Message struct {
|
||||
// Role string `json:"role"`
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// type ChatRequest struct {
|
||||
// Header struct {
|
||||
// AppId string `json:"app_id"`
|
||||
// } `json:"header"`
|
||||
// Parameter struct {
|
||||
// Chat struct {
|
||||
// Domain string `json:"domain,omitempty"`
|
||||
// Temperature float64 `json:"temperature,omitempty"`
|
||||
// TopK int `json:"top_k,omitempty"`
|
||||
// MaxTokens int `json:"max_tokens,omitempty"`
|
||||
// Auditing bool `json:"auditing,omitempty"`
|
||||
// } `json:"chat"`
|
||||
// } `json:"parameter"`
|
||||
// Payload struct {
|
||||
// Message struct {
|
||||
// Text []Message `json:"text"`
|
||||
// } `json:"message"`
|
||||
// } `json:"payload"`
|
||||
// }
|
||||
type ChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []Message `json:"text"`
|
||||
} `json:"message"`
|
||||
Functions struct {
|
||||
Text []model.Function `json:"text,omitempty"`
|
||||
} `json:"functions,omitempty"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
// type ChatResponseTextItem struct {
|
||||
// Content string `json:"content"`
|
||||
// Role string `json:"role"`
|
||||
// Index int `json:"index"`
|
||||
// }
|
||||
type ChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
ContentType string `json:"content_type"`
|
||||
FunctionCall *model.Function `json:"function_call"`
|
||||
}
|
||||
|
||||
// type ChatResponse struct {
|
||||
// Header struct {
|
||||
// Code int `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// Sid string `json:"sid"`
|
||||
// Status int `json:"status"`
|
||||
// } `json:"header"`
|
||||
// Payload struct {
|
||||
// Choices struct {
|
||||
// Status int `json:"status"`
|
||||
// Seq int `json:"seq"`
|
||||
// Text []ChatResponseTextItem `json:"text"`
|
||||
// } `json:"choices"`
|
||||
// Usage struct {
|
||||
// //Text struct {
|
||||
// // QuestionTokens string `json:"question_tokens"`
|
||||
// // PromptTokens string `json:"prompt_tokens"`
|
||||
// // CompletionTokens string `json:"completion_tokens"`
|
||||
// // TotalTokens string `json:"total_tokens"`
|
||||
// //} `json:"text"`
|
||||
// Text model.Usage `json:"text"`
|
||||
// } `json:"usage"`
|
||||
// } `json:"payload"`
|
||||
// }
|
||||
type ChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []ChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text model.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
@@ -1,145 +1,146 @@
|
||||
package zhipu
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
// "github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
// "github.com/songquanpeng/one-api/relay/meta"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/relaymode"
|
||||
// "io"
|
||||
// "math"
|
||||
// "net/http"
|
||||
// "strings"
|
||||
// )
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
// type Adaptor struct {
|
||||
// APIVersion string
|
||||
// }
|
||||
"github.com/Laisky/one-api/relay/adaptor"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
type Adaptor struct {
|
||||
APIVersion string
|
||||
}
|
||||
|
||||
// }
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
// func (a *Adaptor) SetVersionByModeName(modelName string) {
|
||||
// if strings.HasPrefix(modelName, "glm-") {
|
||||
// a.APIVersion = "v4"
|
||||
// } else {
|
||||
// a.APIVersion = "v3"
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// switch meta.Mode {
|
||||
// case relaymode.ImagesGenerations:
|
||||
// return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
|
||||
// case relaymode.Embeddings:
|
||||
// return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
|
||||
// }
|
||||
// a.SetVersionByModeName(meta.ActualModelName)
|
||||
// if a.APIVersion == "v4" {
|
||||
// return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
|
||||
// }
|
||||
// method := "invoke"
|
||||
// if meta.IsStream {
|
||||
// method = "sse-invoke"
|
||||
// }
|
||||
// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
|
||||
// }
|
||||
func (a *Adaptor) SetVersionByModeName(modelName string) {
|
||||
if strings.HasPrefix(modelName, "glm-") {
|
||||
a.APIVersion = "v4"
|
||||
} else {
|
||||
a.APIVersion = "v3"
|
||||
}
|
||||
}
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
// adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
// token := GetToken(meta.APIKey)
|
||||
// req.Header.Set("Authorization", token)
|
||||
// return nil
|
||||
// }
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
|
||||
case relaymode.Embeddings:
|
||||
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
|
||||
}
|
||||
a.SetVersionByModeName(meta.ActualModelName)
|
||||
if a.APIVersion == "v4" {
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
|
||||
}
|
||||
method := "invoke"
|
||||
if meta.IsStream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// switch relayMode {
|
||||
// case relaymode.Embeddings:
|
||||
// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
// return baiduEmbeddingRequest, nil
|
||||
// default:
|
||||
// // TopP (0.0, 1.0)
|
||||
// request.TopP = math.Min(0.99, request.TopP)
|
||||
// request.TopP = math.Max(0.01, request.TopP)
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
token := GetToken(meta.APIKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
// // Temperature (0.0, 1.0)
|
||||
// request.Temperature = math.Min(0.99, request.Temperature)
|
||||
// request.Temperature = math.Max(0.01, request.Temperature)
|
||||
// a.SetVersionByModeName(request.Model)
|
||||
// if a.APIVersion == "v4" {
|
||||
// return request, nil
|
||||
// }
|
||||
// return ConvertRequest(*request), nil
|
||||
// }
|
||||
// }
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
// TopP (0.0, 1.0)
|
||||
request.TopP = math.Min(0.99, request.TopP)
|
||||
request.TopP = math.Max(0.01, request.TopP)
|
||||
|
||||
// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// newRequest := ImageRequest{
|
||||
// Model: request.Model,
|
||||
// Prompt: request.Prompt,
|
||||
// UserId: request.User,
|
||||
// }
|
||||
// return newRequest, nil
|
||||
// }
|
||||
// Temperature (0.0, 1.0)
|
||||
request.Temperature = math.Min(0.99, request.Temperature)
|
||||
request.Temperature = math.Max(0.01, request.Temperature)
|
||||
a.SetVersionByModeName(request.Model)
|
||||
if a.APIVersion == "v4" {
|
||||
return request, nil
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
newRequest := ImageRequest{
|
||||
Model: request.Model,
|
||||
Prompt: request.Prompt,
|
||||
UserId: request.User,
|
||||
}
|
||||
return newRequest, nil
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
|
||||
// } else {
|
||||
// err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// switch meta.Mode {
|
||||
// case relaymode.Embeddings:
|
||||
// err, usage = EmbeddingsHandler(c, resp)
|
||||
// return
|
||||
// case relaymode.ImagesGenerations:
|
||||
// err, usage = openai.ImageHandler(c, resp)
|
||||
// return
|
||||
// }
|
||||
// if a.APIVersion == "v4" {
|
||||
// return a.DoResponseV4(c, resp, meta)
|
||||
// }
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, resp)
|
||||
// } else {
|
||||
// if meta.Mode == relaymode.Embeddings {
|
||||
// err, usage = EmbeddingsHandler(c, resp)
|
||||
// } else {
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
|
||||
} else {
|
||||
err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
// return &EmbeddingRequest{
|
||||
// Model: "embedding-2",
|
||||
// Input: request.Input.(string),
|
||||
// }
|
||||
// }
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingsHandler(c, resp)
|
||||
return
|
||||
case relaymode.ImagesGenerations:
|
||||
err, usage = openai.ImageHandler(c, resp)
|
||||
return
|
||||
}
|
||||
if a.APIVersion == "v4" {
|
||||
return a.DoResponseV4(c, resp, meta)
|
||||
}
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
if meta.Mode == relaymode.Embeddings {
|
||||
err, usage = EmbeddingsHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: "embedding-2",
|
||||
Input: request.Input.(string),
|
||||
}
|
||||
}
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "zhipu"
|
||||
// }
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "zhipu"
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package zhipu
|
||||
|
||||
// var ModelList = []string{
|
||||
// "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
|
||||
// }
|
||||
var ModelList = []string{
|
||||
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
|
||||
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
|
||||
"cogview-3",
|
||||
}
|
||||
|
||||
@@ -1,301 +1,304 @@
|
||||
package zhipu
|
||||
|
||||
// import (
|
||||
// "bufio"
|
||||
// "encoding/json"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/golang-jwt/jwt"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "one-api/common"
|
||||
// "strings"
|
||||
// "sync"
|
||||
// "time"
|
||||
// )
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
// // https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
// // chatglm_std, chatglm_lite
|
||||
// // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||
// // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/helper"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/constant"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// type ZhipuMessage struct {
|
||||
// Role string `json:"role"`
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
// chatglm_std, chatglm_lite
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||
|
||||
// type ZhipuRequest struct {
|
||||
// Prompt []ZhipuMessage `json:"prompt"`
|
||||
// Temperature float64 `json:"temperature,omitempty"`
|
||||
// TopP float64 `json:"top_p,omitempty"`
|
||||
// RequestId string `json:"request_id,omitempty"`
|
||||
// Incremental bool `json:"incremental,omitempty"`
|
||||
// }
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
// type ZhipuResponseData struct {
|
||||
// TaskId string `json:"task_id"`
|
||||
// RequestId string `json:"request_id"`
|
||||
// TaskStatus string `json:"task_status"`
|
||||
// Choices []ZhipuMessage `json:"choices"`
|
||||
// Usage `json:"usage"`
|
||||
// }
|
||||
func GetToken(apikey string) string {
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(tokenData)
|
||||
if time.Now().Before(tokenData.ExpiryTime) {
|
||||
return tokenData.Token
|
||||
}
|
||||
}
|
||||
|
||||
// type ZhipuResponse struct {
|
||||
// Code int `json:"code"`
|
||||
// Msg string `json:"msg"`
|
||||
// Success bool `json:"success"`
|
||||
// Data ZhipuResponseData `json:"data"`
|
||||
// }
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
logger.SysError("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
// type ZhipuStreamMetaResponse struct {
|
||||
// RequestId string `json:"request_id"`
|
||||
// TaskId string `json:"task_id"`
|
||||
// TaskStatus string `json:"task_status"`
|
||||
// Usage `json:"usage"`
|
||||
// }
|
||||
id := split[0]
|
||||
secret := split[1]
|
||||
|
||||
// type zhipuTokenData struct {
|
||||
// Token string
|
||||
// ExpiryTime time.Time
|
||||
// }
|
||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||
|
||||
// var zhipuTokens sync.Map
|
||||
// var expSeconds int64 = 24 * 3600
|
||||
timestamp := time.Now().UnixNano() / 1e6
|
||||
|
||||
// func getZhipuToken(apikey string) string {
|
||||
// data, ok := zhipuTokens.Load(apikey)
|
||||
// if ok {
|
||||
// tokenData := data.(zhipuTokenData)
|
||||
// if time.Now().Before(tokenData.ExpiryTime) {
|
||||
// return tokenData.Token
|
||||
// }
|
||||
// }
|
||||
payload := jwt.MapClaims{
|
||||
"api_key": id,
|
||||
"exp": expMillis,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
// split := strings.Split(apikey, ".")
|
||||
// if len(split) != 2 {
|
||||
// common.SysError("invalid zhipu key: " + apikey)
|
||||
// return ""
|
||||
// }
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
|
||||
// id := split[0]
|
||||
// secret := split[1]
|
||||
token.Header["alg"] = "HS256"
|
||||
token.Header["sign_type"] = "SIGN"
|
||||
|
||||
// expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||
// expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// timestamp := time.Now().UnixNano() / 1e6
|
||||
zhipuTokens.Store(apikey, tokenData{
|
||||
Token: tokenString,
|
||||
ExpiryTime: expiryTime,
|
||||
})
|
||||
|
||||
// payload := jwt.MapClaims{
|
||||
// "api_key": id,
|
||||
// "exp": expMillis,
|
||||
// "timestamp": timestamp,
|
||||
// }
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
return &Request{
|
||||
Prompt: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Incremental: false,
|
||||
}
|
||||
}
|
||||
|
||||
// token.Header["alg"] = "HS256"
|
||||
// token.Header["sign_type"] = "SIGN"
|
||||
func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
|
||||
Usage: response.Data.Usage,
|
||||
}
|
||||
for i, choice := range response.Data.Choices {
|
||||
openaiChoice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: model.Message{
|
||||
Role: choice.Role,
|
||||
Content: strings.Trim(choice.Content, "\""),
|
||||
},
|
||||
FinishReason: "",
|
||||
}
|
||||
if i == len(response.Data.Choices)-1 {
|
||||
openaiChoice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
// tokenString, err := token.SignedString([]byte(secret))
|
||||
// if err != nil {
|
||||
// return ""
|
||||
// }
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// zhipuTokens.Store(apikey, zhipuTokenData{
|
||||
// Token: tokenString,
|
||||
// ExpiryTime: expiryTime,
|
||||
// })
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
// return tokenString
|
||||
// }
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage *model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
metaChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
lines := strings.Split(data, "\n")
|
||||
for i, line := range lines {
|
||||
if len(line) < 5 {
|
||||
continue
|
||||
}
|
||||
if line[:5] == "data:" {
|
||||
dataChan <- line[5:]
|
||||
if i != len(lines)-1 {
|
||||
dataChan <- "\n"
|
||||
}
|
||||
} else if line[:5] == "meta:" {
|
||||
metaChan <- line[5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
response := streamResponseZhipu2OpenAI(data)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case data := <-metaChan:
|
||||
var zhipuResponse StreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage = zhipuUsage
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
// func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
// messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||
// for _, message := range request.Messages {
|
||||
// if message.Role == "system" {
|
||||
// messages = append(messages, ZhipuMessage{
|
||||
// Role: "system",
|
||||
// Content: message.Content,
|
||||
// })
|
||||
// messages = append(messages, ZhipuMessage{
|
||||
// Role: "user",
|
||||
// Content: "Okay",
|
||||
// })
|
||||
// } else {
|
||||
// messages = append(messages, ZhipuMessage{
|
||||
// Role: message.Role,
|
||||
// Content: message.Content,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// return &ZhipuRequest{
|
||||
// Prompt: messages,
|
||||
// Temperature: request.Temperature,
|
||||
// TopP: request.TopP,
|
||||
// Incremental: false,
|
||||
// }
|
||||
// }
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var zhipuResponse Response
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
Code: zhipuResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||
fullTextResponse.Model = "chatglm"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
// func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
||||
// fullTextResponse := OpenAITextResponse{
|
||||
// Id: response.Data.TaskId,
|
||||
// Object: "chat.completion",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
||||
// Usage: response.Data.Usage,
|
||||
// }
|
||||
// for i, choice := range response.Data.Choices {
|
||||
// openaiChoice := OpenAITextResponseChoice{
|
||||
// Index: i,
|
||||
// Message: Message{
|
||||
// Role: choice.Role,
|
||||
// Content: strings.Trim(choice.Content, "\""),
|
||||
// },
|
||||
// FinishReason: "",
|
||||
// }
|
||||
// if i == len(response.Data.Choices)-1 {
|
||||
// openaiChoice.FinishReason = "stop"
|
||||
// }
|
||||
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
// }
|
||||
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var zhipuResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
// func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
|
||||
// var choice ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = zhipuResponse
|
||||
// response := ChatCompletionsStreamResponse{
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Model: "chatglm",
|
||||
// Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
// }
|
||||
// return &response
|
||||
// }
|
||||
func embeddingResponseZhipu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
|
||||
Model: response.Model,
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.PromptTokens,
|
||||
CompletionTokens: response.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
// func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
|
||||
// var choice ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = ""
|
||||
// choice.FinishReason = &stopFinishReason
|
||||
// response := ChatCompletionsStreamResponse{
|
||||
// Id: zhipuResponse.RequestId,
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Model: "chatglm",
|
||||
// Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
// }
|
||||
// return &response, &zhipuResponse.Usage
|
||||
// }
|
||||
|
||||
// func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var usage *Usage
|
||||
// scanner := bufio.NewScanner(resp.Body)
|
||||
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// if atEOF && len(data) == 0 {
|
||||
// return 0, nil, nil
|
||||
// }
|
||||
// if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||
// return i + 2, data[0:i], nil
|
||||
// }
|
||||
// if atEOF {
|
||||
// return len(data), data, nil
|
||||
// }
|
||||
// return 0, nil, nil
|
||||
// })
|
||||
// dataChan := make(chan string)
|
||||
// metaChan := make(chan string)
|
||||
// stopChan := make(chan bool)
|
||||
// go func() {
|
||||
// for scanner.Scan() {
|
||||
// data := scanner.Text()
|
||||
// lines := strings.Split(data, "\n")
|
||||
// for i, line := range lines {
|
||||
// if len(line) < 5 {
|
||||
// continue
|
||||
// }
|
||||
// if line[:5] == "data:" {
|
||||
// dataChan <- line[5:]
|
||||
// if i != len(lines)-1 {
|
||||
// dataChan <- "\n"
|
||||
// }
|
||||
// } else if line[:5] == "meta:" {
|
||||
// metaChan <- line[5:]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// stopChan <- true
|
||||
// }()
|
||||
// setEventStreamHeaders(c)
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case data := <-dataChan:
|
||||
// response := streamResponseZhipu2OpenAI(data)
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// common.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case data := <-metaChan:
|
||||
// var zhipuResponse ZhipuStreamMetaResponse
|
||||
// err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
// if err != nil {
|
||||
// common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// common.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// usage = zhipuUsage
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
// err := resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// return nil, usage
|
||||
// }
|
||||
|
||||
// func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
// var zhipuResponse ZhipuResponse
|
||||
// responseBody, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// if !zhipuResponse.Success {
|
||||
// return &OpenAIErrorWithStatusCode{
|
||||
// OpenAIError: OpenAIError{
|
||||
// Message: zhipuResponse.Msg,
|
||||
// Type: "zhipu_error",
|
||||
// Param: "",
|
||||
// Code: zhipuResponse.Code,
|
||||
// },
|
||||
// StatusCode: resp.StatusCode,
|
||||
// }, nil
|
||||
// }
|
||||
// fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||
// jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// _, err = c.Writer.Write(jsonResponse)
|
||||
// return nil, &fullTextResponse.Usage
|
||||
// }
|
||||
for _, item := range response.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
@@ -1,46 +1,71 @@
|
||||
package zhipu
|
||||
|
||||
// import (
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "time"
|
||||
// )
|
||||
import (
|
||||
"time"
|
||||
|
||||
// type Message struct {
|
||||
// Role string `json:"role"`
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
)
|
||||
|
||||
// type Request struct {
|
||||
// Prompt []Message `json:"prompt"`
|
||||
// Temperature float64 `json:"temperature,omitempty"`
|
||||
// TopP float64 `json:"top_p,omitempty"`
|
||||
// RequestId string `json:"request_id,omitempty"`
|
||||
// Incremental bool `json:"incremental,omitempty"`
|
||||
// }
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// type ResponseData struct {
|
||||
// TaskId string `json:"task_id"`
|
||||
// RequestId string `json:"request_id"`
|
||||
// TaskStatus string `json:"task_status"`
|
||||
// Choices []Message `json:"choices"`
|
||||
// model.Usage `json:"usage"`
|
||||
// }
|
||||
type Request struct {
|
||||
Prompt []Message `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
// type Response struct {
|
||||
// Code int `json:"code"`
|
||||
// Msg string `json:"msg"`
|
||||
// Success bool `json:"success"`
|
||||
// Data ResponseData `json:"data"`
|
||||
// }
|
||||
type ResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []Message `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// type StreamMetaResponse struct {
|
||||
// RequestId string `json:"request_id"`
|
||||
// TaskId string `json:"task_id"`
|
||||
// TaskStatus string `json:"task_status"`
|
||||
// model.Usage `json:"usage"`
|
||||
// }
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ResponseData `json:"data"`
|
||||
}
|
||||
|
||||
// type tokenData struct {
|
||||
// Token string
|
||||
// ExpiryTime time.Time
|
||||
// }
|
||||
type StreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Embeddings []EmbeddingData `json:"data"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package billing
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/model"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||
|
||||
@@ -2,7 +2,7 @@ package ratio
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
)
|
||||
|
||||
var GroupRatio = map[string]float64{
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -9,12 +9,12 @@ package tencent
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/common"
|
||||
// "github.com/songquanpeng/one-api/common/helper"
|
||||
// "github.com/songquanpeng/one-api/common/logger"
|
||||
// "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
// "github.com/songquanpeng/one-api/relay/constant"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/Laisky/one-api/common"
|
||||
// "github.com/Laisky/one-api/common/helper"
|
||||
// "github.com/Laisky/one-api/common/logger"
|
||||
// "github.com/Laisky/one-api/relay/channel/openai"
|
||||
// "github.com/Laisky/one-api/relay/constant"
|
||||
// "github.com/Laisky/one-api/relay/model"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "sort"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package channeltype
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/apitype"
|
||||
import "github.com/Laisky/one-api/relay/apitype"
|
||||
|
||||
func ToAPIType(channelType int) int {
|
||||
apiType := apitype.OpenAI
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
gutils "github.com/Laisky/go-utils/v4"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
)
|
||||
|
||||
var HTTPClient *http.Client
|
||||
|
||||
@@ -7,19 +7,19 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/model"
|
||||
"github.com/Laisky/one-api/relay/adaptor/azure"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/billing"
|
||||
billingratio "github.com/Laisky/one-api/relay/billing/ratio"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/client"
|
||||
relaymodel "github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/azure"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -3,9 +3,9 @@ package controller
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -4,18 +4,18 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common"
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/model"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/Laisky/one-api/relay/billing/ratio"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/controller/validator"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
relaymodel "github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/controller/validator"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"math"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -9,15 +9,15 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/model"
|
||||
"github.com/Laisky/one-api/relay"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/Laisky/one-api/relay/billing/ratio"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
relaymodel "github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
|
||||
@@ -3,22 +3,21 @@ package controller
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/Laisky/one-api/common/logger"
|
||||
"github.com/Laisky/one-api/relay"
|
||||
"github.com/Laisky/one-api/relay/adaptor/openai"
|
||||
"github.com/Laisky/one-api/relay/apitype"
|
||||
"github.com/Laisky/one-api/relay/billing"
|
||||
billingratio "github.com/Laisky/one-api/relay/billing/ratio"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/meta"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
@@ -95,14 +94,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if resp != nil {
|
||||
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
|
||||
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json"))
|
||||
if errorHappened {
|
||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q",
|
||||
resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes)))
|
||||
return RelayErrorHandler(resp)
|
||||
}
|
||||
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
}
|
||||
|
||||
// do response
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"math"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"github.com/Laisky/one-api/relay/model"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"github.com/Laisky/one-api/common/config"
|
||||
"github.com/Laisky/one-api/relay/adaptor/azure"
|
||||
"github.com/Laisky/one-api/relay/channeltype"
|
||||
"github.com/Laisky/one-api/relay/relaymode"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/azure"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user