merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-08-26 18:50:07 +08:00
19 changed files with 278 additions and 17 deletions

View File

@@ -139,6 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
claudeMessages := make([]ClaudeMessage, 0)
isFirstMessage := true
for _, message := range formatMessages {
if message.Role == "system" {
if message.IsStringContent() {
@@ -154,6 +155,22 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeRequest.System = content
}
} else {
if isFirstMessage {
isFirstMessage = false
if message.Role != "user" {
// fix: first message is assistant, add user message
claudeMessage := ClaudeMessage{
Role: "user",
Content: []ClaudeMediaMessage{
{
Type: "text",
Text: "...",
},
},
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeMessage := ClaudeMessage{
Role: message.Role,
}

View File

@@ -0,0 +1,80 @@
package siliconflow
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = siliconflowRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeChatCompletions {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,51 @@
package siliconflow
var ModelList = []string{
"THUDM/glm-4-9b-chat",
//"stabilityai/stable-diffusion-xl-base-1.0",
//"TencentARC/PhotoMaker",
"InstantX/InstantID",
//"stabilityai/stable-diffusion-2-1",
//"stabilityai/sd-turbo",
//"stabilityai/sdxl-turbo",
"ByteDance/SDXL-Lightning",
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
//"stabilityai/stable-diffusion-3-medium",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
"black-forest-labs/FLUX.1-schnell",
"iic/SenseVoiceSmall",
"netease-youdao/bce-embedding-base_v1",
"BAAI/bge-m3",
"internlm/internlm2_5-20b-chat",
"Qwen/Qwen2-Math-72B-Instruct",
"netease-youdao/bce-reranker-base_v1",
"BAAI/bge-reranker-v2-m3",
}
var ChannelName = "siliconflow"

View File

@@ -0,0 +1,17 @@
package siliconflow
import "one-api/dto"
type SFTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type SFMeta struct {
Tokens SFTokens `json:"tokens"`
}
type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta SFMeta `json:"meta"`
}

View File

@@ -0,0 +1,44 @@
package siliconflow
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
}
rerankResp := &dto.RerankResponse{
Results: siliconflowResp.Results,
Usage: *usage,
}
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -23,6 +23,7 @@ const (
APITypeDify
APITypeJina
APITypeCloudflare
APITypeSiliconFlow
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -66,6 +67,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeJina
case common.ChannelCloudflare:
apiType = APITypeCloudflare
case common.ChannelTypeSiliconFlow:
apiType = APITypeSiliconFlow
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -38,9 +38,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
@@ -50,6 +48,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
//if imageRequest.N != 1 {
// return nil, errors.New("n must be 1")
//}

View File

@@ -326,7 +326,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
totalTokens := promptTokens + completionTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f分组倍率 %.2f补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
logContent = fmt.Sprintf("模型倍率 %.2f补全倍率 %.2f分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}

View File

@@ -16,6 +16,7 @@ import (
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
@@ -62,6 +63,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &jina.Adaptor{}
case constant.APITypeCloudflare:
return &cloudflare.Adaptor{}
case constant.APITypeSiliconFlow:
return &siliconflow.Adaptor{}
}
return nil
}

View File

@@ -38,6 +38,23 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if len(rerankRequest.Documents) == 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[rerankRequest.Model] != "" {
rerankRequest.Model = modelMap[rerankRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)