mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 07:56:38 +08:00
merge upstream
Signed-off-by: wozulong <>
This commit is contained in:
commit
0cc7f5cca6
1
.github/workflows/docker-image-arm64.yml
vendored
1
.github/workflows/docker-image-arm64.yml
vendored
@ -4,6 +4,7 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
|
@ -2,6 +2,21 @@
|
||||
|
||||
**简介**:Midjourney Proxy API文档
|
||||
|
||||
## 接口列表
|
||||
支持的接口如下:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
+ [x] /mj/submit/describe
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||
+ [x] /mj/submit/modal
|
||||
+ [x] /mj/submit/shorten
|
||||
+ [x] /mj/task/{id}/image-seed
|
||||
+ [x] /mj/insight-face/swap (InsightFace)
|
||||
|
||||
## 模型列表
|
||||
|
||||
### midjourney-proxy支持
|
||||
|
28
README.md
28
README.md
@ -16,19 +16,7 @@
|
||||
此分叉版本的主要变更如下:
|
||||
|
||||
1. 全新的UI界面(部分界面还待更新)
|
||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
+ [x] /mj/submit/describe
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||
+ [x] /mj/submit/modal
|
||||
+ [x] /mj/submit/shorten
|
||||
+ [x] /mj/task/{id}/image-seed
|
||||
+ [x] /mj/insight-face/swap (InsightFace)
|
||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
|
||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||
+ [x] 易支付
|
||||
4. 支持用key查询使用额度:
|
||||
@ -45,22 +33,21 @@
|
||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
|
||||
+ [x] /suno/submit/music
|
||||
+ [x] /suno/submit/lyrics
|
||||
+ [x] /suno/fetch
|
||||
+ [x] /suno/fetch/:id
|
||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
|
||||
14. 支持Rerank模型,目前仅兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*, g-*)
|
||||
2. 智谱glm-4v,glm-4v识图
|
||||
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
||||
3. Anthropic Claude 3
|
||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||
6. [零一万物](https://platform.lingyiwanwu.com/)
|
||||
7. 自定义渠道,支持填入完整调用地址
|
||||
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||
9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
||||
10. Dify
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
@ -85,7 +72,8 @@
|
||||
|
||||
## 比原版One API多出的配置
|
||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
||||
|
||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false`
|
||||
- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型
|
||||
## 部署
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
|
62
Rerank.md
Normal file
62
Rerank.md
Normal file
@ -0,0 +1,62 @@
|
||||
# Rerank API文档
|
||||
|
||||
**简介**:Rerank API文档
|
||||
|
||||
## 接入Dify
|
||||
模型供应商选择Jina,按要求填写模型信息即可接入Dify。
|
||||
|
||||
## 请求方式
|
||||
|
||||
Post: /v1/rerank
|
||||
|
||||
Request:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "rerank-multilingual-v3.0",
|
||||
"query": "What is the capital of the United States?",
|
||||
"top_n": 3,
|
||||
"documents": [
|
||||
"Carson City is the capital city of the American state of Nevada.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"document": {
|
||||
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
|
||||
},
|
||||
"index": 2,
|
||||
"relevance_score": 0.9999702
|
||||
},
|
||||
{
|
||||
"document": {
|
||||
"text": "Carson City is the capital city of the American state of Nevada."
|
||||
},
|
||||
"index": 0,
|
||||
"relevance_score": 0.67800725
|
||||
},
|
||||
{
|
||||
"document": {
|
||||
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
|
||||
},
|
||||
"index": 3,
|
||||
"relevance_score": 0.02800752
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 158,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 158
|
||||
}
|
||||
}
|
||||
```
|
7
Suno.md
7
Suno.md
@ -2,6 +2,13 @@
|
||||
|
||||
**简介**:Suno API文档
|
||||
|
||||
## 接口列表
|
||||
支持的接口如下:
|
||||
+ [x] /suno/submit/music
|
||||
+ [x] /suno/submit/lyrics
|
||||
+ [x] /suno/fetch
|
||||
+ [x] /suno/fetch/:id
|
||||
|
||||
## 模型列表
|
||||
|
||||
### Suno API支持
|
||||
|
@ -233,36 +233,38 @@ const (
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", //23
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
@ -276,4 +278,6 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"", //37
|
||||
"https://api.jina.ai", //38
|
||||
}
|
||||
|
@ -24,3 +24,15 @@ func GetEnvOrDefaultString(env string, defaultValue string) string {
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
||||
|
||||
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
b, err := strconv.ParseBool(os.Getenv(env))
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
@ -78,18 +78,22 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens //renamed to ERNIE-4.0-8K
|
||||
"ERNIE-4.0-8K": 8.572, // ¥0.12 / 1k tokens
|
||||
"ERNIE-3.5-8K": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Speed-8K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"ERNIE-Speed-128K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"ERNIE-Lite-8K": 0.2143, // ¥0.003 / 1k tokens
|
||||
"ERNIE-Tiny-8K": 0.0715, // ¥0.001 / 1k tokens
|
||||
"ERNIE-Character-8K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"ERNIE-Functions-8K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
||||
"ERNIE-3.5-8K-1222": 0.012 * RMB,
|
||||
"ERNIE-Bot-8K": 0.024 * RMB,
|
||||
"ERNIE-3.5-4K-0205": 0.012 * RMB,
|
||||
"ERNIE-Speed-8K": 0.004 * RMB,
|
||||
"ERNIE-Speed-128K": 0.004 * RMB,
|
||||
"ERNIE-Lite-8K-0922": 0.008 * RMB,
|
||||
"ERNIE-Lite-8K-0308": 0.003 * RMB,
|
||||
"ERNIE-Tiny-8K": 0.001 * RMB,
|
||||
"BLOOMZ-7B": 0.004 * RMB,
|
||||
"Embedding-V1": 0.002 * RMB,
|
||||
"bge-large-zh": 0.002 * RMB,
|
||||
"bge-large-en": 0.002 * RMB,
|
||||
"tao-8k": 0.002 * RMB,
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
|
@ -5,3 +5,7 @@ import (
|
||||
)
|
||||
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
|
@ -4,6 +4,7 @@ var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
var MjActionCheckSuccessEnabled = true
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -24,6 +25,7 @@ import (
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||
tik := time.Now()
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
@ -65,7 +67,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
if len(channel.GetModels()) > 0 {
|
||||
testModel = channel.GetModels()[0]
|
||||
} else {
|
||||
testModel = "gpt-3.5-turbo"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
modelMapping := *channel.ModelMapping
|
||||
@ -115,11 +121,29 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
result := w.Result()
|
||||
// print result.Body
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
modelPrice, usePrice := common.GetModelPrice(testModel, false)
|
||||
modelRatio := common.GetModelRatio(testModel)
|
||||
completionRatio := common.GetCompletionRatio(testModel)
|
||||
ratio := modelRatio
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
|
||||
quota = int(math.Round(float64(quota) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit)
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
@ -140,7 +164,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
|
||||
}
|
||||
|
||||
func TestChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@ -148,7 +172,7 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err = relay.AudioHelper(c, relayMode)
|
||||
case relayconstant.RelayModeRerank:
|
||||
err = relay.RerankHelper(c, relayMode)
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
|
19
dto/rerank.go
Normal file
19
dto/rerank.go
Normal file
@ -0,0 +1,19 @@
|
||||
package dto
|
||||
|
||||
type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
}
|
||||
|
||||
type RerankResponseDocument struct {
|
||||
Document any `json:"document"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Results []RerankResponseDocument `json:"results"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
@ -13,7 +13,7 @@ type GeneralOpenAIRequest struct {
|
||||
BestOf int `json:"best_of,omitempty"`
|
||||
Echo bool `json:"echo,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions any `json:"stream_options,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
@ -48,6 +48,10 @@ type OpenAIFunction struct {
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||
return int64(r.MaxTokens)
|
||||
}
|
||||
|
@ -102,10 +102,12 @@ type ChatCompletionsStreamResponse struct {
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint *string `json:"system_fingerprint"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseSimple struct {
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@ -33,6 +34,13 @@ type Channel struct {
|
||||
OtherInfo string `json:"other_info"`
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
if channel.Models == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||||
}
|
||||
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
|
@ -103,6 +103,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
@ -218,6 +219,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
case "MjForwardUrlEnabled":
|
||||
constant.MjForwardUrlEnabled = boolValue
|
||||
case "MjActionCheckSuccessEnabled":
|
||||
constant.MjActionCheckSuccessEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
|
@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if err != nil {
|
||||
return 0, errors.New("兑换失败," + err.Error())
|
||||
}
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
|
||||
return redemption.Quota, nil
|
||||
}
|
||||
|
||||
|
@ -249,11 +249,9 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
|
||||
if userQuota < quota {
|
||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
}
|
||||
if !token.UnlimitedQuota {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = DecreaseUserQuota(token.UserId, quota)
|
||||
return userQuota - quota, err
|
||||
@ -271,15 +269,13 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
||||
return err
|
||||
}
|
||||
|
||||
if !token.UnlimitedQuota {
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(tokenId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(tokenId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sendEmail {
|
||||
|
@ -315,7 +315,8 @@ func (user *User) ValidateAndFill() (err error) {
|
||||
if user.Username == "" || password == "" {
|
||||
return errors.New("用户名或密码为空")
|
||||
}
|
||||
DB.Where(User{Username: user.Username}).First(user)
|
||||
// find buy username or email
|
||||
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
|
||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||
if !okay || user.Status != common.UserStatusEnabled {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
|
@ -11,9 +11,11 @@ import (
|
||||
type Adaptor interface {
|
||||
// Init IsStream bool
|
||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
|
@ -15,6 +15,9 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
@ -20,6 +20,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
@ -53,13 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a.RequestMode)
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
relaymodel "one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
|
||||
@ -214,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package baidu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -15,44 +16,74 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.UpstreamModelName {
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "ERNIE-Bot-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Speed":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "ERNIE-4.0-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "ERNIE-3.5-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Speed-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
case "ERNIE-Character-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k"
|
||||
case "ERNIE-Functions-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k"
|
||||
case "ERNIE-Lite-8K-0922":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "Yi-34B-Chat":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
default:
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName)
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
suffix := "chat/"
|
||||
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
switch info.UpstreamModelName {
|
||||
case "ERNIE-4.0":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot-4":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot":
|
||||
suffix += "completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Speed":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-4.0-8K":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-3.5-8K":
|
||||
suffix += "completions"
|
||||
case "ERNIE-3.5-8K-0205":
|
||||
suffix += "ernie-3.5-8k-0205"
|
||||
case "ERNIE-3.5-8K-1222":
|
||||
suffix += "ernie-3.5-8k-1222"
|
||||
case "ERNIE-Bot-8K":
|
||||
suffix += "ernie_bot_8k"
|
||||
case "ERNIE-3.5-4K-0205":
|
||||
suffix += "ernie-3.5-4k-0205"
|
||||
case "ERNIE-Speed-8K":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-Speed-128K":
|
||||
suffix += "ernie-speed-128k"
|
||||
case "ERNIE-Lite-8K-0922":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Lite-8K-0308":
|
||||
suffix += "ernie-lite-8k"
|
||||
case "ERNIE-Tiny-8K":
|
||||
suffix += "ernie-tiny-8k"
|
||||
case "BLOOMZ-7B":
|
||||
suffix += "bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
suffix += "embedding-v1"
|
||||
case "bge-large-zh":
|
||||
suffix += "bge_large_zh"
|
||||
case "bge-large-en":
|
||||
suffix += "bge_large_en"
|
||||
case "tao-8k":
|
||||
suffix += "tao_8k"
|
||||
default:
|
||||
suffix += strings.ToLower(info.UpstreamModelName)
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||
@ -82,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
@ -1,20 +1,22 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-3.5-8K",
|
||||
"ERNIE-4.0-8K",
|
||||
"ERNIE-3.5-8K",
|
||||
"ERNIE-3.5-8K-0205",
|
||||
"ERNIE-3.5-8K-1222",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-3.5-4K-0205",
|
||||
"ERNIE-Speed-8K",
|
||||
"ERNIE-Speed-128K",
|
||||
"ERNIE-Lite-8K",
|
||||
"ERNIE-Lite-8K-0922",
|
||||
"ERNIE-Lite-8K-0308",
|
||||
"ERNIE-Tiny-8K",
|
||||
"ERNIE-Character-8K",
|
||||
"ERNIE-Functions-8K",
|
||||
//"ERNIE-Bot-4",
|
||||
//"ERNIE-Bot-8K",
|
||||
//"ERNIE-Bot",
|
||||
//"ERNIE-Speed",
|
||||
//"ERNIE-Bot-turbo",
|
||||
"BLOOMZ-7B",
|
||||
"Embedding-V1",
|
||||
"bge-large-zh",
|
||||
"bge-large-en",
|
||||
"tao-8k",
|
||||
}
|
||||
|
||||
var ChannelName = "baidu"
|
||||
|
@ -11,9 +11,16 @@ type BaiduMessage struct {
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
|
@ -22,17 +22,33 @@ import (
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
baiduRequest := BaiduChatRequest{
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
UserId: request.User,
|
||||
}
|
||||
if request.MaxTokens != 0 {
|
||||
maxTokens := int(request.MaxTokens)
|
||||
if request.MaxTokens == 1 {
|
||||
maxTokens = 2
|
||||
}
|
||||
baiduRequest.MaxOutputTokens = &maxTokens
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
if message.Role == "system" {
|
||||
baiduRequest.System = message.StringContent()
|
||||
} else {
|
||||
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &baiduRequest
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||
|
@ -21,6 +21,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
@ -330,22 +330,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
response.Created = createdTime
|
||||
response.Model = info.UpstreamModelName
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
err = service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
@ -356,6 +349,18 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
|
@ -8,16 +8,24 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
} else {
|
||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return requestConvertRerank2Cohere(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info)
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package cohere
|
||||
|
||||
var ModelList = []string{
|
||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
|
||||
}
|
||||
|
||||
var ChannelName = "cohere"
|
||||
|
@ -1,5 +1,7 @@
|
||||
package cohere
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type CohereRequest struct {
|
||||
Model string `json:"model"`
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
@ -28,6 +30,19 @@ type CohereResponseResult struct {
|
||||
Meta CohereMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type CohereRerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
ReturnDocuments bool `json:"return_documents"`
|
||||
}
|
||||
|
||||
type CohereRerankResponseResult struct {
|
||||
Results []dto.RerankResponseDocument `json:"results"`
|
||||
Meta CohereMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type CohereMeta struct {
|
||||
//Tokens CohereTokens `json:"tokens"`
|
||||
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||
|
@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
return &cohereReq
|
||||
}
|
||||
|
||||
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
|
||||
if rerankRequest.TopN == 0 {
|
||||
rerankRequest.TopN = 1
|
||||
}
|
||||
cohereReq := CohereRerankRequest{
|
||||
Query: rerankRequest.Query,
|
||||
Documents: rerankRequest.Documents,
|
||||
Model: rerankRequest.Model,
|
||||
TopN: rerankRequest.TopN,
|
||||
ReturnDocuments: true,
|
||||
}
|
||||
return &cohereReq
|
||||
}
|
||||
|
||||
func stopReasonCohere2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "COMPLETE":
|
||||
@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var cohereResp CohereRerankResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
} else {
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
||||
}
|
||||
|
||||
var rerankResp dto.RerankResponse
|
||||
rerankResp.Results = cohereResp.Results
|
||||
rerankResp.Usage = usage
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
65
relay/channel/dify/adaptor.go
Normal file
65
relay/channel/dify/adaptor.go
Normal file
@ -0,0 +1,65 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Dify(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = difyStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = difyHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
5
relay/channel/dify/constants.go
Normal file
5
relay/channel/dify/constants.go
Normal file
@ -0,0 +1,5 @@
|
||||
package dify
|
||||
|
||||
var ModelList []string
|
||||
|
||||
var ChannelName = "dify"
|
35
relay/channel/dify/dto.go
Normal file
35
relay/channel/dify/dto.go
Normal file
@ -0,0 +1,35 @@
|
||||
package dify
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type DifyChatRequest struct {
|
||||
Inputs map[string]interface{} `json:"inputs"`
|
||||
Query string `json:"query"`
|
||||
ResponseMode string `json:"response_mode"`
|
||||
User string `json:"user"`
|
||||
AutoGenerateName bool `json:"auto_generate_name"`
|
||||
}
|
||||
|
||||
type DifyMetaData struct {
|
||||
Usage dto.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type DifyData struct {
|
||||
WorkflowId string `json:"workflow_id"`
|
||||
NodeId string `json:"node_id"`
|
||||
}
|
||||
|
||||
type DifyChatCompletionResponse struct {
|
||||
ConversationId string `json:"conversation_id"`
|
||||
Answer string `json:"answer"`
|
||||
CreateAt int64 `json:"create_at"`
|
||||
MetaData DifyMetaData `json:"metadata"`
|
||||
}
|
||||
|
||||
type DifyChunkChatCompletionResponse struct {
|
||||
Event string `json:"event"`
|
||||
ConversationId string `json:"conversation_id"`
|
||||
Answer string `json:"answer"`
|
||||
Data DifyData `json:"data"`
|
||||
MetaData DifyMetaData `json:"metadata"`
|
||||
}
|
156
relay/channel/dify/relay-dify.go
Normal file
156
relay/channel/dify/relay-dify.go
Normal file
@ -0,0 +1,156 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
|
||||
content := ""
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
content += "SYSTEM: \n" + message.StringContent() + "\n"
|
||||
} else if message.Role == "assistant" {
|
||||
content += "ASSISTANT: \n" + message.StringContent() + "\n"
|
||||
} else {
|
||||
content += "USER: \n" + message.StringContent() + "\n"
|
||||
}
|
||||
}
|
||||
mode := "blocking"
|
||||
if request.Stream {
|
||||
mode = "streaming"
|
||||
}
|
||||
user := request.User
|
||||
if user == "" {
|
||||
user = "api-user"
|
||||
}
|
||||
return &DifyChatRequest{
|
||||
Inputs: make(map[string]interface{}),
|
||||
Query: content,
|
||||
ResponseMode: mode,
|
||||
User: user,
|
||||
AutoGenerateName: false,
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "dify",
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
|
||||
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
||||
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
||||
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
||||
} else if difyResponse.Event == "message" {
|
||||
choice.Delta.SetContentString(difyResponse.Answer)
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
return &response
|
||||
}
|
||||
|
||||
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var responseText string
|
||||
usage := &dto.Usage{}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||
if difyResponse.Event == "message_end" {
|
||||
usage = &difyResponse.MetaData.Usage
|
||||
break
|
||||
} else if difyResponse.Event == "error" {
|
||||
break
|
||||
} else {
|
||||
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
}
|
||||
err = service.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
service.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var difyResponse DifyChatCompletionResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &difyResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: difyResponse.ConversationId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: difyResponse.MetaData.Usage,
|
||||
}
|
||||
content, _ := json.Marshal(difyResponse.Answer)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &difyResponse.MetaData.Usage
|
||||
}
|
@ -9,12 +9,14 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@ -56,15 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return CovertGemini2OpenAI(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = geminiChatStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseText := ""
|
||||
responseJson := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
dataChan := make(chan string, 5)
|
||||
stopChan := make(chan bool, 2)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
responseJson += data
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.SetContentString(dummy.Content)
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Created: createAt,
|
||||
Model: info.UpstreamModelName,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
var geminiChatResponses []GeminiChatResponse
|
||||
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
log.Printf("cannot get gemini usage: %s", err.Error())
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
for _, response := range geminiChatResponses {
|
||||
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return nil, responseText
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
|
64
relay/channel/jina/adaptor.go
Normal file
64
relay/channel/jina/adaptor.go
Normal file
@ -0,0 +1,64 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = jinaRerankHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
8
relay/channel/jina/constant.go
Normal file
8
relay/channel/jina/constant.go
Normal file
@ -0,0 +1,8 @@
|
||||
package jina
|
||||
|
||||
var ModelList = []string{
|
||||
"jina-clip-v1",
|
||||
"jina-reranker-v2-base-multilingual",
|
||||
}
|
||||
|
||||
var ChannelName = "jina"
|
35
relay/channel/jina/relay-jina.go
Normal file
35
relay/channel/jina/relay-jina.go
Normal file
@ -0,0 +1,35 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var jinaResp dto.RerankResponse
|
||||
err = json.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &jinaResp.Usage
|
||||
}
|
@ -16,6 +16,9 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
@ -52,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
|
@ -22,6 +22,13 @@ type Adaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
@ -82,9 +89,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
var toolCount int
|
||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
@ -18,9 +18,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
|
||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
var responseTextBuilder strings.Builder
|
||||
var usage dto.Usage
|
||||
toolCount := 0
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@ -62,17 +63,24 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
}
|
||||
// 计算token
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
if streamResponse.Usage != nil {
|
||||
if streamResponse.Usage.TotalTokens != 0 {
|
||||
usage = *streamResponse.Usage
|
||||
}
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
@ -89,6 +97,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
if streamResponse.Usage != nil {
|
||||
if streamResponse.Usage.TotalTokens != 0 {
|
||||
usage = *streamResponse.Usage
|
||||
}
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
@ -107,6 +120,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
@ -133,13 +147,19 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
isFirst := true
|
||||
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
defer ticker.Stop()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
common.LogError(c, "reading data from upstream timeout")
|
||||
return false
|
||||
case data := <-dataChan:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
@ -153,10 +173,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
|
||||
}
|
||||
wg.Wait()
|
||||
return nil, responseTextBuilder.String(), toolCount
|
||||
return nil, &usage, responseTextBuilder.String(), toolCount
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
|
@ -15,6 +15,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return requestOpenAI2Perplexity(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
@ -46,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
@ -6,28 +6,44 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
Sign string
|
||||
AppID int64
|
||||
Action string
|
||||
Version string
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
a.Action = "ChatCompletions"
|
||||
a.Version = "2023-09-01"
|
||||
a.Timestamp = common.GetTimestamp()
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
||||
req.Header.Set("X-TC-Action", a.Action)
|
||||
req.Header.Set("X-TC-Version", a.Version)
|
||||
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -38,17 +54,20 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
a.AppID = appId
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := requestOpenAI2Tencent(*request)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
tencentRequest := requestOpenAI2Tencent(a, *request)
|
||||
// we have to calculate the sign here
|
||||
a.Sign = getTencentSign(*tencentRequest, secretKey)
|
||||
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
package tencent
|
||||
|
||||
var ModelList = []string{
|
||||
"ChatPro",
|
||||
"ChatStd",
|
||||
"hunyuan",
|
||||
"hunyuan-lite",
|
||||
"hunyuan-standard",
|
||||
"hunyuan-standard-256K",
|
||||
"hunyuan-pro",
|
||||
}
|
||||
|
||||
var ChannelName = "tencent"
|
||||
|
@ -1,62 +1,75 @@
|
||||
package tencent
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"Role"`
|
||||
Content string `json:"Content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
Model string `json:"model"` // 模型名称
|
||||
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
|
||||
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
|
||||
//
|
||||
// 注意:
|
||||
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
|
||||
Model *string `json:"Model"`
|
||||
// 聊天上下文信息。
|
||||
// 说明:
|
||||
// 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
|
||||
// 2. Message.Role 可选值:system、user、assistant。
|
||||
// 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
|
||||
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
|
||||
Messages []*TencentMessage `json:"Messages"`
|
||||
// 流式调用开关。
|
||||
// 说明:
|
||||
// 1. 未传值时默认为非流式调用(false)。
|
||||
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
|
||||
// 3. 非流式调用时:
|
||||
// 调用方式与普通 HTTP 请求无异。
|
||||
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
|
||||
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
|
||||
//
|
||||
// 注意:
|
||||
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
|
||||
Stream *bool `json:"Stream,omitempty"`
|
||||
// 说明:
|
||||
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
TopP *float64 `json:"TopP,omitempty"`
|
||||
// 说明:
|
||||
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
Temperature *float64 `json:"Temperature,omitempty"`
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Code int `json:"Code"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokens int `json:"PromptTokens"`
|
||||
CompletionTokens int `json:"CompletionTokens"`
|
||||
TotalTokens int `json:"TotalTokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage dto.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
|
||||
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"Id,omitempty"` // 会话 id
|
||||
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"Note,omitempty"` // 注释
|
||||
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
type TencentChatResponseSB struct {
|
||||
Response TencentChatResponse `json:"Response,omitempty"`
|
||||
}
|
||||
|
@ -3,8 +3,8 @@ package tencent
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -15,54 +15,46 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
messages := make([]*TencentMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
continue
|
||||
}
|
||||
messages = append(messages, TencentMessage{
|
||||
messages = append(messages, &TencentMessage{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
stream := 0
|
||||
if request.Stream {
|
||||
stream = 1
|
||||
var req = TencentChatRequest{
|
||||
Stream: &request.Stream,
|
||||
Messages: messages,
|
||||
Model: &request.Model,
|
||||
}
|
||||
return &TencentChatRequest{
|
||||
Timestamp: common.GetTimestamp(),
|
||||
Expired: common.GetTimestamp() + 24*60*60,
|
||||
QueryID: common.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
Messages: messages,
|
||||
Model: request.Model,
|
||||
if request.TopP != 0 {
|
||||
req.TopP = &request.TopP
|
||||
}
|
||||
if request.Temperature != 0 {
|
||||
req.Temperature = &request.Temperature
|
||||
}
|
||||
return &req
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||
@ -99,69 +91,51 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
})
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
|
||||
var tencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response := streamResponseTencent2OpenAI(&tencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
|
||||
err = service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var TencentResponse TencentChatResponse
|
||||
var tencentSb TencentChatResponseSB
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@ -170,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
if tencentSb.Response.Error.Code != 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
Message: tencentSb.Response.Error.Message,
|
||||
Code: tencentSb.Response.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@ -206,29 +180,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
|
||||
return
|
||||
}
|
||||
|
||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
||||
params := make([]string, 0)
|
||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
params = append(params, "secret_id="+req.SecretId)
|
||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
params = append(params, "query_id="+req.QueryID)
|
||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
var messageStr string
|
||||
for _, msg := range req.Messages {
|
||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
}
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Sort(sort.StringSlice(params))
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
mac.Write([]byte(signURL))
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
func sha256hex(s string) string {
|
||||
b := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func hmacSha256(s, key string) string {
|
||||
hashed := hmac.New(sha256.New, []byte(key))
|
||||
hashed.Write([]byte(s))
|
||||
return string(hashed.Sum(nil))
|
||||
}
|
||||
|
||||
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
||||
// build canonical request string
|
||||
host := "hunyuan.tencentcloudapi.com"
|
||||
httpRequestMethod := "POST"
|
||||
canonicalURI := "/"
|
||||
canonicalQueryString := ""
|
||||
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
||||
"application/json", host, strings.ToLower(adaptor.Action))
|
||||
signedHeaders := "content-type;host;x-tc-action"
|
||||
payload, _ := json.Marshal(req)
|
||||
hashedRequestPayload := sha256hex(string(payload))
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
httpRequestMethod,
|
||||
canonicalURI,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
hashedRequestPayload)
|
||||
// build string to sign
|
||||
algorithm := "TC3-HMAC-SHA256"
|
||||
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
|
||||
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
|
||||
t := time.Unix(timestamp, 0).UTC()
|
||||
// must be the format 2006-01-02, ref to package time for more info
|
||||
date := t.Format("2006-01-02")
|
||||
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
|
||||
hashedCanonicalRequest := sha256hex(canonicalRequest)
|
||||
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
|
||||
algorithm,
|
||||
requestTimestamp,
|
||||
credentialScope,
|
||||
hashedCanonicalRequest)
|
||||
|
||||
// sign string
|
||||
secretDate := hmacSha256(date, "TC3"+secKey)
|
||||
secretService := hmacSha256("hunyuan", secretDate)
|
||||
secretKey := hmacSha256("tc3_request", secretService)
|
||||
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
|
||||
|
||||
// build authorization
|
||||
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
algorithm,
|
||||
secId,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature)
|
||||
return authorization
|
||||
}
|
||||
|
@ -16,6 +16,11 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
|
@ -14,6 +14,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return requestOpenAI2Zhipu(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return requestOpenAI2Zhipu(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
@ -48,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
var toolCount int
|
||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
@ -9,24 +9,26 @@ import (
|
||||
)
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
ApiType int
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
ApiType int
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
@ -65,6 +67,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
|
@ -38,6 +38,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
|
||||
var textResponse dto.TextResponseWithError
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
|
||||
return
|
||||
}
|
||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
||||
|
@ -20,6 +20,8 @@ const (
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@ -57,6 +59,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeAws
|
||||
case common.ChannelTypeCohere:
|
||||
apiType = APITypeCohere
|
||||
case common.ChannelTypeDify:
|
||||
apiType = APITypeDify
|
||||
case common.ChannelTypeJina:
|
||||
apiType = APITypeJina
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
@ -32,6 +32,7 @@ const (
|
||||
RelayModeSunoFetch
|
||||
RelayModeSunoFetchByID
|
||||
RelayModeSunoSubmit
|
||||
RelayModeRerank
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
|
||||
relayMode = RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
} else if strings.HasPrefix(path, "/v1/rerank") {
|
||||
relayMode = RelayModeRerank
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
@ -415,9 +415,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
if constant.MjActionCheckSuccessEnabled {
|
||||
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
}
|
||||
}
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||
@ -500,7 +503,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
@ -544,7 +547,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if err != nil {
|
||||
common.SysError("get_channel_null: " + err.Error())
|
||||
}
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 1 {
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
|
||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||
}
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
// set upstream model name
|
||||
isModelMapped = true
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
@ -130,33 +130,38 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
||||
textRequest.StreamOptions = nil
|
||||
} else {
|
||||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||
if constant.ForceStreamOption {
|
||||
textRequest.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo, *textRequest)
|
||||
var requestBody io.Reader
|
||||
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
@ -182,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -281,7 +286,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
||||
}
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, usePrice bool) {
|
||||
|
||||
@ -290,7 +295,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
completionTokens := usage.CompletionTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
@ -316,7 +321,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
} else {
|
||||
//if sensitiveResp != nil {
|
||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
@ -336,16 +342,17 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
logModel := textRequest.Model
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
} else if strings.HasPrefix(logModel, "g-") {
|
||||
logModel = "g-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
|
@ -8,7 +8,9 @@ import (
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/cohere"
|
||||
"one-api/relay/channel/dify"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
@ -53,6 +55,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &aws.Adaptor{}
|
||||
case constant.APITypeCohere:
|
||||
return &cohere.Adaptor{}
|
||||
case constant.APITypeDify:
|
||||
return &dify.Adaptor{}
|
||||
case constant.APITypeJina:
|
||||
return &jina.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
104
relay/relay_rerank.go
Normal file
104
relay/relay_rerank.go
Normal file
@ -0,0 +1,104 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
for _, document := range rerankRequest.Documents {
|
||||
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
||||
if err == nil {
|
||||
token += tkm
|
||||
}
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
var rerankRequest *dto.RerankRequest
|
||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
if rerankRequest.Query == "" {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
||||
}
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||
}
|
||||
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
if !success {
|
||||
preConsumedTokens := promptToken
|
||||
modelRatio = common.GetModelRatio(rerankRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.InitRerank(relayInfo, *rerankRequest)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
return nil
|
||||
}
|
@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
relayV1Router.POST("/rerank", controller.Relay)
|
||||
}
|
||||
|
||||
relayMjRouter := router.Group("/mj")
|
||||
|
@ -1,6 +1,12 @@
|
||||
package service
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
@ -9,3 +15,23 @@ func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func StringData(c *gin.Context, str string) {
|
||||
str = strings.TrimPrefix(str, "data: ")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
func ObjectData(c *gin.Context, object interface{}) error {
|
||||
jsonData, err := json.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
StringData(c, string(jsonData))
|
||||
return nil
|
||||
}
|
||||
|
||||
func Done(c *gin.Context) {
|
||||
StringData(c, "[DONE]")
|
||||
}
|
||||
|
@ -24,3 +24,15 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
}
|
||||
|
||||
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
SystemFingerprint: nil,
|
||||
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
|
||||
Usage: &usage,
|
||||
}
|
||||
}
|
||||
|
@ -554,7 +554,7 @@ const ChannelsTable = () => {
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setChannels(data);
|
||||
setChannelFormat(data);
|
||||
setActivePage(1);
|
||||
} else {
|
||||
showError(message);
|
||||
|
@ -42,6 +42,7 @@ const OperationSetting = () => {
|
||||
MjAccountFilterEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
DrawingEnabled: false,
|
||||
DataExportEnabled: false,
|
||||
DataExportDefaultTime: 'hour',
|
||||
|
@ -104,6 +104,8 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
|
||||
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
|
||||
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
|
||||
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
||||
{
|
||||
key: 22,
|
||||
|
@ -160,8 +160,8 @@ export function renderModelPrice(
|
||||
let inputRatioPrice = modelRatio * 2.0;
|
||||
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
||||
let price =
|
||||
(inputTokens / 1000000) * inputRatioPrice +
|
||||
(completionTokens / 1000000) * completionRatioPrice;
|
||||
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
|
||||
(completionTokens / 1000000) * completionRatioPrice * groupRatio;
|
||||
return (
|
||||
<>
|
||||
<article>
|
||||
|
@ -16,6 +16,7 @@ export default function SettingsDrawing(props) {
|
||||
MjAccountFilterEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
@ -156,6 +157,21 @@ export default function SettingsDrawing(props) {
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={8}>
|
||||
<Form.Switch
|
||||
field={'MjActionCheckSuccessEnabled'}
|
||||
label={<>检测必须等待绘图成功才能进行放大等操作</>}
|
||||
size='large'
|
||||
checkedText='|'
|
||||
uncheckedText='〇'
|
||||
onChange={(value) =>
|
||||
setInputs({
|
||||
...inputs,
|
||||
MjActionCheckSuccessEnabled: value,
|
||||
})
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row>
|
||||
<Button size='large' onClick={onSubmit}>
|
||||
|
Loading…
Reference in New Issue
Block a user