Compare commits

...

15 Commits

Author SHA1 Message Date
CalciumIon
1e536ee7d9 fix: streaming timeout 2024-07-07 01:01:55 +08:00
CalciumIon
8a730cfe12 feat: support jina rerank 2024-07-06 18:42:48 +08:00
CalciumIon
3ed4f2f0a9 Update README.md 2024-07-06 18:13:26 +08:00
CalciumIon
bec18ed82d Update README.md 2024-07-06 17:46:47 +08:00
CalciumIon
bd9bf4b732 chore: remove useless code 2024-07-06 17:29:28 +08:00
CalciumIon
1735e093db fix: fix rerank 2024-07-06 17:28:00 +08:00
CalciumIon
8af4e28f75 feat: support cohere rerank 2024-07-06 17:09:22 +08:00
CalciumIon
afe02c6aa5 fix: midjourney channel auto ban 2024-07-06 01:44:30 +08:00
CalciumIon
e0ed59bfe3 feat: support dify (close #299) 2024-07-06 01:32:40 +08:00
CalciumIon
bd7222118a feat: 记录兑换时兑换码的ID (close #286) 2024-07-05 20:57:32 +08:00
CalciumIon
cf3d894195 feat: 记录渠道测试的消费日志 (close #334) 2024-07-05 20:51:25 +08:00
CalciumIon
7011083201 feat: 统计无限令牌的已用额度 (close #308) 2024-07-05 20:28:17 +08:00
CalciumIon
752048dfb4 feat: 统计无限令牌的已用额度 (close #308) 2024-07-05 20:25:33 +08:00
CalciumIon
eb382d28ab feat: update baidu 2024-07-05 20:22:30 +08:00
CalciumIon
a9e1078bca fix typo 2024-07-05 20:01:25 +08:00
50 changed files with 978 additions and 119 deletions

View File

@@ -2,6 +2,21 @@
**简介**:Midjourney Proxy API文档
## 接口列表
支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
## 模型列表
### midjourney-proxy支持

View File

@@ -16,19 +16,7 @@
此分叉版本的主要变更如下:
1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
4. 支持用key查询使用额度:
@@ -45,22 +33,21 @@
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
3. 选择你的bot然后输入http(s)://你的网站地址/login
4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
+ [x] /suno/submit/music
+ [x] /suno/submit/lyrics
+ [x] /suno/fetch
+ [x] /suno/fetch/:id
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
14. 支持Rerank模型目前仅兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
## 模型支持
此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*
2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
3. Anthropic Claude 3
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
6. [零一万物](https://platform.lingyiwanwu.com/)
7. 自定义渠道,支持填入完整调用地址
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。

62
Rerank.md Normal file
View File

@@ -0,0 +1,62 @@
# Rerank API文档
**简介**:Rerank API文档
## 接入Dify
模型供应商选择Jina按要求填写模型信息即可接入Dify。
## 请求方式
Post: /v1/rerank
Request:
```json
{
"model": "rerank-multilingual-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
}
```
Response:
```json
{
"results": [
{
"document": {
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
},
"index": 2,
"relevance_score": 0.9999702
},
{
"document": {
"text": "Carson City is the capital city of the American state of Nevada."
},
"index": 0,
"relevance_score": 0.67800725
},
{
"document": {
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
},
"index": 3,
"relevance_score": 0.02800752
}
],
"usage": {
"prompt_tokens": 158,
"completion_tokens": 0,
"total_tokens": 158
}
}
```

View File

@@ -2,6 +2,13 @@
**简介**:Suno API文档
## 接口列表
支持的接口如下:
+ [x] /suno/submit/music
+ [x] /suno/submit/lyrics
+ [x] /suno/fetch
+ [x] /suno/fetch/:id
## 模型列表
### Suno API支持

View File

@@ -210,6 +210,8 @@ const (
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelTypeDummy // this one is only for count, do not add any channel after this
@@ -253,4 +255,6 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.jina.ai", //38
}

View File

@@ -78,19 +78,23 @@ var defaultModelRatio = map[string]float64{
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens //renamed to ERNIE-4.0-8K
"ERNIE-4.0-8K": 8.572, // ¥0.12 / 1k tokens
"ERNIE-3.5-8K": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Speed-8K": 0.2858, // ¥0.004 / 1k tokens
"ERNIE-Speed-128K": 0.2858, // ¥0.004 / 1k tokens
"ERNIE-Lite-8K": 0.2143, // ¥0.003 / 1k tokens
"ERNIE-Tiny-8K": 0.0715, // ¥0.001 / 1k tokens
"ERNIE-Character-8K": 0.2858, // ¥0.004 / 1k tokens
"ERNIE-Functions-8K": 0.2858, // ¥0.004 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"net/url"
@@ -24,6 +25,7 @@ import (
)
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
@@ -65,7 +67,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
testModel = adaptor.GetModelList()[0]
if len(adaptor.GetModelList()) > 0 {
testModel = adaptor.GetModelList()[0]
} else {
testModel = "gpt-3.5-turbo"
}
}
} else {
modelMapping := *channel.ModelMapping
@@ -120,6 +126,25 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
if err != nil {
return err, nil
}
modelPrice, usePrice := common.GetModelPrice(testModel, false)
modelRatio := common.GetModelRatio(testModel)
completionRatio := common.GetCompletionRatio(testModel)
ratio := modelRatio
quota := 0
if !usePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit)
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
@@ -140,7 +165,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -148,7 +173,7 @@ func TestChannel(c *gin.Context) {
})
return
}
channel, err := model.GetChannelById(id, true)
channel, err := model.GetChannelById(channelId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
fallthrough
case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c, relayMode)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
default:
err = relay.TextHelper(c)
}

View File

@@ -24,7 +24,7 @@ func UpdateTaskBulk() {
//imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
common.SysLog(" 任务进度轮询开始")
common.SysLog("任务进度轮询开始")
ctx := context.TODO()
allTasks := model.GetAllUnFinishSyncTasks(500)
platformTask := make(map[constant.TaskPlatform][]*model.Task)

19
dto/rerank.go Normal file
View File

@@ -0,0 +1,19 @@
package dto
type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
}
type RerankResponseDocument struct {
Document any `json:"document"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type RerankResponse struct {
Results []RerankResponseDocument `json:"results"`
Usage Usage `json:"usage"`
}

View File

@@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return 0, errors.New("兑换失败," + err.Error())
}
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
return redemption.Quota, nil
}

View File

@@ -250,11 +250,9 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
if !token.UnlimitedQuota {
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
return 0, err
}
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
return 0, err
}
err = DecreaseUserQuota(token.UserId, quota)
return userQuota - quota, err
@@ -272,15 +270,13 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
return err
}
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
}
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
}
if sendEmail {

View File

@@ -11,9 +11,11 @@ import (
type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string

View File

@@ -15,6 +15,9 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -20,6 +20,11 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
@@ -53,6 +58,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return claudeReq, err
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}

View File

@@ -2,6 +2,7 @@ package baidu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -15,44 +16,74 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.UpstreamModelName {
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "ERNIE-4.0-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-3.5-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Character-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k"
case "ERNIE-Functions-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k"
case "ERNIE-Lite-8K-0922":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "Yi-34B-Chat":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
default:
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
suffix = "embeddings/"
}
switch info.UpstreamModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-3.5-4K-0205":
suffix += "ernie-3.5-4k-0205"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
suffix += "embedding-v1"
case "bge-large-zh":
suffix += "bge_large_zh"
case "bge-large-en":
suffix += "bge_large_en"
case "tao-8k":
suffix += "tao_8k"
default:
suffix += strings.ToLower(info.UpstreamModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
@@ -82,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,20 +1,22 @@
package baidu
var ModelList = []string{
"ERNIE-3.5-8K",
"ERNIE-4.0-8K",
"ERNIE-3.5-8K",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K",
"ERNIE-Lite-8K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K",
"ERNIE-Character-8K",
"ERNIE-Functions-8K",
//"ERNIE-Bot-4",
//"ERNIE-Bot-8K",
//"ERNIE-Bot",
//"ERNIE-Speed",
//"ERNIE-Bot-turbo",
"BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}
var ChannelName = "baidu"

View File

@@ -11,9 +11,16 @@ type BaiduMessage struct {
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
Messages []BaiduMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}
type Error struct {

View File

@@ -22,17 +22,27 @@ import (
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
baiduRequest := BaiduChatRequest{
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: int(request.MaxTokens),
UserId: request.User,
}
for _, message := range request.Messages {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
if message.Role == "system" {
baiduRequest.System = message.StringContent()
} else {
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &baiduRequest
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {

View File

@@ -21,6 +21,11 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
@@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -8,16 +8,24 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return requestConvertRerank2Cohere(request), nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
}
return
}

View File

@@ -2,6 +2,7 @@ package cohere
var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
}
var ChannelName = "cohere"

View File

@@ -1,5 +1,7 @@
package cohere
import "one-api/dto"
type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatHistory `json:"chat_history"`
@@ -28,6 +30,19 @@ type CohereResponseResult struct {
Meta CohereMeta `json:"meta"`
}
type CohereRerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents"`
}
type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {
//Tokens CohereTokens `json:"tokens"`
BilledUnits CohereBilledUnits `json:"billed_units"`

View File

@@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
return &cohereReq
}
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
if rerankRequest.TopN == 0 {
rerankRequest.TopN = 1
}
cohereReq := CohereRerankRequest{
Query: rerankRequest.Query,
Documents: rerankRequest.Documents,
Model: rerankRequest.Model,
TopN: rerankRequest.TopN,
ReturnDocuments: true,
}
return &cohereReq
}
func stopReasonCohere2OpenAI(reason string) string {
switch reason {
case "COMPLETE":
@@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = 0
usage.TotalTokens = info.PromptTokens
} else {
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
}
var rerankResp dto.RerankResponse
rerankResp.Results = cohereResp.Results
rerankResp.Usage = usage
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,65 @@
package dify
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Dify(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = difyStreamHandler(c, resp, info)
} else {
err, usage = difyHandler(c, resp, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package dify
var ModelList []string
var ChannelName = "dify"

35
relay/channel/dify/dto.go Normal file
View File

@@ -0,0 +1,35 @@
package dify
import "one-api/dto"
type DifyChatRequest struct {
Inputs map[string]interface{} `json:"inputs"`
Query string `json:"query"`
ResponseMode string `json:"response_mode"`
User string `json:"user"`
AutoGenerateName bool `json:"auto_generate_name"`
}
type DifyMetaData struct {
Usage dto.Usage `json:"usage"`
}
type DifyData struct {
WorkflowId string `json:"workflow_id"`
NodeId string `json:"node_id"`
}
type DifyChatCompletionResponse struct {
ConversationId string `json:"conversation_id"`
Answers string `json:"answers"`
CreateAt int64 `json:"create_at"`
MetaData DifyMetaData `json:"metadata"`
}
type DifyChunkChatCompletionResponse struct {
Event string `json:"event"`
ConversationId string `json:"conversation_id"`
Answer string `json:"answer"`
Data DifyData `json:"data"`
MetaData DifyMetaData `json:"metadata"`
}

View File

@@ -0,0 +1,154 @@
package dify
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
content := ""
for _, message := range request.Messages {
if message.Role == "system" {
content += "SYSTEM: \n" + message.StringContent() + "\n"
} else if message.Role == "assistant" {
content += "ASSISTANT: \n" + message.StringContent() + "\n"
} else {
content += "USER: \n" + message.StringContent() + "\n"
}
}
mode := "blocking"
if request.Stream {
mode = "streaming"
}
user := request.User
if user == "" {
user = "api-user"
}
return &DifyChatRequest{
Inputs: make(map[string]interface{}),
Query: content,
ResponseMode: mode,
User: user,
AutoGenerateName: false,
}
}
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "dify",
}
var choice dto.ChatCompletionsStreamResponseChoice
if difyResponse.Event == "workflow_started" {
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" {
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)
return &response
}
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var responseText string
usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage
break
} else if difyResponse.Event == "error" {
break
} else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString()
}
}
err = service.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, usage
}
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := dto.OpenAITextResponse{
Id: difyResponse.ConversationId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
content, _ := json.Marshal(difyResponse.Answers)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &difyResponse.MetaData.Usage
}

View File

@@ -15,6 +15,9 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -56,6 +59,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return CovertGemini2OpenAI(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -0,0 +1,64 @@
package jina
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,8 @@
package jina
var ModelList = []string{
"jina-clip-v1",
"jina-reranker-v2-base-multilingual",
}
var ChannelName = "jina"

View File

@@ -0,0 +1,35 @@
package jina
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.RerankResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -16,6 +16,9 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -22,6 +22,13 @@ type Adaptor struct {
ChannelType int
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType
}

View File

@@ -133,13 +133,19 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout))
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
case <-ticker.C:
common.LogError(c, "reading data from upstream timeout")
return false
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
ticker.Reset(time.Duration(constant.StreamingTimeout))
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}

View File

@@ -15,6 +15,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return requestOpenAI2Perplexity(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -22,6 +22,11 @@ type Adaptor struct {
Timestamp int64
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.Action = "ChatCompletions"
a.Version = "2023-09-01"
@@ -57,6 +62,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return tencentRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,9 +1,10 @@
package tencent
var ModelList = []string{
"ChatPro",
"ChatStd",
"hunyuan",
"hunyuan-lite",
"hunyuan-standard",
"hunyuan-standard-256K",
"hunyuan-pro",
}
var ChannelName = "tencent"

View File

@@ -16,6 +16,11 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}

View File

@@ -14,6 +14,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -20,6 +20,8 @@ const (
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDify
APITypeJina
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -57,6 +59,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeAws
case common.ChannelTypeCohere:
apiType = APITypeCohere
case common.ChannelTypeDify:
apiType = APITypeDify
case common.ChannelTypeJina:
apiType = APITypeJina
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -32,6 +32,7 @@ const (
RelayModeSunoFetch
RelayModeSunoFetchByID
RelayModeSunoSubmit
RelayModeRerank
)
func Path2RelayMode(path string) int {
@@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
} else if strings.HasPrefix(path, "/v1/rerank") {
relayMode = RelayModeRerank
}
return relayMode
}

View File

@@ -544,7 +544,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if err != nil {
common.SysError("get_channel_null: " + err.Error())
}
if channel.AutoBan != nil && *channel.AutoBan == 1 {
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
}
}

View File

@@ -182,7 +182,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
return nil
}
@@ -272,7 +272,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
}
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool) {
@@ -281,7 +281,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
completionTokens := usage.CompletionTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(textRequest.Model)
completionRatio := common.GetCompletionRatio(modelName)
quota := 0
if !usePrice {
@@ -307,7 +307,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
//if sensitiveResp != nil {
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
@@ -327,13 +328,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := textRequest.Model
logModel := modelName
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
logContent += fmt.Sprintf(",模型 %s", modelName)
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
//if quota != 0 {
//

View File

@@ -8,7 +8,9 @@ import (
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/cohere"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
"one-api/relay/channel/jina"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
@@ -53,6 +55,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &aws.Adaptor{}
case constant.APITypeCohere:
return &cohere.Adaptor{}
case constant.APITypeDify:
return &dify.Adaptor{}
case constant.APITypeJina:
return &jina.Adaptor{}
}
return nil
}

104
relay/relay_rerank.go Normal file
View File

@@ -0,0 +1,104 @@
package relay
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
for _, document := range rerankRequest.Documents {
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
if err == nil {
token += tkm
}
}
return token
}
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
var rerankRequest *dto.RerankRequest
err := common.UnmarshalBodyReusable(c, &rerankRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
if rerankRequest.Query == "" {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
}
if len(rerankRequest.Documents) == 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
promptToken := getRerankPromptToken(*rerankRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(rerankRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.InitRerank(relayInfo, *rerankRequest)
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
return nil
}

View File

@@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
relayV1Router.POST("/rerank", controller.Relay)
}
relayMjRouter := router.Group("/mj")

View File

@@ -104,6 +104,8 @@ export const CHANNEL_OPTIONS = [
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
{
key: 22,