feat: 支持部分渠道的system角色 (close #89)

This commit is contained in:
1808837298@qq.com 2024-03-06 14:16:04 +08:00
parent da73dca9a7
commit 3ab4f145db
6 changed files with 79 additions and 67 deletions

View File

@ -1,8 +1,8 @@
package ali package ali
type AliMessage struct { type AliMessage struct {
User string `json:"user"` Content string `json:"content"`
Bot string `json:"bot"` Role string `json:"role"`
} }
type AliInput struct { type AliInput struct {
@ -15,6 +15,7 @@ type AliParameters struct {
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"` Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"` EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
} }
type AliChatRequest struct { type AliChatRequest struct {

View File

@ -14,28 +14,23 @@ import (
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
const EnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest { func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages)) messages := make([]AliMessage, 0, len(request.Messages))
prompt := "" prompt := ""
for i := 0; i < len(request.Messages); i++ { for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i] message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{ messages = append(messages, AliMessage{
User: message.StringContent(), Content: message.StringContent(),
Bot: "Okay", Role: strings.ToLower(message.Role),
}) })
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: string(request.Messages[i+1].Content),
})
i++
} }
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
} }
return &AliChatRequest{ return &AliChatRequest{
Model: request.Model, Model: request.Model,
@ -43,12 +38,11 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
Prompt: prompt, Prompt: prompt,
History: messages, History: messages,
}, },
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's Parameters: AliParameters{
// TopP: request.TopP, IncrementalOutput: request.Stream,
// TopK: 50, Seed: uint64(request.Seed),
// //Seed: 0, EnableSearch: enableSearch,
// //EnableSearch: false, },
//},
} }
} }

View File

@ -24,22 +24,11 @@ var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages)) messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: message.Role, Role: message.Role,
Content: message.StringContent(), Content: message.StringContent(),
}) })
} }
}
return &BaiduChatRequest{ return &BaiduChatRequest{
Messages: messages, Messages: messages,
Stream: request.Stream, Stream: request.Stream,

View File

@ -50,10 +50,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Header.Set("Authorization", "Bearer "+info.ApiKey)
if info.ChannelType == common.ChannelTypeOpenRouter { //if info.ChannelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API") // req.Header.Set("X-Title", "One API")
} //}
return nil return nil
} }

View File

@ -24,8 +24,9 @@ import (
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages)) messages := make([]XunfeiMessage, 0, len(request.Messages))
shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" && shouldCovertSystemMessage {
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: "user", Role: "user",
Content: message.StringContent(), Content: message.StringContent(),
@ -126,7 +127,7 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
} }
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
@ -156,7 +157,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
} }
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
@ -235,20 +236,44 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
return dataChan, stopChan, nil return dataChan, stopChan, nil
} }
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { func apiVersion2domain(apiVersion string) string {
query := c.Request.URL.Query() switch apiVersion {
apiVersion := query.Get("api-version") case "v1.1":
if apiVersion == "" { return "general"
apiVersion = c.GetString("api_version") case "v2.1":
return "generalv2"
case "v3.1":
return "generalv3"
case "v3.5":
return "generalv3.5"
} }
if apiVersion == "" { return "general" + apiVersion
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
} }
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
apiVersion := getAPIVersion(c, modelName)
domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl return domain, authUrl
} }
func getAPIVersion(c *gin.Context, modelName string) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion != "" {
return apiVersion
}
parts := strings.Split(modelName, "-")
if len(parts) == 2 {
apiVersion = parts[1]
return apiVersion
}
apiVersion = c.GetString("api_version")
if apiVersion != "" {
return apiVersion
}
apiVersion = "v1.1"
common.SysLog("api_version not found, using default: " + apiVersion)
return apiVersion
}

View File

@ -72,13 +72,13 @@ const EditChannel = (props) => {
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break; break;
case 17: case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; localModels = ["qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", 'text-embedding-v1'];
break; break;
case 16: case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break; break;
case 18: case 18:
localModels = ['SparkDesk']; localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'];
break; break;
case 19: case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
@ -87,7 +87,10 @@ const EditChannel = (props) => {
localModels = ['hunyuan']; localModels = ['hunyuan'];
break; break;
case 24: case 24:
localModels = ['gemini-pro']; localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
case 25:
localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'];
break; break;
case 26: case 26:
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo']; localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];