mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 07:56:38 +08:00
merge upstream
Signed-off-by: wozulong <>
This commit is contained in:
commit
0cc7f5cca6
1
.github/workflows/docker-image-arm64.yml
vendored
1
.github/workflows/docker-image-arm64.yml
vendored
@ -4,6 +4,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
|
- '!*-alpha*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
name:
|
name:
|
||||||
|
@ -2,6 +2,21 @@
|
|||||||
|
|
||||||
**简介**:Midjourney Proxy API文档
|
**简介**:Midjourney Proxy API文档
|
||||||
|
|
||||||
|
## 接口列表
|
||||||
|
支持的接口如下:
|
||||||
|
+ [x] /mj/submit/imagine
|
||||||
|
+ [x] /mj/submit/change
|
||||||
|
+ [x] /mj/submit/blend
|
||||||
|
+ [x] /mj/submit/describe
|
||||||
|
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||||
|
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||||
|
+ [x] /task/list-by-condition
|
||||||
|
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||||
|
+ [x] /mj/submit/modal
|
||||||
|
+ [x] /mj/submit/shorten
|
||||||
|
+ [x] /mj/task/{id}/image-seed
|
||||||
|
+ [x] /mj/insight-face/swap (InsightFace)
|
||||||
|
|
||||||
## 模型列表
|
## 模型列表
|
||||||
|
|
||||||
### midjourney-proxy支持
|
### midjourney-proxy支持
|
||||||
|
28
README.md
28
README.md
@ -16,19 +16,7 @@
|
|||||||
此分叉版本的主要变更如下:
|
此分叉版本的主要变更如下:
|
||||||
|
|
||||||
1. 全新的UI界面(部分界面还待更新)
|
1. 全新的UI界面(部分界面还待更新)
|
||||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
|
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
|
||||||
+ [x] /mj/submit/imagine
|
|
||||||
+ [x] /mj/submit/change
|
|
||||||
+ [x] /mj/submit/blend
|
|
||||||
+ [x] /mj/submit/describe
|
|
||||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
|
||||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
|
||||||
+ [x] /task/list-by-condition
|
|
||||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
|
||||||
+ [x] /mj/submit/modal
|
|
||||||
+ [x] /mj/submit/shorten
|
|
||||||
+ [x] /mj/task/{id}/image-seed
|
|
||||||
+ [x] /mj/insight-face/swap (InsightFace)
|
|
||||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||||
+ [x] 易支付
|
+ [x] 易支付
|
||||||
4. 支持用key查询使用额度:
|
4. 支持用key查询使用额度:
|
||||||
@ -45,22 +33,21 @@
|
|||||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
|
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
|
||||||
+ [x] /suno/submit/music
|
14. 支持Rerank模型,目前仅兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
||||||
+ [x] /suno/submit/lyrics
|
|
||||||
+ [x] /suno/fetch
|
|
||||||
+ [x] /suno/fetch/:id
|
|
||||||
|
|
||||||
## 模型支持
|
## 模型支持
|
||||||
此版本额外支持以下模型:
|
此版本额外支持以下模型:
|
||||||
1. 第三方模型 **gps** (gpt-4-gizmo-*, g-*)
|
1. 第三方模型 **gps** (gpt-4-gizmo-*, g-*)
|
||||||
2. 智谱glm-4v,glm-4v识图
|
2. 智谱glm-4v,glm-4v识图
|
||||||
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
3. Anthropic Claude 3
|
||||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||||
6. [零一万物](https://platform.lingyiwanwu.com/)
|
6. [零一万物](https://platform.lingyiwanwu.com/)
|
||||||
7. 自定义渠道,支持填入完整调用地址
|
7. 自定义渠道,支持填入完整调用地址
|
||||||
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||||
|
9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
||||||
|
10. Dify
|
||||||
|
|
||||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||||
|
|
||||||
@ -85,7 +72,8 @@
|
|||||||
|
|
||||||
## 比原版One API多出的配置
|
## 比原版One API多出的配置
|
||||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
||||||
|
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false`
|
||||||
|
- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型
|
||||||
## 部署
|
## 部署
|
||||||
### 部署要求
|
### 部署要求
|
||||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||||
|
62
Rerank.md
Normal file
62
Rerank.md
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Rerank API文档
|
||||||
|
|
||||||
|
**简介**:Rerank API文档
|
||||||
|
|
||||||
|
## 接入Dify
|
||||||
|
模型供应商选择Jina,按要求填写模型信息即可接入Dify。
|
||||||
|
|
||||||
|
## 请求方式
|
||||||
|
|
||||||
|
Post: /v1/rerank
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "rerank-multilingual-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": [
|
||||||
|
"Carson City is the capital city of the American state of Nevada.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||||
|
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||||
|
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||||
|
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"document": {
|
||||||
|
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
|
||||||
|
},
|
||||||
|
"index": 2,
|
||||||
|
"relevance_score": 0.9999702
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"document": {
|
||||||
|
"text": "Carson City is the capital city of the American state of Nevada."
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
"relevance_score": 0.67800725
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"document": {
|
||||||
|
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
|
||||||
|
},
|
||||||
|
"index": 3,
|
||||||
|
"relevance_score": 0.02800752
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 158,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 158
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
7
Suno.md
7
Suno.md
@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
**简介**:Suno API文档
|
**简介**:Suno API文档
|
||||||
|
|
||||||
|
## 接口列表
|
||||||
|
支持的接口如下:
|
||||||
|
+ [x] /suno/submit/music
|
||||||
|
+ [x] /suno/submit/lyrics
|
||||||
|
+ [x] /suno/fetch
|
||||||
|
+ [x] /suno/fetch/:id
|
||||||
|
|
||||||
## 模型列表
|
## 模型列表
|
||||||
|
|
||||||
### Suno API支持
|
### Suno API支持
|
||||||
|
@ -233,36 +233,38 @@ const (
|
|||||||
ChannelTypeCohere = 34
|
ChannelTypeCohere = 34
|
||||||
ChannelTypeMiniMax = 35
|
ChannelTypeMiniMax = 35
|
||||||
ChannelTypeSunoAPI = 36
|
ChannelTypeSunoAPI = 36
|
||||||
|
ChannelTypeDify = 37
|
||||||
|
ChannelTypeJina = 38
|
||||||
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
"", // 0
|
"", // 0
|
||||||
"https://api.openai.com", // 1
|
"https://api.openai.com", // 1
|
||||||
"https://oa.api2d.net", // 2
|
"https://oa.api2d.net", // 2
|
||||||
"", // 3
|
"", // 3
|
||||||
"http://localhost:11434", // 4
|
"http://localhost:11434", // 4
|
||||||
"https://api.openai-sb.com", // 5
|
"https://api.openai-sb.com", // 5
|
||||||
"https://api.openaimax.com", // 6
|
"https://api.openaimax.com", // 6
|
||||||
"https://api.ohmygpt.com", // 7
|
"https://api.ohmygpt.com", // 7
|
||||||
"", // 8
|
"", // 8
|
||||||
"https://api.caipacity.com", // 9
|
"https://api.caipacity.com", // 9
|
||||||
"https://api.aiproxy.io", // 10
|
"https://api.aiproxy.io", // 10
|
||||||
"", // 11
|
"", // 11
|
||||||
"https://api.api2gpt.com", // 12
|
"https://api.api2gpt.com", // 12
|
||||||
"https://api.aigc2d.com", // 13
|
"https://api.aigc2d.com", // 13
|
||||||
"https://api.anthropic.com", // 14
|
"https://api.anthropic.com", // 14
|
||||||
"https://aip.baidubce.com", // 15
|
"https://aip.baidubce.com", // 15
|
||||||
"https://open.bigmodel.cn", // 16
|
"https://open.bigmodel.cn", // 16
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
"", // 18
|
"", // 18
|
||||||
"https://ai.360.cn", // 19
|
"https://ai.360.cn", // 19
|
||||||
"https://openrouter.ai/api", // 20
|
"https://openrouter.ai/api", // 20
|
||||||
"https://api.aiproxy.io", // 21
|
"https://api.aiproxy.io", // 21
|
||||||
"https://fastgpt.run/api/openapi", // 22
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
"https://hunyuan.cloud.tencent.com", //23
|
"https://hunyuan.tencentcloudapi.com", //23
|
||||||
"https://generativelanguage.googleapis.com", //24
|
"https://generativelanguage.googleapis.com", //24
|
||||||
"https://api.moonshot.cn", //25
|
"https://api.moonshot.cn", //25
|
||||||
"https://open.bigmodel.cn", //26
|
"https://open.bigmodel.cn", //26
|
||||||
@ -276,4 +278,6 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.cohere.ai", //34
|
"https://api.cohere.ai", //34
|
||||||
"https://api.minimax.chat", //35
|
"https://api.minimax.chat", //35
|
||||||
"", //36
|
"", //36
|
||||||
|
"", //37
|
||||||
|
"https://api.jina.ai", //38
|
||||||
}
|
}
|
||||||
|
@ -24,3 +24,15 @@ func GetEnvOrDefaultString(env string, defaultValue string) string {
|
|||||||
}
|
}
|
||||||
return os.Getenv(env)
|
return os.Getenv(env)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
b, err := strconv.ParseBool(os.Getenv(env))
|
||||||
|
if err != nil {
|
||||||
|
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
@ -78,18 +78,22 @@ var defaultModelRatio = map[string]float64{
|
|||||||
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
|
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
|
||||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K
|
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K
|
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens //renamed to ERNIE-4.0-8K
|
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
||||||
"ERNIE-4.0-8K": 8.572, // ¥0.12 / 1k tokens
|
"ERNIE-3.5-8K-1222": 0.012 * RMB,
|
||||||
"ERNIE-3.5-8K": 0.8572, // ¥0.012 / 1k tokens
|
"ERNIE-Bot-8K": 0.024 * RMB,
|
||||||
"ERNIE-Speed-8K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-3.5-4K-0205": 0.012 * RMB,
|
||||||
"ERNIE-Speed-128K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-Speed-8K": 0.004 * RMB,
|
||||||
"ERNIE-Lite-8K": 0.2143, // ¥0.003 / 1k tokens
|
"ERNIE-Speed-128K": 0.004 * RMB,
|
||||||
"ERNIE-Tiny-8K": 0.0715, // ¥0.001 / 1k tokens
|
"ERNIE-Lite-8K-0922": 0.008 * RMB,
|
||||||
"ERNIE-Character-8K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-Lite-8K-0308": 0.003 * RMB,
|
||||||
"ERNIE-Functions-8K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-Tiny-8K": 0.001 * RMB,
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"BLOOMZ-7B": 0.004 * RMB,
|
||||||
|
"Embedding-V1": 0.002 * RMB,
|
||||||
|
"bge-large-zh": 0.002 * RMB,
|
||||||
|
"bge-large-en": 0.002 * RMB,
|
||||||
|
"tao-8k": 0.002 * RMB,
|
||||||
"PaLM-2": 1,
|
"PaLM-2": 1,
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
|
@ -5,3 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
||||||
|
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||||
|
|
||||||
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
|
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||||
|
@ -4,6 +4,7 @@ var MjNotifyEnabled = false
|
|||||||
var MjAccountFilterEnabled = false
|
var MjAccountFilterEnabled = false
|
||||||
var MjModeClearEnabled = false
|
var MjModeClearEnabled = false
|
||||||
var MjForwardUrlEnabled = true
|
var MjForwardUrlEnabled = true
|
||||||
|
var MjActionCheckSuccessEnabled = true
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MjErrorUnknown = 5
|
MjErrorUnknown = 5
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -24,6 +25,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||||
|
tik := time.Now()
|
||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == common.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
@ -65,7 +67,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||||
testModel = *channel.TestModel
|
testModel = *channel.TestModel
|
||||||
} else {
|
} else {
|
||||||
testModel = adaptor.GetModelList()[0]
|
if len(channel.GetModels()) > 0 {
|
||||||
|
testModel = channel.GetModels()[0]
|
||||||
|
} else {
|
||||||
|
testModel = "gpt-3.5-turbo"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
modelMapping := *channel.ModelMapping
|
modelMapping := *channel.ModelMapping
|
||||||
@ -115,11 +121,29 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
return errors.New("usage is nil"), nil
|
return errors.New("usage is nil"), nil
|
||||||
}
|
}
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
// print result.Body
|
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
modelPrice, usePrice := common.GetModelPrice(testModel, false)
|
||||||
|
modelRatio := common.GetModelRatio(testModel)
|
||||||
|
completionRatio := common.GetCompletionRatio(testModel)
|
||||||
|
ratio := modelRatio
|
||||||
|
quota := 0
|
||||||
|
if !usePrice {
|
||||||
|
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
|
||||||
|
quota = int(math.Round(float64(quota) * ratio))
|
||||||
|
if ratio != 0 && quota <= 0 {
|
||||||
|
quota = 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quota = int(modelPrice * common.QuotaPerUnit)
|
||||||
|
}
|
||||||
|
tok := time.Now()
|
||||||
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
|
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||||
|
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -140,7 +164,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
channelId, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -148,7 +172,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err := model.GetChannelById(id, true)
|
channel, err := model.GetChannelById(channelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.AudioHelper(c, relayMode)
|
err = relay.AudioHelper(c, relayMode)
|
||||||
|
case relayconstant.RelayModeRerank:
|
||||||
|
err = relay.RerankHelper(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
|
19
dto/rerank.go
Normal file
19
dto/rerank.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type RerankRequest struct {
|
||||||
|
Documents []any `json:"documents"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
TopN int `json:"top_n"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RerankResponseDocument struct {
|
||||||
|
Document any `json:"document"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RerankResponse struct {
|
||||||
|
Results []RerankResponseDocument `json:"results"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
@ -13,7 +13,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
BestOf int `json:"best_of,omitempty"`
|
BestOf int `json:"best_of,omitempty"`
|
||||||
Echo bool `json:"echo,omitempty"`
|
Echo bool `json:"echo,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
StreamOptions any `json:"stream_options,omitempty"`
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
Suffix string `json:"suffix,omitempty"`
|
Suffix string `json:"suffix,omitempty"`
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
@ -48,6 +48,10 @@ type OpenAIFunction struct {
|
|||||||
Parameters any `json:"parameters,omitempty"`
|
Parameters any `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||||
return int64(r.MaxTokens)
|
return int64(r.MaxTokens)
|
||||||
}
|
}
|
||||||
|
@ -102,10 +102,12 @@ type ChatCompletionsStreamResponse struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint *string `json:"system_fingerprint"`
|
SystemFingerprint *string `json:"system_fingerprint"`
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseSimple struct {
|
type ChatCompletionsStreamResponseSimple struct {
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionsStreamResponse struct {
|
type CompletionsStreamResponse struct {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
@ -33,6 +34,13 @@ type Channel struct {
|
|||||||
OtherInfo string `json:"other_info"`
|
OtherInfo string `json:"other_info"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetModels() []string {
|
||||||
|
if channel.Models == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||||
otherInfo := make(map[string]interface{})
|
otherInfo := make(map[string]interface{})
|
||||||
if channel.OtherInfo != "" {
|
if channel.OtherInfo != "" {
|
||||||
|
@ -103,6 +103,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||||
|
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
||||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||||
@ -218,6 +219,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
constant.MjModeClearEnabled = boolValue
|
constant.MjModeClearEnabled = boolValue
|
||||||
case "MjForwardUrlEnabled":
|
case "MjForwardUrlEnabled":
|
||||||
constant.MjForwardUrlEnabled = boolValue
|
constant.MjForwardUrlEnabled = boolValue
|
||||||
|
case "MjActionCheckSuccessEnabled":
|
||||||
|
constant.MjActionCheckSuccessEnabled = boolValue
|
||||||
case "CheckSensitiveEnabled":
|
case "CheckSensitiveEnabled":
|
||||||
constant.CheckSensitiveEnabled = boolValue
|
constant.CheckSensitiveEnabled = boolValue
|
||||||
case "CheckSensitiveOnPromptEnabled":
|
case "CheckSensitiveOnPromptEnabled":
|
||||||
|
@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New("兑换失败," + err.Error())
|
return 0, errors.New("兑换失败," + err.Error())
|
||||||
}
|
}
|
||||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
|
||||||
return redemption.Quota, nil
|
return redemption.Quota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,11 +249,9 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
|
|||||||
if userQuota < quota {
|
if userQuota < quota {
|
||||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||||
}
|
}
|
||||||
if !token.UnlimitedQuota {
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
if err != nil {
|
||||||
if err != nil {
|
return 0, err
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
err = DecreaseUserQuota(token.UserId, quota)
|
err = DecreaseUserQuota(token.UserId, quota)
|
||||||
return userQuota - quota, err
|
return userQuota - quota, err
|
||||||
@ -271,15 +269,13 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !token.UnlimitedQuota {
|
if quota > 0 {
|
||||||
if quota > 0 {
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
} else {
|
||||||
} else {
|
err = IncreaseTokenQuota(tokenId, -quota)
|
||||||
err = IncreaseTokenQuota(tokenId, -quota)
|
}
|
||||||
}
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if sendEmail {
|
if sendEmail {
|
||||||
|
@ -315,7 +315,8 @@ func (user *User) ValidateAndFill() (err error) {
|
|||||||
if user.Username == "" || password == "" {
|
if user.Username == "" || password == "" {
|
||||||
return errors.New("用户名或密码为空")
|
return errors.New("用户名或密码为空")
|
||||||
}
|
}
|
||||||
DB.Where(User{Username: user.Username}).First(user)
|
// find buy username or email
|
||||||
|
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
|
||||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||||
if !okay || user.Status != common.UserStatusEnabled {
|
if !okay || user.Status != common.UserStatusEnabled {
|
||||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||||
|
@ -11,9 +11,11 @@ import (
|
|||||||
type Adaptor interface {
|
type Adaptor interface {
|
||||||
// Init IsStream bool
|
// Init IsStream bool
|
||||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||||
|
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
||||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
|
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
|
@ -15,6 +15,9 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,11 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
@ -53,13 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return claudeReq, err
|
return claudeReq, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage = awsHandler(c, info, a.RequestMode)
|
err, usage = awsHandler(c, info, a.RequestMode)
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
relaymodel "one-api/dto"
|
relaymodel "one-api/dto"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
awsCli, err := newAwsClient(c, info)
|
awsCli, err := newAwsClient(c, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||||
@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
|||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
event, ok := <-stream.Events()
|
event, ok := <-stream.Events()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package baidu
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -15,44 +16,74 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
var fullRequestURL string
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||||
switch info.UpstreamModelName {
|
suffix := "chat/"
|
||||||
case "ERNIE-Bot-4":
|
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
suffix = "embeddings/"
|
||||||
case "ERNIE-Bot-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
|
||||||
case "ERNIE-Bot":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
|
||||||
case "ERNIE-Speed":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
|
||||||
case "ERNIE-Bot-turbo":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
||||||
case "BLOOMZ-7B":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
|
||||||
case "ERNIE-4.0-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
|
||||||
case "ERNIE-3.5-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
|
||||||
case "ERNIE-Speed-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
|
||||||
case "ERNIE-Character-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k"
|
|
||||||
case "ERNIE-Functions-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k"
|
|
||||||
case "ERNIE-Lite-8K-0922":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
||||||
case "Yi-34B-Chat":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat"
|
|
||||||
case "Embedding-V1":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
|
||||||
default:
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName)
|
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
|
||||||
|
suffix = "embeddings/"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
|
||||||
|
suffix = "embeddings/"
|
||||||
|
}
|
||||||
|
switch info.UpstreamModelName {
|
||||||
|
case "ERNIE-4.0":
|
||||||
|
suffix += "completions_pro"
|
||||||
|
case "ERNIE-Bot-4":
|
||||||
|
suffix += "completions_pro"
|
||||||
|
case "ERNIE-Bot":
|
||||||
|
suffix += "completions"
|
||||||
|
case "ERNIE-Bot-turbo":
|
||||||
|
suffix += "eb-instant"
|
||||||
|
case "ERNIE-Speed":
|
||||||
|
suffix += "ernie_speed"
|
||||||
|
case "ERNIE-4.0-8K":
|
||||||
|
suffix += "completions_pro"
|
||||||
|
case "ERNIE-3.5-8K":
|
||||||
|
suffix += "completions"
|
||||||
|
case "ERNIE-3.5-8K-0205":
|
||||||
|
suffix += "ernie-3.5-8k-0205"
|
||||||
|
case "ERNIE-3.5-8K-1222":
|
||||||
|
suffix += "ernie-3.5-8k-1222"
|
||||||
|
case "ERNIE-Bot-8K":
|
||||||
|
suffix += "ernie_bot_8k"
|
||||||
|
case "ERNIE-3.5-4K-0205":
|
||||||
|
suffix += "ernie-3.5-4k-0205"
|
||||||
|
case "ERNIE-Speed-8K":
|
||||||
|
suffix += "ernie_speed"
|
||||||
|
case "ERNIE-Speed-128K":
|
||||||
|
suffix += "ernie-speed-128k"
|
||||||
|
case "ERNIE-Lite-8K-0922":
|
||||||
|
suffix += "eb-instant"
|
||||||
|
case "ERNIE-Lite-8K-0308":
|
||||||
|
suffix += "ernie-lite-8k"
|
||||||
|
case "ERNIE-Tiny-8K":
|
||||||
|
suffix += "ernie-tiny-8k"
|
||||||
|
case "BLOOMZ-7B":
|
||||||
|
suffix += "bloomz_7b1"
|
||||||
|
case "Embedding-V1":
|
||||||
|
suffix += "embedding-v1"
|
||||||
|
case "bge-large-zh":
|
||||||
|
suffix += "bge_large_zh"
|
||||||
|
case "bge-large-en":
|
||||||
|
suffix += "bge_large_en"
|
||||||
|
case "tao-8k":
|
||||||
|
suffix += "tao_8k"
|
||||||
|
default:
|
||||||
|
suffix += strings.ToLower(info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
|
||||||
var accessToken string
|
var accessToken string
|
||||||
var err error
|
var err error
|
||||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||||
@ -82,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
@ -1,20 +1,22 @@
|
|||||||
package baidu
|
package baidu
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"ERNIE-3.5-8K",
|
|
||||||
"ERNIE-4.0-8K",
|
"ERNIE-4.0-8K",
|
||||||
|
"ERNIE-3.5-8K",
|
||||||
|
"ERNIE-3.5-8K-0205",
|
||||||
|
"ERNIE-3.5-8K-1222",
|
||||||
|
"ERNIE-Bot-8K",
|
||||||
|
"ERNIE-3.5-4K-0205",
|
||||||
"ERNIE-Speed-8K",
|
"ERNIE-Speed-8K",
|
||||||
"ERNIE-Speed-128K",
|
"ERNIE-Speed-128K",
|
||||||
"ERNIE-Lite-8K",
|
"ERNIE-Lite-8K-0922",
|
||||||
|
"ERNIE-Lite-8K-0308",
|
||||||
"ERNIE-Tiny-8K",
|
"ERNIE-Tiny-8K",
|
||||||
"ERNIE-Character-8K",
|
"BLOOMZ-7B",
|
||||||
"ERNIE-Functions-8K",
|
|
||||||
//"ERNIE-Bot-4",
|
|
||||||
//"ERNIE-Bot-8K",
|
|
||||||
//"ERNIE-Bot",
|
|
||||||
//"ERNIE-Speed",
|
|
||||||
//"ERNIE-Bot-turbo",
|
|
||||||
"Embedding-V1",
|
"Embedding-V1",
|
||||||
|
"bge-large-zh",
|
||||||
|
"bge-large-en",
|
||||||
|
"tao-8k",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "baidu"
|
var ChannelName = "baidu"
|
||||||
|
@ -11,9 +11,16 @@ type BaiduMessage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BaiduChatRequest struct {
|
type BaiduChatRequest struct {
|
||||||
Messages []BaiduMessage `json:"messages"`
|
Messages []BaiduMessage `json:"messages"`
|
||||||
Stream bool `json:"stream"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
UserId string `json:"user_id,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
DisableSearch bool `json:"disable_search,omitempty"`
|
||||||
|
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||||
|
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
|
@ -22,17 +22,33 @@ import (
|
|||||||
var baiduTokenStore sync.Map
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
baiduRequest := BaiduChatRequest{
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
PenaltyScore: request.FrequencyPenalty,
|
||||||
|
Stream: request.Stream,
|
||||||
|
DisableSearch: false,
|
||||||
|
EnableCitation: false,
|
||||||
|
UserId: request.User,
|
||||||
|
}
|
||||||
|
if request.MaxTokens != 0 {
|
||||||
|
maxTokens := int(request.MaxTokens)
|
||||||
|
if request.MaxTokens == 1 {
|
||||||
|
maxTokens = 2
|
||||||
|
}
|
||||||
|
baiduRequest.MaxOutputTokens = &maxTokens
|
||||||
|
}
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
messages = append(messages, BaiduMessage{
|
if message.Role == "system" {
|
||||||
Role: message.Role,
|
baiduRequest.System = message.StringContent()
|
||||||
Content: message.StringContent(),
|
} else {
|
||||||
})
|
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
|
||||||
}
|
Role: message.Role,
|
||||||
return &BaiduChatRequest{
|
Content: message.StringContent(),
|
||||||
Messages: messages,
|
})
|
||||||
Stream: request.Stream,
|
}
|
||||||
}
|
}
|
||||||
|
return &baiduRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||||
|
@ -21,6 +21,11 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
@ -330,22 +330,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
|
|
||||||
jsonStr, err := json.Marshal(response)
|
err = service.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
common.SysError(err.Error())
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
@ -356,6 +349,18 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,16 +8,24 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
|
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return requestConvertRerank2Cohere(request), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = cohereStreamHandler(c, resp, info)
|
err, usage = cohereRerankHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
if info.IsStream {
|
||||||
|
err, usage = cohereStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package cohere
|
|||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||||
|
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "cohere"
|
var ChannelName = "cohere"
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package cohere
|
package cohere
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
type CohereRequest struct {
|
type CohereRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
ChatHistory []ChatHistory `json:"chat_history"`
|
ChatHistory []ChatHistory `json:"chat_history"`
|
||||||
@ -28,6 +30,19 @@ type CohereResponseResult struct {
|
|||||||
Meta CohereMeta `json:"meta"`
|
Meta CohereMeta `json:"meta"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CohereRerankRequest struct {
|
||||||
|
Documents []any `json:"documents"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
TopN int `json:"top_n"`
|
||||||
|
ReturnDocuments bool `json:"return_documents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CohereRerankResponseResult struct {
|
||||||
|
Results []dto.RerankResponseDocument `json:"results"`
|
||||||
|
Meta CohereMeta `json:"meta"`
|
||||||
|
}
|
||||||
|
|
||||||
type CohereMeta struct {
|
type CohereMeta struct {
|
||||||
//Tokens CohereTokens `json:"tokens"`
|
//Tokens CohereTokens `json:"tokens"`
|
||||||
BilledUnits CohereBilledUnits `json:"billed_units"`
|
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||||
|
@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
|||||||
return &cohereReq
|
return &cohereReq
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
|
||||||
|
if rerankRequest.TopN == 0 {
|
||||||
|
rerankRequest.TopN = 1
|
||||||
|
}
|
||||||
|
cohereReq := CohereRerankRequest{
|
||||||
|
Query: rerankRequest.Query,
|
||||||
|
Documents: rerankRequest.Documents,
|
||||||
|
Model: rerankRequest.Model,
|
||||||
|
TopN: rerankRequest.TopN,
|
||||||
|
ReturnDocuments: true,
|
||||||
|
}
|
||||||
|
return &cohereReq
|
||||||
|
}
|
||||||
|
|
||||||
func stopReasonCohere2OpenAI(reason string) string {
|
func stopReasonCohere2OpenAI(reason string) string {
|
||||||
switch reason {
|
switch reason {
|
||||||
case "COMPLETE":
|
case "COMPLETE":
|
||||||
@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
|||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var cohereResp CohereRerankResponseResult
|
||||||
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
usage := dto.Usage{}
|
||||||
|
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.CompletionTokens = 0
|
||||||
|
usage.TotalTokens = info.PromptTokens
|
||||||
|
} else {
|
||||||
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||||
|
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||||
|
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
var rerankResp dto.RerankResponse
|
||||||
|
rerankResp.Results = cohereResp.Results
|
||||||
|
rerankResp.Usage = usage
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(rerankResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
65
relay/channel/dify/adaptor.go
Normal file
65
relay/channel/dify/adaptor.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return requestOpenAI2Dify(*request), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = difyStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = difyHandler(c, resp, info)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
5
relay/channel/dify/constants.go
Normal file
5
relay/channel/dify/constants.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
var ModelList []string
|
||||||
|
|
||||||
|
var ChannelName = "dify"
|
35
relay/channel/dify/dto.go
Normal file
35
relay/channel/dify/dto.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
type DifyChatRequest struct {
|
||||||
|
Inputs map[string]interface{} `json:"inputs"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
ResponseMode string `json:"response_mode"`
|
||||||
|
User string `json:"user"`
|
||||||
|
AutoGenerateName bool `json:"auto_generate_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyMetaData struct {
|
||||||
|
Usage dto.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyData struct {
|
||||||
|
WorkflowId string `json:"workflow_id"`
|
||||||
|
NodeId string `json:"node_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyChatCompletionResponse struct {
|
||||||
|
ConversationId string `json:"conversation_id"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
CreateAt int64 `json:"create_at"`
|
||||||
|
MetaData DifyMetaData `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyChunkChatCompletionResponse struct {
|
||||||
|
Event string `json:"event"`
|
||||||
|
ConversationId string `json:"conversation_id"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
Data DifyData `json:"data"`
|
||||||
|
MetaData DifyMetaData `json:"metadata"`
|
||||||
|
}
|
156
relay/channel/dify/relay-dify.go
Normal file
156
relay/channel/dify/relay-dify.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
|
||||||
|
content := ""
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
content += "SYSTEM: \n" + message.StringContent() + "\n"
|
||||||
|
} else if message.Role == "assistant" {
|
||||||
|
content += "ASSISTANT: \n" + message.StringContent() + "\n"
|
||||||
|
} else {
|
||||||
|
content += "USER: \n" + message.StringContent() + "\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mode := "blocking"
|
||||||
|
if request.Stream {
|
||||||
|
mode = "streaming"
|
||||||
|
}
|
||||||
|
user := request.User
|
||||||
|
if user == "" {
|
||||||
|
user = "api-user"
|
||||||
|
}
|
||||||
|
return &DifyChatRequest{
|
||||||
|
Inputs: make(map[string]interface{}),
|
||||||
|
Query: content,
|
||||||
|
ResponseMode: mode,
|
||||||
|
User: user,
|
||||||
|
AutoGenerateName: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "dify",
|
||||||
|
}
|
||||||
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
|
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
|
||||||
|
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
||||||
|
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
||||||
|
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
||||||
|
} else if difyResponse.Event == "message" {
|
||||||
|
choice.Delta.SetContentString(difyResponse.Answer)
|
||||||
|
}
|
||||||
|
response.Choices = append(response.Choices, choice)
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var responseText string
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
service.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = strings.TrimPrefix(data, "data:")
|
||||||
|
var difyResponse DifyChunkChatCompletionResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if difyResponse.Event == "message_end" {
|
||||||
|
usage = &difyResponse.MetaData.Usage
|
||||||
|
break
|
||||||
|
} else if difyResponse.Event == "error" {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||||
|
if len(openaiResponse.Choices) != 0 {
|
||||||
|
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = service.ObjectData(c, openaiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.SysError("error reading stream: " + err.Error())
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
common.SysError("close_response_body_failed: " + err.Error())
|
||||||
|
}
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var difyResponse DifyChatCompletionResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &difyResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
|
Id: difyResponse.ConversationId,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Usage: difyResponse.MetaData.Usage,
|
||||||
|
}
|
||||||
|
content, _ := json.Marshal(difyResponse.Answer)
|
||||||
|
choice := dto.OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: dto.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &difyResponse.MetaData.Usage
|
||||||
|
}
|
@ -9,12 +9,14 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,15 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return CovertGemini2OpenAI(*request), nil
|
return CovertGemini2OpenAI(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||||
err, responseText = geminiChatStreamHandler(c, resp, info)
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
} else {
|
||||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
|
|||||||
type GeminiChatResponse struct {
|
type GeminiChatResponse struct {
|
||||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||||
|
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiUsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
|
responseJson := ""
|
||||||
|
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createAt := common.GetTimestamp()
|
||||||
|
var usage = &dto.Usage{}
|
||||||
dataChan := make(chan string, 5)
|
dataChan := make(chan string, 5)
|
||||||
stopChan := make(chan bool, 2)
|
stopChan := make(chan bool, 2)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
go func() {
|
go func() {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
|
responseJson += data
|
||||||
data = strings.TrimSpace(data)
|
data = strings.TrimSpace(data)
|
||||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||||
continue
|
continue
|
||||||
@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.SetContentString(dummy.Content)
|
choice.Delta.SetContentString(dummy.Content)
|
||||||
response := dto.ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: createAt,
|
||||||
Model: "gemini-pro",
|
Model: info.UpstreamModelName,
|
||||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
var geminiChatResponses []GeminiChatResponse
|
||||||
|
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
log.Printf("cannot get gemini usage: %s", err.Error())
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
for _, response := range geminiChatResponses {
|
||||||
|
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
||||||
|
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
||||||
|
}
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
}
|
}
|
||||||
return nil, responseText
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||||
}
|
}
|
||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
64
relay/channel/jina/adaptor.go
Normal file
64
relay/channel/jina/adaptor.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package jina
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
|
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||||
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
|
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
return "", errors.New("invalid relay mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
|
err, usage = jinaRerankHandler(c, resp)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
8
relay/channel/jina/constant.go
Normal file
8
relay/channel/jina/constant.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package jina
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"jina-clip-v1",
|
||||||
|
"jina-reranker-v2-base-multilingual",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "jina"
|
35
relay/channel/jina/relay-jina.go
Normal file
35
relay/channel/jina/relay-jina.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package jina
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var jinaResp dto.RerankResponse
|
||||||
|
err = json.Unmarshal(responseBody, &jinaResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(jinaResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &jinaResp.Usage
|
||||||
|
}
|
@ -16,6 +16,9 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@ -52,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||||
|
@ -22,6 +22,13 @@ type Adaptor struct {
|
|||||||
ChannelType int
|
ChannelType int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
}
|
}
|
||||||
@ -82,9 +89,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
var toolCount int
|
var toolCount int
|
||||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -18,9 +18,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
|
||||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
|
var usage dto.Usage
|
||||||
toolCount := 0
|
toolCount := 0
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
@ -62,17 +63,24 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
streamItems = append(streamItems, data)
|
streamItems = append(streamItems, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 计算token
|
||||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 一次性解析失败,逐个解析
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
for _, item := range streamItems {
|
for _, item := range streamItems {
|
||||||
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
if streamResponse.Usage != nil {
|
||||||
|
if streamResponse.Usage.TotalTokens != 0 {
|
||||||
|
usage = *streamResponse.Usage
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
if choice.Delta.ToolCalls != nil {
|
if choice.Delta.ToolCalls != nil {
|
||||||
@ -89,6 +97,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, streamResponse := range streamResponses {
|
for _, streamResponse := range streamResponses {
|
||||||
|
if streamResponse.Usage != nil {
|
||||||
|
if streamResponse.Usage.TotalTokens != 0 {
|
||||||
|
usage = *streamResponse.Usage
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
if choice.Delta.ToolCalls != nil {
|
if choice.Delta.ToolCalls != nil {
|
||||||
@ -107,6 +120,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
var streamResponses []dto.CompletionsStreamResponse
|
var streamResponses []dto.CompletionsStreamResponse
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 一次性解析失败,逐个解析
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
for _, item := range streamItems {
|
for _, item := range streamItems {
|
||||||
var streamResponse dto.CompletionsStreamResponse
|
var streamResponse dto.CompletionsStreamResponse
|
||||||
@ -133,13 +147,19 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
isFirst := true
|
isFirst := true
|
||||||
|
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
common.LogError(c, "reading data from upstream timeout")
|
||||||
|
return false
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
if isFirst {
|
if isFirst {
|
||||||
isFirst = false
|
isFirst = false
|
||||||
info.FirstResponseTime = time.Now()
|
info.FirstResponseTime = time.Now()
|
||||||
}
|
}
|
||||||
|
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||||
if strings.HasPrefix(data, "data: [DONE]") {
|
if strings.HasPrefix(data, "data: [DONE]") {
|
||||||
data = data[:12]
|
data = data[:12]
|
||||||
}
|
}
|
||||||
@ -153,10 +173,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return nil, responseTextBuilder.String(), toolCount
|
return nil, &usage, responseTextBuilder.String(), toolCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
@ -15,6 +15,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return requestOpenAI2Perplexity(*request), nil
|
return requestOpenAI2Perplexity(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@ -46,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -6,28 +6,44 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
Sign string
|
Sign string
|
||||||
|
AppID int64
|
||||||
|
Action string
|
||||||
|
Version string
|
||||||
|
Timestamp int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
a.Action = "ChatCompletions"
|
||||||
|
a.Version = "2023-09-01"
|
||||||
|
a.Timestamp = common.GetTimestamp()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", a.Sign)
|
req.Header.Set("Authorization", a.Sign)
|
||||||
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
req.Header.Set("X-TC-Action", a.Action)
|
||||||
|
req.Header.Set("X-TC-Version", a.Version)
|
||||||
|
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,17 +54,20 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
apiKey := c.Request.Header.Get("Authorization")
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||||
|
a.AppID = appId
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tencentRequest := requestOpenAI2Tencent(*request)
|
tencentRequest := requestOpenAI2Tencent(a, *request)
|
||||||
tencentRequest.AppId = appId
|
|
||||||
tencentRequest.SecretId = secretId
|
|
||||||
// we have to calculate the sign here
|
// we have to calculate the sign here
|
||||||
a.Sign = getTencentSign(*tencentRequest, secretKey)
|
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
|
||||||
return tencentRequest, nil
|
return tencentRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
package tencent
|
package tencent
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"ChatPro",
|
"hunyuan-lite",
|
||||||
"ChatStd",
|
"hunyuan-standard",
|
||||||
"hunyuan",
|
"hunyuan-standard-256K",
|
||||||
|
"hunyuan-pro",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "tencent"
|
var ChannelName = "tencent"
|
||||||
|
@ -1,62 +1,75 @@
|
|||||||
package tencent
|
package tencent
|
||||||
|
|
||||||
import "one-api/dto"
|
|
||||||
|
|
||||||
type TencentMessage struct {
|
type TencentMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"Role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"Content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentChatRequest struct {
|
type TencentChatRequest struct {
|
||||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
|
||||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
|
||||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
//
|
||||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
// 注意:
|
||||||
Timestamp int64 `json:"timestamp"`
|
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
|
||||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
Model *string `json:"Model"`
|
||||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
// 聊天上下文信息。
|
||||||
Expired int64 `json:"expired"`
|
// 说明:
|
||||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
// 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
|
||||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
// 2. Message.Role 可选值:system、user、assistant。
|
||||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
// 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
|
||||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
|
||||||
Temperature float64 `json:"temperature"`
|
Messages []*TencentMessage `json:"Messages"`
|
||||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
// 流式调用开关。
|
||||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
// 说明:
|
||||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
// 1. 未传值时默认为非流式调用(false)。
|
||||||
TopP float64 `json:"top_p"`
|
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
|
||||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
// 3. 非流式调用时:
|
||||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
// 调用方式与普通 HTTP 请求无异。
|
||||||
Stream int `json:"stream"`
|
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
|
||||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
|
||||||
// 输入 content 总数最大支持 3000 token。
|
//
|
||||||
Messages []TencentMessage `json:"messages"`
|
// 注意:
|
||||||
Model string `json:"model"` // 模型名称
|
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
|
||||||
|
Stream *bool `json:"Stream,omitempty"`
|
||||||
|
// 说明:
|
||||||
|
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||||
|
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||||
|
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||||
|
TopP *float64 `json:"TopP,omitempty"`
|
||||||
|
// 说明:
|
||||||
|
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||||
|
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||||
|
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||||
|
Temperature *float64 `json:"Temperature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentError struct {
|
type TencentError struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"Code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"Message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentUsage struct {
|
type TencentUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
PromptTokens int `json:"PromptTokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
CompletionTokens int `json:"CompletionTokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"TotalTokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentResponseChoices struct {
|
type TencentResponseChoices struct {
|
||||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentChatResponse struct {
|
type TencentChatResponse struct {
|
||||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
|
||||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||||
Id string `json:"id,omitempty"` // 会话 id
|
Id string `json:"Id,omitempty"` // 会话 id
|
||||||
Usage dto.Usage `json:"usage,omitempty"` // token 数量
|
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
|
||||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
Note string `json:"note,omitempty"` // 注释
|
Note string `json:"Note,omitempty"` // 注释
|
||||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatResponseSB struct {
|
||||||
|
Response TencentChatResponse `json:"Response,omitempty"`
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,8 @@ package tencent
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha1"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -15,54 +15,46 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://cloud.tencent.com/document/product/1729/97732
|
// https://cloud.tencent.com/document/product/1729/97732
|
||||||
|
|
||||||
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
messages := make([]*TencentMessage, 0, len(request.Messages))
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
message := request.Messages[i]
|
message := request.Messages[i]
|
||||||
if message.Role == "system" {
|
messages = append(messages, &TencentMessage{
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Content: message.StringContent(),
|
Content: message.StringContent(),
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
stream := 0
|
var req = TencentChatRequest{
|
||||||
if request.Stream {
|
Stream: &request.Stream,
|
||||||
stream = 1
|
Messages: messages,
|
||||||
|
Model: &request.Model,
|
||||||
}
|
}
|
||||||
return &TencentChatRequest{
|
if request.TopP != 0 {
|
||||||
Timestamp: common.GetTimestamp(),
|
req.TopP = &request.TopP
|
||||||
Expired: common.GetTimestamp() + 24*60*60,
|
|
||||||
QueryID: common.GetUUID(),
|
|
||||||
Temperature: request.Temperature,
|
|
||||||
TopP: request.TopP,
|
|
||||||
Stream: stream,
|
|
||||||
Messages: messages,
|
|
||||||
Model: request.Model,
|
|
||||||
}
|
}
|
||||||
|
if request.Temperature != 0 {
|
||||||
|
req.Temperature = &request.Temperature
|
||||||
|
}
|
||||||
|
return &req
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
|
Id: response.Id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Usage: response.Usage,
|
Usage: dto.Usage{
|
||||||
|
PromptTokens: response.Usage.PromptTokens,
|
||||||
|
CompletionTokens: response.Usage.CompletionTokens,
|
||||||
|
TotalTokens: response.Usage.TotalTokens,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if len(response.Choices) > 0 {
|
if len(response.Choices) > 0 {
|
||||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||||
@ -99,69 +91,51 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
|||||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
var responseText string
|
var responseText string
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
for scanner.Scan() {
|
||||||
case data := <-dataChan:
|
data := scanner.Text()
|
||||||
var TencentResponse TencentChatResponse
|
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
continue
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
|
||||||
if len(response.Choices) != 0 {
|
|
||||||
responseText += response.Choices[0].Delta.GetContentString()
|
|
||||||
}
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
})
|
data = strings.TrimPrefix(data, "data:")
|
||||||
|
|
||||||
|
var tencentResponse TencentChatResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
response := streamResponseTencent2OpenAI(&tencentResponse)
|
||||||
|
if len(response.Choices) != 0 {
|
||||||
|
responseText += response.Choices[0].Delta.GetContentString()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.SysError("error reading stream: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Done(c)
|
||||||
|
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var TencentResponse TencentChatResponse
|
var tencentSb TencentChatResponseSB
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@ -170,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
err = json.Unmarshal(responseBody, &tencentSb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if TencentResponse.Error.Code != 0 {
|
if tencentSb.Response.Error.Code != 0 {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
Error: dto.OpenAIError{
|
Error: dto.OpenAIError{
|
||||||
Message: TencentResponse.Error.Message,
|
Message: tencentSb.Response.Error.Message,
|
||||||
Code: TencentResponse.Error.Code,
|
Code: tencentSb.Response.Error.Code,
|
||||||
},
|
},
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@ -206,29 +180,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
func sha256hex(s string) string {
|
||||||
params := make([]string, 0)
|
b := sha256.Sum256([]byte(s))
|
||||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
return hex.EncodeToString(b[:])
|
||||||
params = append(params, "secret_id="+req.SecretId)
|
}
|
||||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
|
||||||
params = append(params, "query_id="+req.QueryID)
|
func hmacSha256(s, key string) string {
|
||||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
hashed := hmac.New(sha256.New, []byte(key))
|
||||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
hashed.Write([]byte(s))
|
||||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
return string(hashed.Sum(nil))
|
||||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
}
|
||||||
|
|
||||||
var messageStr string
|
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
||||||
for _, msg := range req.Messages {
|
// build canonical request string
|
||||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
host := "hunyuan.tencentcloudapi.com"
|
||||||
}
|
httpRequestMethod := "POST"
|
||||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
canonicalURI := "/"
|
||||||
params = append(params, "messages=["+messageStr+"]")
|
canonicalQueryString := ""
|
||||||
|
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
||||||
sort.Sort(sort.StringSlice(params))
|
"application/json", host, strings.ToLower(adaptor.Action))
|
||||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
signedHeaders := "content-type;host;x-tc-action"
|
||||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
payload, _ := json.Marshal(req)
|
||||||
signURL := url
|
hashedRequestPayload := sha256hex(string(payload))
|
||||||
mac.Write([]byte(signURL))
|
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||||
sign := mac.Sum([]byte(nil))
|
httpRequestMethod,
|
||||||
return base64.StdEncoding.EncodeToString(sign)
|
canonicalURI,
|
||||||
|
canonicalQueryString,
|
||||||
|
canonicalHeaders,
|
||||||
|
signedHeaders,
|
||||||
|
hashedRequestPayload)
|
||||||
|
// build string to sign
|
||||||
|
algorithm := "TC3-HMAC-SHA256"
|
||||||
|
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
|
||||||
|
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
|
||||||
|
t := time.Unix(timestamp, 0).UTC()
|
||||||
|
// must be the format 2006-01-02, ref to package time for more info
|
||||||
|
date := t.Format("2006-01-02")
|
||||||
|
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
|
||||||
|
hashedCanonicalRequest := sha256hex(canonicalRequest)
|
||||||
|
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
|
||||||
|
algorithm,
|
||||||
|
requestTimestamp,
|
||||||
|
credentialScope,
|
||||||
|
hashedCanonicalRequest)
|
||||||
|
|
||||||
|
// sign string
|
||||||
|
secretDate := hmacSha256(date, "TC3"+secKey)
|
||||||
|
secretService := hmacSha256("hunyuan", secretDate)
|
||||||
|
secretKey := hmacSha256("tc3_request", secretService)
|
||||||
|
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
|
||||||
|
|
||||||
|
// build authorization
|
||||||
|
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||||
|
algorithm,
|
||||||
|
secId,
|
||||||
|
credentialScope,
|
||||||
|
signedHeaders,
|
||||||
|
signature)
|
||||||
|
return authorization
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,11 @@ type Adaptor struct {
|
|||||||
request *dto.GeneralOpenAIRequest
|
request *dto.GeneralOpenAIRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
// xunfei's request is not http request, so we don't need to do anything here
|
// xunfei's request is not http request, so we don't need to do anything here
|
||||||
dummyResp := &http.Response{}
|
dummyResp := &http.Response{}
|
||||||
|
@ -14,6 +14,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return requestOpenAI2Zhipu(*request), nil
|
return requestOpenAI2Zhipu(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return requestOpenAI2Zhipu(*request), nil
|
return requestOpenAI2Zhipu(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@ -48,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
var toolCount int
|
var toolCount int
|
||||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -9,24 +9,26 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RelayInfo struct {
|
type RelayInfo struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
TokenId int
|
TokenId int
|
||||||
UserId int
|
UserId int
|
||||||
Group string
|
Group string
|
||||||
TokenUnlimited bool
|
TokenUnlimited bool
|
||||||
StartTime time.Time
|
StartTime time.Time
|
||||||
FirstResponseTime time.Time
|
FirstResponseTime time.Time
|
||||||
ApiType int
|
ApiType int
|
||||||
IsStream bool
|
IsStream bool
|
||||||
RelayMode int
|
RelayMode int
|
||||||
UpstreamModelName string
|
UpstreamModelName string
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
ApiVersion string
|
ApiVersion string
|
||||||
PromptTokens int
|
PromptTokens int
|
||||||
ApiKey string
|
ApiKey string
|
||||||
Organization string
|
Organization string
|
||||||
BaseUrl string
|
BaseUrl string
|
||||||
|
SupportStreamOptions bool
|
||||||
|
ShouldIncludeUsage bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
@ -65,6 +67,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
info.ApiVersion = GetAPIVersion(c)
|
info.ApiVersion = GetAPIVersion(c)
|
||||||
}
|
}
|
||||||
|
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||||
|
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
|
||||||
|
info.SupportStreamOptions = true
|
||||||
|
}
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
|
|||||||
var textResponse dto.TextResponseWithError
|
var textResponse dto.TextResponseWithError
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
||||||
|
@ -20,6 +20,8 @@ const (
|
|||||||
APITypePerplexity
|
APITypePerplexity
|
||||||
APITypeAws
|
APITypeAws
|
||||||
APITypeCohere
|
APITypeCohere
|
||||||
|
APITypeDify
|
||||||
|
APITypeJina
|
||||||
|
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
@ -57,6 +59,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = APITypeAws
|
apiType = APITypeAws
|
||||||
case common.ChannelTypeCohere:
|
case common.ChannelTypeCohere:
|
||||||
apiType = APITypeCohere
|
apiType = APITypeCohere
|
||||||
|
case common.ChannelTypeDify:
|
||||||
|
apiType = APITypeDify
|
||||||
|
case common.ChannelTypeJina:
|
||||||
|
apiType = APITypeJina
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return APITypeOpenAI, false
|
return APITypeOpenAI, false
|
||||||
|
@ -32,6 +32,7 @@ const (
|
|||||||
RelayModeSunoFetch
|
RelayModeSunoFetch
|
||||||
RelayModeSunoFetchByID
|
RelayModeSunoFetchByID
|
||||||
RelayModeSunoSubmit
|
RelayModeSunoSubmit
|
||||||
|
RelayModeRerank
|
||||||
)
|
)
|
||||||
|
|
||||||
func Path2RelayMode(path string) int {
|
func Path2RelayMode(path string) int {
|
||||||
@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
|
|||||||
relayMode = RelayModeAudioTranscription
|
relayMode = RelayModeAudioTranscription
|
||||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||||
relayMode = RelayModeAudioTranslation
|
relayMode = RelayModeAudioTranslation
|
||||||
|
} else if strings.HasPrefix(path, "/v1/rerank") {
|
||||||
|
relayMode = RelayModeRerank
|
||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
@ -415,9 +415,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
originTask := model.GetByMJId(userId, mjId)
|
originTask := model.GetByMJId(userId, mjId)
|
||||||
if originTask == nil {
|
if originTask == nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||||
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
|
||||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
|
if constant.MjActionCheckSuccessEnabled {
|
||||||
|
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||||
|
}
|
||||||
|
}
|
||||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||||
@ -500,7 +503,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||||
other := make(map[string]interface{})
|
other := make(map[string]interface{})
|
||||||
other["model_price"] = modelPrice
|
other["model_price"] = modelPrice
|
||||||
other["group_ratio"] = groupRatio
|
other["group_ratio"] = groupRatio
|
||||||
@ -544,7 +547,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("get_channel_null: " + err.Error())
|
common.SysError("get_channel_null: " + err.Error())
|
||||||
}
|
}
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 1 {
|
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
|
||||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetString("model_mapping")
|
||||||
isModelMapped := false
|
//isModelMapped := false
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
if modelMapping != "" && modelMapping != "{}" {
|
||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if modelMap[textRequest.Model] != "" {
|
if modelMap[textRequest.Model] != "" {
|
||||||
textRequest.Model = modelMap[textRequest.Model]
|
textRequest.Model = modelMap[textRequest.Model]
|
||||||
// set upstream model name
|
// set upstream model name
|
||||||
isModelMapped = true
|
//isModelMapped = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
relayInfo.UpstreamModelName = textRequest.Model
|
relayInfo.UpstreamModelName = textRequest.Model
|
||||||
@ -130,33 +130,38 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||||
|
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
||||||
|
textRequest.StreamOptions = nil
|
||||||
|
} else {
|
||||||
|
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||||
|
if constant.ForceStreamOption {
|
||||||
|
textRequest.StreamOptions = &dto.StreamOptions{
|
||||||
|
IncludeUsage: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||||
|
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
|
||||||
|
}
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo, *textRequest)
|
adaptor.Init(relayInfo, *textRequest)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
|
|
||||||
if isModelMapped {
|
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||||
jsonStr, err := json.Marshal(textRequest)
|
if err != nil {
|
||||||
if err != nil {
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
|
||||||
}
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
@ -182,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -281,7 +286,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
|
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||||
modelPrice float64, usePrice bool) {
|
modelPrice float64, usePrice bool) {
|
||||||
|
|
||||||
@ -290,7 +295,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
|||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
|
|
||||||
tokenName := ctx.GetString("token_name")
|
tokenName := ctx.GetString("token_name")
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
completionRatio := common.GetCompletionRatio(modelName)
|
||||||
|
|
||||||
quota := 0
|
quota := 0
|
||||||
if !usePrice {
|
if !usePrice {
|
||||||
@ -316,7 +321,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
|||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
quota = 0
|
quota = 0
|
||||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
|
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||||
|
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||||
} else {
|
} else {
|
||||||
//if sensitiveResp != nil {
|
//if sensitiveResp != nil {
|
||||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||||
@ -336,16 +342,17 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
|||||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
logModel := textRequest.Model
|
logModel := modelName
|
||||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||||
logModel = "gpt-4-gizmo-*"
|
logModel = "gpt-4-gizmo-*"
|
||||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||||
} else if strings.HasPrefix(logModel, "g-") {
|
} else if strings.HasPrefix(logModel, "g-") {
|
||||||
logModel = "g-*"
|
logModel = "g-*"
|
||||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||||
}
|
}
|
||||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||||
|
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||||
|
|
||||||
//if quota != 0 {
|
//if quota != 0 {
|
||||||
//
|
//
|
||||||
|
@ -8,7 +8,9 @@ import (
|
|||||||
"one-api/relay/channel/baidu"
|
"one-api/relay/channel/baidu"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
"one-api/relay/channel/cohere"
|
"one-api/relay/channel/cohere"
|
||||||
|
"one-api/relay/channel/dify"
|
||||||
"one-api/relay/channel/gemini"
|
"one-api/relay/channel/gemini"
|
||||||
|
"one-api/relay/channel/jina"
|
||||||
"one-api/relay/channel/ollama"
|
"one-api/relay/channel/ollama"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
"one-api/relay/channel/palm"
|
"one-api/relay/channel/palm"
|
||||||
@ -53,6 +55,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &aws.Adaptor{}
|
return &aws.Adaptor{}
|
||||||
case constant.APITypeCohere:
|
case constant.APITypeCohere:
|
||||||
return &cohere.Adaptor{}
|
return &cohere.Adaptor{}
|
||||||
|
case constant.APITypeDify:
|
||||||
|
return &dify.Adaptor{}
|
||||||
|
case constant.APITypeJina:
|
||||||
|
return &jina.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
104
relay/relay_rerank.go
Normal file
104
relay/relay_rerank.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||||
|
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||||
|
for _, document := range rerankRequest.Documents {
|
||||||
|
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
||||||
|
if err == nil {
|
||||||
|
token += tkm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
var rerankRequest *dto.RerankRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if rerankRequest.Query == "" {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if len(rerankRequest.Documents) == 0 {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||||
|
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||||
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
|
var preConsumedQuota int
|
||||||
|
var ratio float64
|
||||||
|
var modelRatio float64
|
||||||
|
|
||||||
|
promptToken := getRerankPromptToken(*rerankRequest)
|
||||||
|
if !success {
|
||||||
|
preConsumedTokens := promptToken
|
||||||
|
modelRatio = common.GetModelRatio(rerankRequest.Model)
|
||||||
|
ratio = modelRatio * groupRatio
|
||||||
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
|
} else {
|
||||||
|
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
relayInfo.PromptTokens = promptToken
|
||||||
|
|
||||||
|
// pre-consume quota 预消耗配额
|
||||||
|
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.InitRerank(relayInfo, *rerankRequest)
|
||||||
|
|
||||||
|
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
|
openaiErr := service.RelayErrorHandler(resp)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||||
|
return nil
|
||||||
|
}
|
@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/moderations", controller.Relay)
|
relayV1Router.POST("/moderations", controller.Relay)
|
||||||
|
relayV1Router.POST("/rerank", controller.Relay)
|
||||||
}
|
}
|
||||||
|
|
||||||
relayMjRouter := router.Group("/mj")
|
relayMjRouter := router.Group("/mj")
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
func SetEventStreamHeaders(c *gin.Context) {
|
func SetEventStreamHeaders(c *gin.Context) {
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
@ -9,3 +15,23 @@ func SetEventStreamHeaders(c *gin.Context) {
|
|||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StringData(c *gin.Context, str string) {
|
||||||
|
str = strings.TrimPrefix(str, "data: ")
|
||||||
|
str = strings.TrimSuffix(str, "\r")
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func ObjectData(c *gin.Context, object interface{}) error {
|
||||||
|
jsonData, err := json.Marshal(object)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshalling object: %w", err)
|
||||||
|
}
|
||||||
|
StringData(c, string(jsonData))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Done(c *gin.Context) {
|
||||||
|
StringData(c, "[DONE]")
|
||||||
|
}
|
||||||
|
@ -24,3 +24,15 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
|
|||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||||
|
return &dto.ChatCompletionsStreamResponse{
|
||||||
|
Id: id,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: createAt,
|
||||||
|
Model: model,
|
||||||
|
SystemFingerprint: nil,
|
||||||
|
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
|
||||||
|
Usage: &usage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -554,7 +554,7 @@ const ChannelsTable = () => {
|
|||||||
);
|
);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setChannels(data);
|
setChannelFormat(data);
|
||||||
setActivePage(1);
|
setActivePage(1);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
|
@ -42,6 +42,7 @@ const OperationSetting = () => {
|
|||||||
MjAccountFilterEnabled: false,
|
MjAccountFilterEnabled: false,
|
||||||
MjModeClearEnabled: false,
|
MjModeClearEnabled: false,
|
||||||
MjForwardUrlEnabled: false,
|
MjForwardUrlEnabled: false,
|
||||||
|
MjActionCheckSuccessEnabled: false,
|
||||||
DrawingEnabled: false,
|
DrawingEnabled: false,
|
||||||
DataExportEnabled: false,
|
DataExportEnabled: false,
|
||||||
DataExportDefaultTime: 'hour',
|
DataExportDefaultTime: 'hour',
|
||||||
|
@ -104,6 +104,8 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||||
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
|
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
|
||||||
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
|
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
|
||||||
|
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
|
||||||
|
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
||||||
{
|
{
|
||||||
key: 22,
|
key: 22,
|
||||||
|
@ -160,8 +160,8 @@ export function renderModelPrice(
|
|||||||
let inputRatioPrice = modelRatio * 2.0;
|
let inputRatioPrice = modelRatio * 2.0;
|
||||||
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
||||||
let price =
|
let price =
|
||||||
(inputTokens / 1000000) * inputRatioPrice +
|
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
|
||||||
(completionTokens / 1000000) * completionRatioPrice;
|
(completionTokens / 1000000) * completionRatioPrice * groupRatio;
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<article>
|
<article>
|
||||||
|
@ -16,6 +16,7 @@ export default function SettingsDrawing(props) {
|
|||||||
MjAccountFilterEnabled: false,
|
MjAccountFilterEnabled: false,
|
||||||
MjForwardUrlEnabled: false,
|
MjForwardUrlEnabled: false,
|
||||||
MjModeClearEnabled: false,
|
MjModeClearEnabled: false,
|
||||||
|
MjActionCheckSuccessEnabled: false,
|
||||||
});
|
});
|
||||||
const refForm = useRef();
|
const refForm = useRef();
|
||||||
const [inputsRow, setInputsRow] = useState(inputs);
|
const [inputsRow, setInputsRow] = useState(inputs);
|
||||||
@ -156,6 +157,21 @@ export default function SettingsDrawing(props) {
|
|||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
</Col>
|
</Col>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Switch
|
||||||
|
field={'MjActionCheckSuccessEnabled'}
|
||||||
|
label={<>检测必须等待绘图成功才能进行放大等操作</>}
|
||||||
|
size='large'
|
||||||
|
checkedText='|'
|
||||||
|
uncheckedText='〇'
|
||||||
|
onChange={(value) =>
|
||||||
|
setInputs({
|
||||||
|
...inputs,
|
||||||
|
MjActionCheckSuccessEnabled: value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
<Row>
|
<Row>
|
||||||
<Button size='large' onClick={onSubmit}>
|
<Button size='large' onClick={onSubmit}>
|
||||||
|
Loading…
Reference in New Issue
Block a user