mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-18 11:33:42 +08:00
Compare commits
49 Commits
v0.2.8-alp
...
v0.2.7.2-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
220ab412e2 | ||
|
|
7029065892 | ||
|
|
0f687aab9a | ||
|
|
5e936b3923 | ||
|
|
d55cb35c1c | ||
|
|
5be4cbcaaf | ||
|
|
e67aa370bc | ||
|
|
7b36a2b885 | ||
|
|
c88f3741e6 | ||
|
|
4e7e206290 | ||
|
|
579fc8129e | ||
|
|
f55f63f412 | ||
|
|
0526c85732 | ||
|
|
b75134ece4 | ||
|
|
a075598757 | ||
|
|
a984daa503 | ||
|
|
90abe7f27d | ||
|
|
bb313eb26f | ||
|
|
02545e4856 | ||
|
|
49cec50908 | ||
|
|
4f6710e50c | ||
|
|
03b130f2b5 | ||
|
|
45b9de9df9 | ||
|
|
e062cf32e3 | ||
|
|
52debe7572 | ||
|
|
df6502733c | ||
|
|
9896ba0a64 | ||
|
|
e8b93ed6ec | ||
|
|
b0e234e8f5 | ||
|
|
20d71711d3 | ||
|
|
4246c4cdc1 | ||
|
|
1e536ee7d9 | ||
|
|
8a730cfe12 | ||
|
|
3ed4f2f0a9 | ||
|
|
bec18ed82d | ||
|
|
bd9bf4b732 | ||
|
|
1735e093db | ||
|
|
8af4e28f75 | ||
|
|
afe02c6aa5 | ||
|
|
e0ed59bfe3 | ||
|
|
bd7222118a | ||
|
|
cf3d894195 | ||
|
|
7011083201 | ||
|
|
752048dfb4 | ||
|
|
eb382d28ab | ||
|
|
a9e1078bca | ||
|
|
6c5b3b51b0 | ||
|
|
d306aea9e5 | ||
|
|
d4578e28b3 |
1
.github/workflows/docker-image-arm64.yml
vendored
1
.github/workflows/docker-image-arm64.yml
vendored
@@ -4,6 +4,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
|
- '!*-alpha*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
name:
|
name:
|
||||||
|
|||||||
@@ -2,6 +2,21 @@
|
|||||||
|
|
||||||
**简介**:Midjourney Proxy API文档
|
**简介**:Midjourney Proxy API文档
|
||||||
|
|
||||||
|
## 接口列表
|
||||||
|
支持的接口如下:
|
||||||
|
+ [x] /mj/submit/imagine
|
||||||
|
+ [x] /mj/submit/change
|
||||||
|
+ [x] /mj/submit/blend
|
||||||
|
+ [x] /mj/submit/describe
|
||||||
|
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||||
|
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||||
|
+ [x] /task/list-by-condition
|
||||||
|
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||||
|
+ [x] /mj/submit/modal
|
||||||
|
+ [x] /mj/submit/shorten
|
||||||
|
+ [x] /mj/task/{id}/image-seed
|
||||||
|
+ [x] /mj/insight-face/swap (InsightFace)
|
||||||
|
|
||||||
## 模型列表
|
## 模型列表
|
||||||
|
|
||||||
### midjourney-proxy支持
|
### midjourney-proxy支持
|
||||||
|
|||||||
28
README.md
28
README.md
@@ -16,19 +16,7 @@
|
|||||||
此分叉版本的主要变更如下:
|
此分叉版本的主要变更如下:
|
||||||
|
|
||||||
1. 全新的UI界面(部分界面还待更新)
|
1. 全新的UI界面(部分界面还待更新)
|
||||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
|
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
|
||||||
+ [x] /mj/submit/imagine
|
|
||||||
+ [x] /mj/submit/change
|
|
||||||
+ [x] /mj/submit/blend
|
|
||||||
+ [x] /mj/submit/describe
|
|
||||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
|
||||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
|
||||||
+ [x] /task/list-by-condition
|
|
||||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
|
||||||
+ [x] /mj/submit/modal
|
|
||||||
+ [x] /mj/submit/shorten
|
|
||||||
+ [x] /mj/task/{id}/image-seed
|
|
||||||
+ [x] /mj/insight-face/swap (InsightFace)
|
|
||||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||||
+ [x] 易支付
|
+ [x] 易支付
|
||||||
4. 支持用key查询使用额度:
|
4. 支持用key查询使用额度:
|
||||||
@@ -45,22 +33,21 @@
|
|||||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
|
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
|
||||||
+ [x] /suno/submit/music
|
14. 支持Rerank模型,目前仅兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
||||||
+ [x] /suno/submit/lyrics
|
|
||||||
+ [x] /suno/fetch
|
|
||||||
+ [x] /suno/fetch/:id
|
|
||||||
|
|
||||||
## 模型支持
|
## 模型支持
|
||||||
此版本额外支持以下模型:
|
此版本额外支持以下模型:
|
||||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||||
2. 智谱glm-4v,glm-4v识图
|
2. 智谱glm-4v,glm-4v识图
|
||||||
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
3. Anthropic Claude 3
|
||||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||||
6. [零一万物](https://platform.lingyiwanwu.com/)
|
6. [零一万物](https://platform.lingyiwanwu.com/)
|
||||||
7. 自定义渠道,支持填入完整调用地址
|
7. 自定义渠道,支持填入完整调用地址
|
||||||
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||||
|
9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
||||||
|
10. Dify
|
||||||
|
|
||||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||||
|
|
||||||
@@ -85,7 +72,8 @@
|
|||||||
|
|
||||||
## 比原版One API多出的配置
|
## 比原版One API多出的配置
|
||||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
||||||
|
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false`
|
||||||
|
- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型
|
||||||
## 部署
|
## 部署
|
||||||
### 部署要求
|
### 部署要求
|
||||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||||
|
|||||||
62
Rerank.md
Normal file
62
Rerank.md
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Rerank API文档
|
||||||
|
|
||||||
|
**简介**:Rerank API文档
|
||||||
|
|
||||||
|
## 接入Dify
|
||||||
|
模型供应商选择Jina,按要求填写模型信息即可接入Dify。
|
||||||
|
|
||||||
|
## 请求方式
|
||||||
|
|
||||||
|
Post: /v1/rerank
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "rerank-multilingual-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": [
|
||||||
|
"Carson City is the capital city of the American state of Nevada.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||||
|
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||||
|
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||||
|
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"document": {
|
||||||
|
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
|
||||||
|
},
|
||||||
|
"index": 2,
|
||||||
|
"relevance_score": 0.9999702
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"document": {
|
||||||
|
"text": "Carson City is the capital city of the American state of Nevada."
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
"relevance_score": 0.67800725
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"document": {
|
||||||
|
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
|
||||||
|
},
|
||||||
|
"index": 3,
|
||||||
|
"relevance_score": 0.02800752
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 158,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 158
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
7
Suno.md
7
Suno.md
@@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
**简介**:Suno API文档
|
**简介**:Suno API文档
|
||||||
|
|
||||||
|
## 接口列表
|
||||||
|
支持的接口如下:
|
||||||
|
+ [x] /suno/submit/music
|
||||||
|
+ [x] /suno/submit/lyrics
|
||||||
|
+ [x] /suno/fetch
|
||||||
|
+ [x] /suno/fetch/:id
|
||||||
|
|
||||||
## 模型列表
|
## 模型列表
|
||||||
|
|
||||||
### Suno API支持
|
### Suno API支持
|
||||||
|
|||||||
@@ -210,36 +210,39 @@ const (
|
|||||||
ChannelTypeCohere = 34
|
ChannelTypeCohere = 34
|
||||||
ChannelTypeMiniMax = 35
|
ChannelTypeMiniMax = 35
|
||||||
ChannelTypeSunoAPI = 36
|
ChannelTypeSunoAPI = 36
|
||||||
|
ChannelTypeDify = 37
|
||||||
|
ChannelTypeJina = 38
|
||||||
|
ChannelCloudflare = 39
|
||||||
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
"", // 0
|
"", // 0
|
||||||
"https://api.openai.com", // 1
|
"https://api.openai.com", // 1
|
||||||
"https://oa.api2d.net", // 2
|
"https://oa.api2d.net", // 2
|
||||||
"", // 3
|
"", // 3
|
||||||
"http://localhost:11434", // 4
|
"http://localhost:11434", // 4
|
||||||
"https://api.openai-sb.com", // 5
|
"https://api.openai-sb.com", // 5
|
||||||
"https://api.openaimax.com", // 6
|
"https://api.openaimax.com", // 6
|
||||||
"https://api.ohmygpt.com", // 7
|
"https://api.ohmygpt.com", // 7
|
||||||
"", // 8
|
"", // 8
|
||||||
"https://api.caipacity.com", // 9
|
"https://api.caipacity.com", // 9
|
||||||
"https://api.aiproxy.io", // 10
|
"https://api.aiproxy.io", // 10
|
||||||
"", // 11
|
"", // 11
|
||||||
"https://api.api2gpt.com", // 12
|
"https://api.api2gpt.com", // 12
|
||||||
"https://api.aigc2d.com", // 13
|
"https://api.aigc2d.com", // 13
|
||||||
"https://api.anthropic.com", // 14
|
"https://api.anthropic.com", // 14
|
||||||
"https://aip.baidubce.com", // 15
|
"https://aip.baidubce.com", // 15
|
||||||
"https://open.bigmodel.cn", // 16
|
"https://open.bigmodel.cn", // 16
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
"", // 18
|
"", // 18
|
||||||
"https://ai.360.cn", // 19
|
"https://ai.360.cn", // 19
|
||||||
"https://openrouter.ai/api", // 20
|
"https://openrouter.ai/api", // 20
|
||||||
"https://api.aiproxy.io", // 21
|
"https://api.aiproxy.io", // 21
|
||||||
"https://fastgpt.run/api/openapi", // 22
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
"https://hunyuan.cloud.tencent.com", //23
|
"https://hunyuan.tencentcloudapi.com", //23
|
||||||
"https://generativelanguage.googleapis.com", //24
|
"https://generativelanguage.googleapis.com", //24
|
||||||
"https://api.moonshot.cn", //25
|
"https://api.moonshot.cn", //25
|
||||||
"https://open.bigmodel.cn", //26
|
"https://open.bigmodel.cn", //26
|
||||||
@@ -253,4 +256,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.cohere.ai", //34
|
"https://api.cohere.ai", //34
|
||||||
"https://api.minimax.chat", //35
|
"https://api.minimax.chat", //35
|
||||||
"", //36
|
"", //36
|
||||||
|
"", //37
|
||||||
|
"https://api.jina.ai", //38
|
||||||
|
"https://api.cloudflare.com", //39
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,3 +24,15 @@ func GetEnvOrDefaultString(env string, defaultValue string) string {
|
|||||||
}
|
}
|
||||||
return os.Getenv(env)
|
return os.Getenv(env)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
b, err := strconv.ParseBool(os.Getenv(env))
|
||||||
|
if err != nil {
|
||||||
|
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|||||||
@@ -78,19 +78,23 @@ var defaultModelRatio = map[string]float64{
|
|||||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||||
"claude-3-5-sonnet-20240620": 1.5,
|
"claude-3-5-sonnet-20240620": 1.5,
|
||||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K
|
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K
|
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens //renamed to ERNIE-4.0-8K
|
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
||||||
"ERNIE-4.0-8K": 8.572, // ¥0.12 / 1k tokens
|
"ERNIE-3.5-8K-1222": 0.012 * RMB,
|
||||||
"ERNIE-3.5-8K": 0.8572, // ¥0.012 / 1k tokens
|
"ERNIE-Bot-8K": 0.024 * RMB,
|
||||||
"ERNIE-Speed-8K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-3.5-4K-0205": 0.012 * RMB,
|
||||||
"ERNIE-Speed-128K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-Speed-8K": 0.004 * RMB,
|
||||||
"ERNIE-Lite-8K": 0.2143, // ¥0.003 / 1k tokens
|
"ERNIE-Speed-128K": 0.004 * RMB,
|
||||||
"ERNIE-Tiny-8K": 0.0715, // ¥0.001 / 1k tokens
|
"ERNIE-Lite-8K-0922": 0.008 * RMB,
|
||||||
"ERNIE-Character-8K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-Lite-8K-0308": 0.003 * RMB,
|
||||||
"ERNIE-Functions-8K": 0.2858, // ¥0.004 / 1k tokens
|
"ERNIE-Tiny-8K": 0.001 * RMB,
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"BLOOMZ-7B": 0.004 * RMB,
|
||||||
|
"Embedding-V1": 0.002 * RMB,
|
||||||
|
"bge-large-zh": 0.002 * RMB,
|
||||||
|
"bge-large-en": 0.002 * RMB,
|
||||||
|
"tao-8k": 0.002 * RMB,
|
||||||
"PaLM-2": 1,
|
"PaLM-2": 1,
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
|
|||||||
@@ -5,3 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
||||||
|
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||||
|
|
||||||
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
|
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ var MjNotifyEnabled = false
|
|||||||
var MjAccountFilterEnabled = false
|
var MjAccountFilterEnabled = false
|
||||||
var MjModeClearEnabled = false
|
var MjModeClearEnabled = false
|
||||||
var MjForwardUrlEnabled = true
|
var MjForwardUrlEnabled = true
|
||||||
|
var MjActionCheckSuccessEnabled = true
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MjErrorUnknown = 5
|
MjErrorUnknown = 5
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -23,7 +25,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
tik := time.Now()
|
||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == common.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
@@ -38,34 +41,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
Body: nil,
|
Body: nil,
|
||||||
Header: make(http.Header),
|
Header: make(http.Header),
|
||||||
}
|
}
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
|
||||||
c.Set("channel", channel.Type)
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
//case common.ChannelTypeAIProxyLibrary:
|
|
||||||
// c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set("plugin", channel.Other)
|
|
||||||
}
|
|
||||||
|
|
||||||
meta := relaycommon.GenRelayInfo(c)
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
|
||||||
if adaptor == nil {
|
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
|
||||||
}
|
|
||||||
if testModel == "" {
|
if testModel == "" {
|
||||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||||
testModel = *channel.TestModel
|
testModel = *channel.TestModel
|
||||||
} else {
|
} else {
|
||||||
testModel = adaptor.GetModelList()[0]
|
if len(channel.GetModels()) > 0 {
|
||||||
|
testModel = channel.GetModels()[0]
|
||||||
|
} else {
|
||||||
|
testModel = "gpt-3.5-turbo"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
modelMapping := *channel.ModelMapping
|
modelMapping := *channel.ModelMapping
|
||||||
@@ -73,8 +58,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
|
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
return err, &openaiErr
|
|
||||||
}
|
}
|
||||||
if modelMap[testModel] != "" {
|
if modelMap[testModel] != "" {
|
||||||
testModel = modelMap[testModel]
|
testModel = modelMap[testModel]
|
||||||
@@ -82,6 +66,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Set("channel", channel.Type)
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||||
|
|
||||||
|
meta := relaycommon.GenRelayInfo(c)
|
||||||
|
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||||
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
|
}
|
||||||
|
|
||||||
request := buildTestRequest()
|
request := buildTestRequest()
|
||||||
request.Model = testModel
|
request.Model = testModel
|
||||||
meta.UpstreamModelName = testModel
|
meta.UpstreamModelName = testModel
|
||||||
@@ -89,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
|
|
||||||
adaptor.Init(meta, *request)
|
adaptor.Init(meta, *request)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
@@ -105,21 +103,39 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
}
|
}
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||||
err := relaycommon.RelayErrorHandler(resp)
|
err := relaycommon.RelayErrorHandler(resp)
|
||||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
||||||
}
|
}
|
||||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||||
}
|
}
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
return errors.New("usage is nil"), nil
|
return errors.New("usage is nil"), nil
|
||||||
}
|
}
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
// print result.Body
|
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
modelPrice, usePrice := common.GetModelPrice(testModel, false)
|
||||||
|
modelRatio := common.GetModelRatio(testModel)
|
||||||
|
completionRatio := common.GetCompletionRatio(testModel)
|
||||||
|
ratio := modelRatio
|
||||||
|
quota := 0
|
||||||
|
if !usePrice {
|
||||||
|
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
|
||||||
|
quota = int(math.Round(float64(quota) * ratio))
|
||||||
|
if ratio != 0 && quota <= 0 {
|
||||||
|
quota = 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quota = int(modelPrice * common.QuotaPerUnit)
|
||||||
|
}
|
||||||
|
tok := time.Now()
|
||||||
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
|
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||||
|
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -140,7 +156,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
channelId, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -148,7 +164,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err := model.GetChannelById(id, true)
|
channel, err := model.GetChannelById(channelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -205,7 +221,7 @@ func testAllChannels(notify bool) error {
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, openaiErr := testChannel(channel, "")
|
err, openaiWithStatusErr := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
@@ -214,27 +230,29 @@ func testAllChannels(notify bool) error {
|
|||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
ban = true
|
ban = true
|
||||||
}
|
}
|
||||||
if openaiErr != nil {
|
|
||||||
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
|
// request error disables the channel
|
||||||
ban = true
|
if openaiWithStatusErr != nil {
|
||||||
|
oaiErr := openaiWithStatusErr.Error
|
||||||
|
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||||
|
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse *int to bool
|
// parse *int to bool
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
ban = false
|
ban = false
|
||||||
}
|
}
|
||||||
if openaiErr != nil {
|
|
||||||
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
|
// disable channel
|
||||||
StatusCode: -1,
|
if ban && isChannelEnabled {
|
||||||
Error: *openaiErr,
|
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
LocalError: false,
|
|
||||||
}
|
|
||||||
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
|
|
||||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
|
||||||
}
|
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// enable channel
|
||||||
|
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||||
|
service.EnableChannel(channel.Id, channel.Name)
|
||||||
|
}
|
||||||
|
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.AudioHelper(c, relayMode)
|
err = relay.AudioHelper(c, relayMode)
|
||||||
|
case relayconstant.RelayModeRerank:
|
||||||
|
err = relay.RerankHelper(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
@@ -40,12 +42,13 @@ func Relay(c *gin.Context) {
|
|||||||
retryTimes := common.RetryTimes
|
retryTimes := common.RetryTimes
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
channelType := c.GetInt("channel_type")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
openaiErr := relayHandler(c, relayMode)
|
openaiErr := relayHandler(c, relayMode)
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
go processChannelError(c, channelId, openaiErr)
|
go processChannelError(c, channelId, channelType, openaiErr)
|
||||||
} else {
|
} else {
|
||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
@@ -66,7 +69,7 @@ func Relay(c *gin.Context) {
|
|||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
openaiErr = relayHandler(c, relayMode)
|
openaiErr = relayHandler(c, relayMode)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
go processChannelError(c, channelId, openaiErr)
|
go processChannelError(c, channelId, channel.Type, openaiErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
@@ -125,10 +128,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||||
if service.ShouldDisableChannel(err) && autoBan {
|
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||||
channelName := c.GetString("channel_name")
|
channelName := c.GetString("channel_name")
|
||||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
19
dto/rerank.go
Normal file
19
dto/rerank.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type RerankRequest struct {
|
||||||
|
Documents []any `json:"documents"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
TopN int `json:"top_n"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RerankResponseDocument struct {
|
||||||
|
Document any `json:"document"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RerankResponse struct {
|
||||||
|
Results []RerankResponseDocument `json:"results"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
Prompt any `json:"prompt,omitempty"`
|
Prompt any `json:"prompt,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
@@ -43,8 +44,12 @@ type OpenAIFunction struct {
|
|||||||
Parameters any `json:"parameters,omitempty"`
|
Parameters any `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
type StreamOptions struct {
|
||||||
return int64(r.MaxTokens)
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r GeneralOpenAIRequest) GetMaxTokens() int {
|
||||||
|
return int(r.MaxTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
|||||||
@@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
|
|||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
|
|
||||||
return c.Content == nil && len(c.ToolCalls) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||||
c.Content = &s
|
c.Content = &s
|
||||||
}
|
}
|
||||||
@@ -102,10 +98,23 @@ type ChatCompletionsStreamResponse struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint *string `json:"system_fingerprint"`
|
SystemFingerprint *string `json:"system_fingerprint"`
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
|
||||||
|
if c.SystemFingerprint == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *c.SystemFingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
|
||||||
|
c.SystemFingerprint = &s
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseSimple struct {
|
type ChatCompletionsStreamResponseSimple struct {
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionsStreamResponse struct {
|
type CompletionsStreamResponse struct {
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
|
c.Set("channel_type", channel.Type)
|
||||||
ban := true
|
ban := true
|
||||||
// parse *int to bool
|
// parse *int to bool
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
@@ -197,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
//case common.ChannelTypeAIProxyLibrary:
|
|
||||||
// c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
case common.ChannelTypeGemini:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeAli:
|
case common.ChannelTypeAli:
|
||||||
c.Set("plugin", channel.Other)
|
c.Set("plugin", channel.Other)
|
||||||
|
case common.ChannelCloudflare:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
@@ -33,6 +34,13 @@ type Channel struct {
|
|||||||
OtherInfo string `json:"other_info"`
|
OtherInfo string `json:"other_info"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetModels() []string {
|
||||||
|
if channel.Models == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||||
otherInfo := make(map[string]interface{})
|
otherInfo := make(map[string]interface{})
|
||||||
if channel.OtherInfo != "" {
|
if channel.OtherInfo != "" {
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||||
|
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
||||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||||
@@ -210,6 +211,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
constant.MjModeClearEnabled = boolValue
|
constant.MjModeClearEnabled = boolValue
|
||||||
case "MjForwardUrlEnabled":
|
case "MjForwardUrlEnabled":
|
||||||
constant.MjForwardUrlEnabled = boolValue
|
constant.MjForwardUrlEnabled = boolValue
|
||||||
|
case "MjActionCheckSuccessEnabled":
|
||||||
|
constant.MjActionCheckSuccessEnabled = boolValue
|
||||||
case "CheckSensitiveEnabled":
|
case "CheckSensitiveEnabled":
|
||||||
constant.CheckSensitiveEnabled = boolValue
|
constant.CheckSensitiveEnabled = boolValue
|
||||||
case "CheckSensitiveOnPromptEnabled":
|
case "CheckSensitiveOnPromptEnabled":
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New("兑换失败," + err.Error())
|
return 0, errors.New("兑换失败," + err.Error())
|
||||||
}
|
}
|
||||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
|
||||||
return redemption.Quota, nil
|
return redemption.Quota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -250,11 +250,9 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
|
|||||||
if userQuota < quota {
|
if userQuota < quota {
|
||||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||||
}
|
}
|
||||||
if !token.UnlimitedQuota {
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
if err != nil {
|
||||||
if err != nil {
|
return 0, err
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
err = DecreaseUserQuota(token.UserId, quota)
|
err = DecreaseUserQuota(token.UserId, quota)
|
||||||
return userQuota - quota, err
|
return userQuota - quota, err
|
||||||
@@ -272,15 +270,13 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !token.UnlimitedQuota {
|
if quota > 0 {
|
||||||
if quota > 0 {
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
} else {
|
||||||
} else {
|
err = IncreaseTokenQuota(tokenId, -quota)
|
||||||
err = IncreaseTokenQuota(tokenId, -quota)
|
}
|
||||||
}
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if sendEmail {
|
if sendEmail {
|
||||||
|
|||||||
@@ -298,7 +298,8 @@ func (user *User) ValidateAndFill() (err error) {
|
|||||||
if user.Username == "" || password == "" {
|
if user.Username == "" || password == "" {
|
||||||
return errors.New("用户名或密码为空")
|
return errors.New("用户名或密码为空")
|
||||||
}
|
}
|
||||||
DB.Where(User{Username: user.Username}).First(user)
|
// find buy username or email
|
||||||
|
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
|
||||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||||
if !okay || user.Status != common.UserStatusEnabled {
|
if !okay || user.Status != common.UserStatusEnabled {
|
||||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ import (
|
|||||||
type Adaptor interface {
|
type Adaptor interface {
|
||||||
// Init IsStream bool
|
// Init IsStream bool
|
||||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||||
|
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
||||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
|
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -39,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||||
return baiduEmbeddingRequest, nil
|
return baiduEmbeddingRequest, nil
|
||||||
@@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
@@ -36,7 +41,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@@ -53,13 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return claudeReq, err
|
return claudeReq, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage = awsHandler(c, info, a.RequestMode)
|
err, usage = awsHandler(c, info, a.RequestMode)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
relaymodel "one-api/dto"
|
relaymodel "one-api/dto"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
awsCli, err := newAwsClient(c, info)
|
awsCli, err := newAwsClient(c, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||||
@@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
|||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
event, ok := <-stream.Events()
|
event, ok := <-stream.Events()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package baidu
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,44 +16,74 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
var fullRequestURL string
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||||
switch info.UpstreamModelName {
|
suffix := "chat/"
|
||||||
case "ERNIE-Bot-4":
|
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
suffix = "embeddings/"
|
||||||
case "ERNIE-Bot-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
|
||||||
case "ERNIE-Bot":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
|
||||||
case "ERNIE-Speed":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
|
||||||
case "ERNIE-Bot-turbo":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
||||||
case "BLOOMZ-7B":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
|
||||||
case "ERNIE-4.0-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
|
||||||
case "ERNIE-3.5-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
|
||||||
case "ERNIE-Speed-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
|
||||||
case "ERNIE-Character-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k"
|
|
||||||
case "ERNIE-Functions-8K":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k"
|
|
||||||
case "ERNIE-Lite-8K-0922":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
||||||
case "Yi-34B-Chat":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat"
|
|
||||||
case "Embedding-V1":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
|
||||||
default:
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName)
|
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
|
||||||
|
suffix = "embeddings/"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
|
||||||
|
suffix = "embeddings/"
|
||||||
|
}
|
||||||
|
switch info.UpstreamModelName {
|
||||||
|
case "ERNIE-4.0":
|
||||||
|
suffix += "completions_pro"
|
||||||
|
case "ERNIE-Bot-4":
|
||||||
|
suffix += "completions_pro"
|
||||||
|
case "ERNIE-Bot":
|
||||||
|
suffix += "completions"
|
||||||
|
case "ERNIE-Bot-turbo":
|
||||||
|
suffix += "eb-instant"
|
||||||
|
case "ERNIE-Speed":
|
||||||
|
suffix += "ernie_speed"
|
||||||
|
case "ERNIE-4.0-8K":
|
||||||
|
suffix += "completions_pro"
|
||||||
|
case "ERNIE-3.5-8K":
|
||||||
|
suffix += "completions"
|
||||||
|
case "ERNIE-3.5-8K-0205":
|
||||||
|
suffix += "ernie-3.5-8k-0205"
|
||||||
|
case "ERNIE-3.5-8K-1222":
|
||||||
|
suffix += "ernie-3.5-8k-1222"
|
||||||
|
case "ERNIE-Bot-8K":
|
||||||
|
suffix += "ernie_bot_8k"
|
||||||
|
case "ERNIE-3.5-4K-0205":
|
||||||
|
suffix += "ernie-3.5-4k-0205"
|
||||||
|
case "ERNIE-Speed-8K":
|
||||||
|
suffix += "ernie_speed"
|
||||||
|
case "ERNIE-Speed-128K":
|
||||||
|
suffix += "ernie-speed-128k"
|
||||||
|
case "ERNIE-Lite-8K-0922":
|
||||||
|
suffix += "eb-instant"
|
||||||
|
case "ERNIE-Lite-8K-0308":
|
||||||
|
suffix += "ernie-lite-8k"
|
||||||
|
case "ERNIE-Tiny-8K":
|
||||||
|
suffix += "ernie-tiny-8k"
|
||||||
|
case "BLOOMZ-7B":
|
||||||
|
suffix += "bloomz_7b1"
|
||||||
|
case "Embedding-V1":
|
||||||
|
suffix += "embedding-v1"
|
||||||
|
case "bge-large-zh":
|
||||||
|
suffix += "bge_large_zh"
|
||||||
|
case "bge-large-en":
|
||||||
|
suffix += "bge_large_en"
|
||||||
|
case "tao-8k":
|
||||||
|
suffix += "tao_8k"
|
||||||
|
default:
|
||||||
|
suffix += strings.ToLower(info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
|
||||||
var accessToken string
|
var accessToken string
|
||||||
var err error
|
var err error
|
||||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||||
@@ -68,11 +99,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||||
return baiduEmbeddingRequest, nil
|
return baiduEmbeddingRequest, nil
|
||||||
@@ -82,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,22 @@
|
|||||||
package baidu
|
package baidu
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"ERNIE-3.5-8K",
|
|
||||||
"ERNIE-4.0-8K",
|
"ERNIE-4.0-8K",
|
||||||
|
"ERNIE-3.5-8K",
|
||||||
|
"ERNIE-3.5-8K-0205",
|
||||||
|
"ERNIE-3.5-8K-1222",
|
||||||
|
"ERNIE-Bot-8K",
|
||||||
|
"ERNIE-3.5-4K-0205",
|
||||||
"ERNIE-Speed-8K",
|
"ERNIE-Speed-8K",
|
||||||
"ERNIE-Speed-128K",
|
"ERNIE-Speed-128K",
|
||||||
"ERNIE-Lite-8K",
|
"ERNIE-Lite-8K-0922",
|
||||||
|
"ERNIE-Lite-8K-0308",
|
||||||
"ERNIE-Tiny-8K",
|
"ERNIE-Tiny-8K",
|
||||||
"ERNIE-Character-8K",
|
"BLOOMZ-7B",
|
||||||
"ERNIE-Functions-8K",
|
|
||||||
//"ERNIE-Bot-4",
|
|
||||||
//"ERNIE-Bot-8K",
|
|
||||||
//"ERNIE-Bot",
|
|
||||||
//"ERNIE-Speed",
|
|
||||||
//"ERNIE-Bot-turbo",
|
|
||||||
"Embedding-V1",
|
"Embedding-V1",
|
||||||
|
"bge-large-zh",
|
||||||
|
"bge-large-en",
|
||||||
|
"tao-8k",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "baidu"
|
var ChannelName = "baidu"
|
||||||
|
|||||||
@@ -11,9 +11,16 @@ type BaiduMessage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BaiduChatRequest struct {
|
type BaiduChatRequest struct {
|
||||||
Messages []BaiduMessage `json:"messages"`
|
Messages []BaiduMessage `json:"messages"`
|
||||||
Stream bool `json:"stream"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
UserId string `json:"user_id,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
DisableSearch bool `json:"disable_search,omitempty"`
|
||||||
|
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||||
|
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
|
|||||||
@@ -22,17 +22,33 @@ import (
|
|||||||
var baiduTokenStore sync.Map
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
baiduRequest := BaiduChatRequest{
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
PenaltyScore: request.FrequencyPenalty,
|
||||||
|
Stream: request.Stream,
|
||||||
|
DisableSearch: false,
|
||||||
|
EnableCitation: false,
|
||||||
|
UserId: request.User,
|
||||||
|
}
|
||||||
|
if request.MaxTokens != 0 {
|
||||||
|
maxTokens := int(request.MaxTokens)
|
||||||
|
if request.MaxTokens == 1 {
|
||||||
|
maxTokens = 2
|
||||||
|
}
|
||||||
|
baiduRequest.MaxOutputTokens = &maxTokens
|
||||||
|
}
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
messages = append(messages, BaiduMessage{
|
if message.Role == "system" {
|
||||||
Role: message.Role,
|
baiduRequest.System = message.StringContent()
|
||||||
Content: message.StringContent(),
|
} else {
|
||||||
})
|
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
|
||||||
}
|
Role: message.Role,
|
||||||
return &BaiduChatRequest{
|
Content: message.StringContent(),
|
||||||
Messages: messages,
|
})
|
||||||
Stream: request.Stream,
|
}
|
||||||
}
|
}
|
||||||
|
return &baiduRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||||
|
|||||||
@@ -21,6 +21,11 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
@@ -48,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,6 +72,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
|||||||
if claudeRequest.MaxTokens == 0 {
|
if claudeRequest.MaxTokens == 0 {
|
||||||
claudeRequest.MaxTokens = 4096
|
claudeRequest.MaxTokens = 4096
|
||||||
}
|
}
|
||||||
|
if textRequest.Stop != nil {
|
||||||
|
// stop maybe string/array string, convert to array string
|
||||||
|
switch textRequest.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
|
||||||
|
case []interface{}:
|
||||||
|
stopSequences := make([]string, 0)
|
||||||
|
for _, stop := range textRequest.Stop.([]interface{}) {
|
||||||
|
stopSequences = append(stopSequences, stop.(string))
|
||||||
|
}
|
||||||
|
claudeRequest.StopSequences = stopSequences
|
||||||
|
}
|
||||||
|
}
|
||||||
formatMessages := make([]dto.Message, 0)
|
formatMessages := make([]dto.Message, 0)
|
||||||
var lastMessage *dto.Message
|
var lastMessage *dto.Message
|
||||||
for i, message := range textRequest.Messages {
|
for i, message := range textRequest.Messages {
|
||||||
@@ -330,22 +343,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
|
|
||||||
jsonStr, err := json.Marshal(response)
|
err = service.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
common.SysError(err.Error())
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
@@ -356,6 +362,18 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
76
relay/channel/cloudflare/adaptor.go
Normal file
76
relay/channel/cloudflare/adaptor.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeChatCompletions:
|
||||||
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeCompletions:
|
||||||
|
return convertCf2CompletionsRequest(*request), nil
|
||||||
|
default:
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = cfStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = cfHandler(c, resp, info)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
38
relay/channel/cloudflare/constant.go
Normal file
38
relay/channel/cloudflare/constant.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"@cf/meta/llama-2-7b-chat-fp16",
|
||||||
|
"@cf/meta/llama-2-7b-chat-int8",
|
||||||
|
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||||
|
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
|
||||||
|
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
|
||||||
|
"@cf/deepseek-ai/deepseek-math-7b-base",
|
||||||
|
"@cf/deepseek-ai/deepseek-math-7b-instruct",
|
||||||
|
"@cf/thebloke/discolm-german-7b-v1-awq",
|
||||||
|
"@cf/tiiuae/falcon-7b-instruct",
|
||||||
|
"@cf/google/gemma-2b-it-lora",
|
||||||
|
"@hf/google/gemma-7b-it",
|
||||||
|
"@cf/google/gemma-7b-it-lora",
|
||||||
|
"@hf/nousresearch/hermes-2-pro-mistral-7b",
|
||||||
|
"@hf/thebloke/llama-2-13b-chat-awq",
|
||||||
|
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
|
||||||
|
"@cf/meta/llama-3-8b-instruct",
|
||||||
|
"@hf/thebloke/llamaguard-7b-awq",
|
||||||
|
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
|
||||||
|
"@hf/mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
|
||||||
|
"@hf/thebloke/neural-chat-7b-v3-1-awq",
|
||||||
|
"@cf/openchat/openchat-3.5-0106",
|
||||||
|
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
|
||||||
|
"@cf/microsoft/phi-2",
|
||||||
|
"@cf/qwen/qwen1.5-0.5b-chat",
|
||||||
|
"@cf/qwen/qwen1.5-1.8b-chat",
|
||||||
|
"@cf/qwen/qwen1.5-14b-chat-awq",
|
||||||
|
"@cf/qwen/qwen1.5-7b-chat-awq",
|
||||||
|
"@cf/defog/sqlcoder-7b-2",
|
||||||
|
"@hf/nexusflow/starling-lm-7b-beta",
|
||||||
|
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
|
||||||
|
"@hf/thebloke/zephyr-7b-beta-awq",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "cloudflare"
|
||||||
13
relay/channel/cloudflare/model.go
Normal file
13
relay/channel/cloudflare/model.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
type CfRequest struct {
|
||||||
|
Messages []dto.Message `json:"messages,omitempty"`
|
||||||
|
Lora string `json:"lora,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Raw bool `json:"raw,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
}
|
||||||
121
relay/channel/cloudflare/relay_cloudflare.go
Normal file
121
relay/channel/cloudflare/relay_cloudflare.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||||
|
p, _ := textRequest.Prompt.(string)
|
||||||
|
return &CfRequest{
|
||||||
|
Prompt: p,
|
||||||
|
MaxTokens: textRequest.GetMaxTokens(),
|
||||||
|
Stream: textRequest.Stream,
|
||||||
|
Temperature: textRequest.Temperature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
service.SetEventStreamHeaders(c)
|
||||||
|
id := service.GetResponseID(c)
|
||||||
|
var responseText string
|
||||||
|
isFirst := true
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < len("data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = strings.TrimPrefix(data, "data: ")
|
||||||
|
data = strings.TrimSuffix(data, "\r")
|
||||||
|
|
||||||
|
if data == "[DONE]" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var response dto.ChatCompletionsStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &response)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
choice.Delta.Role = "assistant"
|
||||||
|
responseText += choice.Delta.GetContentString()
|
||||||
|
}
|
||||||
|
response.Id = id
|
||||||
|
response.Model = info.UpstreamModelName
|
||||||
|
err = service.ObjectData(c, response)
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||||
|
}
|
||||||
|
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var response dto.TextResponse
|
||||||
|
err = json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
response.Model = info.UpstreamModelName
|
||||||
|
var responseText string
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
responseText += choice.Message.StringContent()
|
||||||
|
}
|
||||||
|
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
response.Usage = *usage
|
||||||
|
response.Id = service.GetResponseID(c)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
@@ -8,16 +8,24 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
|
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
@@ -26,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
return requestOpenAI2Cohere(*request), nil
|
return requestOpenAI2Cohere(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return requestConvertRerank2Cohere(request), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = cohereStreamHandler(c, resp, info)
|
err, usage = cohereRerankHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
if info.IsStream {
|
||||||
|
err, usage = cohereStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package cohere
|
|||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||||
|
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "cohere"
|
var ChannelName = "cohere"
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package cohere
|
package cohere
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
type CohereRequest struct {
|
type CohereRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
ChatHistory []ChatHistory `json:"chat_history"`
|
ChatHistory []ChatHistory `json:"chat_history"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
MaxTokens int64 `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatHistory struct {
|
type ChatHistory struct {
|
||||||
@@ -28,6 +30,19 @@ type CohereResponseResult struct {
|
|||||||
Meta CohereMeta `json:"meta"`
|
Meta CohereMeta `json:"meta"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CohereRerankRequest struct {
|
||||||
|
Documents []any `json:"documents"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
TopN int `json:"top_n"`
|
||||||
|
ReturnDocuments bool `json:"return_documents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CohereRerankResponseResult struct {
|
||||||
|
Results []dto.RerankResponseDocument `json:"results"`
|
||||||
|
Meta CohereMeta `json:"meta"`
|
||||||
|
}
|
||||||
|
|
||||||
type CohereMeta struct {
|
type CohereMeta struct {
|
||||||
//Tokens CohereTokens `json:"tokens"`
|
//Tokens CohereTokens `json:"tokens"`
|
||||||
BilledUnits CohereBilledUnits `json:"billed_units"`
|
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||||
|
|||||||
@@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
|||||||
return &cohereReq
|
return &cohereReq
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
|
||||||
|
if rerankRequest.TopN == 0 {
|
||||||
|
rerankRequest.TopN = 1
|
||||||
|
}
|
||||||
|
cohereReq := CohereRerankRequest{
|
||||||
|
Query: rerankRequest.Query,
|
||||||
|
Documents: rerankRequest.Documents,
|
||||||
|
Model: rerankRequest.Model,
|
||||||
|
TopN: rerankRequest.TopN,
|
||||||
|
ReturnDocuments: true,
|
||||||
|
}
|
||||||
|
return &cohereReq
|
||||||
|
}
|
||||||
|
|
||||||
func stopReasonCohere2OpenAI(reason string) string {
|
func stopReasonCohere2OpenAI(reason string) string {
|
||||||
switch reason {
|
switch reason {
|
||||||
case "COMPLETE":
|
case "COMPLETE":
|
||||||
@@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
|||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var cohereResp CohereRerankResponseResult
|
||||||
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
usage := dto.Usage{}
|
||||||
|
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.CompletionTokens = 0
|
||||||
|
usage.TotalTokens = info.PromptTokens
|
||||||
|
} else {
|
||||||
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||||
|
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||||
|
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
var rerankResp dto.RerankResponse
|
||||||
|
rerankResp.Results = cohereResp.Results
|
||||||
|
rerankResp.Usage = usage
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(rerankResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|||||||
65
relay/channel/dify/adaptor.go
Normal file
65
relay/channel/dify/adaptor.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return requestOpenAI2Dify(*request), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = difyStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = difyHandler(c, resp, info)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
5
relay/channel/dify/constants.go
Normal file
5
relay/channel/dify/constants.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
var ModelList []string
|
||||||
|
|
||||||
|
var ChannelName = "dify"
|
||||||
35
relay/channel/dify/dto.go
Normal file
35
relay/channel/dify/dto.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
type DifyChatRequest struct {
|
||||||
|
Inputs map[string]interface{} `json:"inputs"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
ResponseMode string `json:"response_mode"`
|
||||||
|
User string `json:"user"`
|
||||||
|
AutoGenerateName bool `json:"auto_generate_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyMetaData struct {
|
||||||
|
Usage dto.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyData struct {
|
||||||
|
WorkflowId string `json:"workflow_id"`
|
||||||
|
NodeId string `json:"node_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyChatCompletionResponse struct {
|
||||||
|
ConversationId string `json:"conversation_id"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
CreateAt int64 `json:"create_at"`
|
||||||
|
MetaData DifyMetaData `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DifyChunkChatCompletionResponse struct {
|
||||||
|
Event string `json:"event"`
|
||||||
|
ConversationId string `json:"conversation_id"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
Data DifyData `json:"data"`
|
||||||
|
MetaData DifyMetaData `json:"metadata"`
|
||||||
|
}
|
||||||
156
relay/channel/dify/relay-dify.go
Normal file
156
relay/channel/dify/relay-dify.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package dify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
|
||||||
|
content := ""
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
content += "SYSTEM: \n" + message.StringContent() + "\n"
|
||||||
|
} else if message.Role == "assistant" {
|
||||||
|
content += "ASSISTANT: \n" + message.StringContent() + "\n"
|
||||||
|
} else {
|
||||||
|
content += "USER: \n" + message.StringContent() + "\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mode := "blocking"
|
||||||
|
if request.Stream {
|
||||||
|
mode = "streaming"
|
||||||
|
}
|
||||||
|
user := request.User
|
||||||
|
if user == "" {
|
||||||
|
user = "api-user"
|
||||||
|
}
|
||||||
|
return &DifyChatRequest{
|
||||||
|
Inputs: make(map[string]interface{}),
|
||||||
|
Query: content,
|
||||||
|
ResponseMode: mode,
|
||||||
|
User: user,
|
||||||
|
AutoGenerateName: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "dify",
|
||||||
|
}
|
||||||
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
|
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
|
||||||
|
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
||||||
|
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
||||||
|
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
||||||
|
} else if difyResponse.Event == "message" {
|
||||||
|
choice.Delta.SetContentString(difyResponse.Answer)
|
||||||
|
}
|
||||||
|
response.Choices = append(response.Choices, choice)
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var responseText string
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
service.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = strings.TrimPrefix(data, "data:")
|
||||||
|
var difyResponse DifyChunkChatCompletionResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if difyResponse.Event == "message_end" {
|
||||||
|
usage = &difyResponse.MetaData.Usage
|
||||||
|
break
|
||||||
|
} else if difyResponse.Event == "error" {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||||
|
if len(openaiResponse.Choices) != 0 {
|
||||||
|
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = service.ObjectData(c, openaiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.SysError("error reading stream: " + err.Error())
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
common.SysError("close_response_body_failed: " + err.Error())
|
||||||
|
}
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var difyResponse DifyChatCompletionResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &difyResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
|
Id: difyResponse.ConversationId,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Usage: difyResponse.MetaData.Usage,
|
||||||
|
}
|
||||||
|
content, _ := json.Marshal(difyResponse.Answer)
|
||||||
|
choice := dto.OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: dto.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &difyResponse.MetaData.Usage
|
||||||
|
}
|
||||||
@@ -9,12 +9,14 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,22 +51,24 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
return CovertGemini2OpenAI(*request), nil
|
return CovertGemini2OpenAI(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||||
err, responseText = geminiChatStreamHandler(c, resp, info)
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
} else {
|
||||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
|
|||||||
type GeminiChatResponse struct {
|
type GeminiChatResponse struct {
|
||||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||||
|
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiUsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
|
responseJson := ""
|
||||||
|
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createAt := common.GetTimestamp()
|
||||||
|
var usage = &dto.Usage{}
|
||||||
dataChan := make(chan string, 5)
|
dataChan := make(chan string, 5)
|
||||||
stopChan := make(chan bool, 2)
|
stopChan := make(chan bool, 2)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
@@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
go func() {
|
go func() {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
|
responseJson += data
|
||||||
data = strings.TrimSpace(data)
|
data = strings.TrimSpace(data)
|
||||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||||
continue
|
continue
|
||||||
@@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.SetContentString(dummy.Content)
|
choice.Delta.SetContentString(dummy.Content)
|
||||||
response := dto.ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: createAt,
|
||||||
Model: "gemini-pro",
|
Model: info.UpstreamModelName,
|
||||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
var geminiChatResponses []GeminiChatResponse
|
||||||
|
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
log.Printf("cannot get gemini usage: %s", err.Error())
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
for _, response := range geminiChatResponses {
|
||||||
|
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
||||||
|
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
||||||
|
}
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
}
|
}
|
||||||
return nil, responseText
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
@@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||||
}
|
}
|
||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
|||||||
64
relay/channel/jina/adaptor.go
Normal file
64
relay/channel/jina/adaptor.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package jina
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
|
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||||
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
|
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
return "", errors.New("invalid relay mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
|
err, usage = jinaRerankHandler(c, resp)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
8
relay/channel/jina/constant.go
Normal file
8
relay/channel/jina/constant.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package jina
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"jina-clip-v1",
|
||||||
|
"jina-reranker-v2-base-multilingual",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "jina"
|
||||||
35
relay/channel/jina/relay-jina.go
Normal file
35
relay/channel/jina/relay-jina.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package jina
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var jinaResp dto.RerankResponse
|
||||||
|
err = json.Unmarshal(responseBody, &jinaResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(jinaResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &jinaResp.Usage
|
||||||
|
}
|
||||||
@@ -16,6 +16,9 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,11 +36,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
return requestOpenAI2Embeddings(*request), nil
|
return requestOpenAI2Embeddings(*request), nil
|
||||||
default:
|
default:
|
||||||
@@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@@ -52,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"one-api/relay/channel/minimax"
|
"one-api/relay/channel/minimax"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,6 +21,13 @@ type Adaptor struct {
|
|||||||
ChannelType int
|
ChannelType int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
}
|
}
|
||||||
@@ -67,10 +73,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||||
|
request.StreamOptions = nil
|
||||||
|
}
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,11 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
err, usage, _, _ = OpenaiStreamHandler(c, resp, info)
|
||||||
var toolCount int
|
|
||||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
usage.CompletionTokens += toolCount * 7
|
|
||||||
} else {
|
} else {
|
||||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,37 +14,34 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
|
||||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
hasStreamUsage := false
|
||||||
|
responseId := ""
|
||||||
|
var createAt int64 = 0
|
||||||
|
var systemFingerprint string
|
||||||
|
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
|
var usage = &dto.Usage{}
|
||||||
toolCount := 0
|
toolCount := 0
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
var streamItems []string // store stream items
|
||||||
return 0, nil, nil
|
|
||||||
}
|
service.SetEventStreamHeaders(c)
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||||
}
|
defer ticker.Stop()
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string, 5)
|
|
||||||
stopChan := make(chan bool, 2)
|
stopChan := make(chan bool, 2)
|
||||||
defer close(stopChan)
|
defer close(stopChan)
|
||||||
defer close(dataChan)
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
go func() {
|
go func() {
|
||||||
wg.Add(1)
|
|
||||||
defer wg.Done()
|
|
||||||
var streamItems []string // store stream items
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
if len(data) < 6 { // ignore blank line or wrong format
|
||||||
continue
|
continue
|
||||||
@@ -52,43 +49,43 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
|
||||||
// send data timeout, stop the stream
|
|
||||||
common.LogError(c, "send data timeout, stop the stream")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
data = data[6:]
|
data = data[6:]
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
|
service.StringData(c, data)
|
||||||
streamItems = append(streamItems, data)
|
streamItems = append(streamItems, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
stopChan <- true
|
||||||
switch info.RelayMode {
|
}()
|
||||||
case relayconstant.RelayModeChatCompletions:
|
|
||||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
select {
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
case <-ticker.C:
|
||||||
if err != nil {
|
// 超时处理逻辑
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.LogError(c, "streaming timeout")
|
||||||
for _, item := range streamItems {
|
case <-stopChan:
|
||||||
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
// 正常结束
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
}
|
||||||
if err == nil {
|
|
||||||
for _, choice := range streamResponse.Choices {
|
// 计算token
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||||
if choice.Delta.ToolCalls != nil {
|
switch info.RelayMode {
|
||||||
if len(choice.Delta.ToolCalls) > toolCount {
|
case relayconstant.RelayModeChatCompletions:
|
||||||
toolCount = len(choice.Delta.ToolCalls)
|
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||||
}
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
for _, tool := range choice.Delta.ToolCalls {
|
if err != nil {
|
||||||
responseTextBuilder.WriteString(tool.Function.Name)
|
// 一次性解析失败,逐个解析
|
||||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
}
|
for _, item := range streamItems {
|
||||||
}
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
}
|
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||||
|
if err == nil {
|
||||||
|
responseId = streamResponse.Id
|
||||||
|
createAt = streamResponse.Created
|
||||||
|
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||||
|
if service.ValidUsage(streamResponse.Usage) {
|
||||||
|
usage = streamResponse.Usage
|
||||||
|
hasStreamUsage = true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for _, streamResponse := range streamResponses {
|
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
if choice.Delta.ToolCalls != nil {
|
if choice.Delta.ToolCalls != nil {
|
||||||
@@ -103,60 +100,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case relayconstant.RelayModeCompletions:
|
} else {
|
||||||
var streamResponses []dto.CompletionsStreamResponse
|
for _, streamResponse := range streamResponses {
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
responseId = streamResponse.Id
|
||||||
if err != nil {
|
createAt = streamResponse.Created
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||||
for _, item := range streamItems {
|
if service.ValidUsage(streamResponse.Usage) {
|
||||||
var streamResponse dto.CompletionsStreamResponse
|
usage = streamResponse.Usage
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
hasStreamUsage = true
|
||||||
if err == nil {
|
}
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > toolCount {
|
||||||
|
toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
for _, streamResponse := range streamResponses {
|
}
|
||||||
|
case relayconstant.RelayModeCompletions:
|
||||||
|
var streamResponses []dto.CompletionsStreamResponse
|
||||||
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
|
if err != nil {
|
||||||
|
// 一次性解析失败,逐个解析
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
for _, item := range streamItems {
|
||||||
|
var streamResponse dto.CompletionsStreamResponse
|
||||||
|
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||||
|
if err == nil {
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
responseTextBuilder.WriteString(choice.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
if len(dataChan) > 0 {
|
for _, streamResponse := range streamResponses {
|
||||||
// wait data out
|
for _, choice := range streamResponse.Choices {
|
||||||
time.Sleep(2 * time.Second)
|
responseTextBuilder.WriteString(choice.Text)
|
||||||
}
|
}
|
||||||
common.SafeSendBool(stopChan, true)
|
|
||||||
}()
|
|
||||||
service.SetEventStreamHeaders(c)
|
|
||||||
isFirst := true
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(data, "data: [DONE]") {
|
|
||||||
data = data[:12]
|
|
||||||
}
|
|
||||||
// some implementations may add \r at the end of data
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
c.Render(-1, common.CustomEvent{Data: data})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if !hasStreamUsage {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.ShouldIncludeUsage && !hasStreamUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
|
||||||
|
response.SetSystemFingerprint(systemFingerprint)
|
||||||
|
service.ObjectData(c, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Done(c)
|
||||||
|
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
|
||||||
}
|
}
|
||||||
wg.Wait()
|
return nil, usage, responseTextBuilder.String(), toolCount
|
||||||
return nil, responseTextBuilder.String(), toolCount
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
|||||||
@@ -15,6 +15,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,13 +33,17 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return requestOpenAI2Perplexity(*request), nil
|
return requestOpenAI2Perplexity(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@@ -46,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,49 +6,68 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
Sign string
|
Sign string
|
||||||
|
AppID int64
|
||||||
|
Action string
|
||||||
|
Version string
|
||||||
|
Timestamp int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
a.Action = "ChatCompletions"
|
||||||
|
a.Version = "2023-09-01"
|
||||||
|
a.Timestamp = common.GetTimestamp()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", a.Sign)
|
req.Header.Set("Authorization", a.Sign)
|
||||||
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
req.Header.Set("X-TC-Action", a.Action)
|
||||||
|
req.Header.Set("X-TC-Version", a.Version)
|
||||||
|
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||||
|
a.AppID = appId
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tencentRequest := requestOpenAI2Tencent(*request)
|
tencentRequest := requestOpenAI2Tencent(a, *request)
|
||||||
tencentRequest.AppId = appId
|
|
||||||
tencentRequest.SecretId = secretId
|
|
||||||
// we have to calculate the sign here
|
// we have to calculate the sign here
|
||||||
a.Sign = getTencentSign(*tencentRequest, secretKey)
|
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
|
||||||
return tencentRequest, nil
|
return tencentRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package tencent
|
package tencent
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"ChatPro",
|
"hunyuan-lite",
|
||||||
"ChatStd",
|
"hunyuan-standard",
|
||||||
"hunyuan",
|
"hunyuan-standard-256K",
|
||||||
|
"hunyuan-pro",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "tencent"
|
var ChannelName = "tencent"
|
||||||
|
|||||||
@@ -1,62 +1,75 @@
|
|||||||
package tencent
|
package tencent
|
||||||
|
|
||||||
import "one-api/dto"
|
|
||||||
|
|
||||||
type TencentMessage struct {
|
type TencentMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"Role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"Content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentChatRequest struct {
|
type TencentChatRequest struct {
|
||||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
|
||||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
|
||||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
//
|
||||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
// 注意:
|
||||||
Timestamp int64 `json:"timestamp"`
|
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
|
||||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
Model *string `json:"Model"`
|
||||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
// 聊天上下文信息。
|
||||||
Expired int64 `json:"expired"`
|
// 说明:
|
||||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
// 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
|
||||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
// 2. Message.Role 可选值:system、user、assistant。
|
||||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
// 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
|
||||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
|
||||||
Temperature float64 `json:"temperature"`
|
Messages []*TencentMessage `json:"Messages"`
|
||||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
// 流式调用开关。
|
||||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
// 说明:
|
||||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
// 1. 未传值时默认为非流式调用(false)。
|
||||||
TopP float64 `json:"top_p"`
|
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
|
||||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
// 3. 非流式调用时:
|
||||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
// 调用方式与普通 HTTP 请求无异。
|
||||||
Stream int `json:"stream"`
|
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
|
||||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
|
||||||
// 输入 content 总数最大支持 3000 token。
|
//
|
||||||
Messages []TencentMessage `json:"messages"`
|
// 注意:
|
||||||
Model string `json:"model"` // 模型名称
|
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
|
||||||
|
Stream *bool `json:"Stream,omitempty"`
|
||||||
|
// 说明:
|
||||||
|
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||||
|
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||||
|
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||||
|
TopP *float64 `json:"TopP,omitempty"`
|
||||||
|
// 说明:
|
||||||
|
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||||
|
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||||
|
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||||
|
Temperature *float64 `json:"Temperature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentError struct {
|
type TencentError struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"Code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"Message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentUsage struct {
|
type TencentUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
PromptTokens int `json:"PromptTokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
CompletionTokens int `json:"CompletionTokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"TotalTokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentResponseChoices struct {
|
type TencentResponseChoices struct {
|
||||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentChatResponse struct {
|
type TencentChatResponse struct {
|
||||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
|
||||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||||
Id string `json:"id,omitempty"` // 会话 id
|
Id string `json:"Id,omitempty"` // 会话 id
|
||||||
Usage dto.Usage `json:"usage,omitempty"` // token 数量
|
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
|
||||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
Note string `json:"note,omitempty"` // 注释
|
Note string `json:"Note,omitempty"` // 注释
|
||||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatResponseSB struct {
|
||||||
|
Response TencentChatResponse `json:"Response,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package tencent
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha1"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -15,54 +15,46 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://cloud.tencent.com/document/product/1729/97732
|
// https://cloud.tencent.com/document/product/1729/97732
|
||||||
|
|
||||||
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
messages := make([]*TencentMessage, 0, len(request.Messages))
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
message := request.Messages[i]
|
message := request.Messages[i]
|
||||||
if message.Role == "system" {
|
messages = append(messages, &TencentMessage{
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
messages = append(messages, TencentMessage{
|
|
||||||
Content: message.StringContent(),
|
Content: message.StringContent(),
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
stream := 0
|
var req = TencentChatRequest{
|
||||||
if request.Stream {
|
Stream: &request.Stream,
|
||||||
stream = 1
|
Messages: messages,
|
||||||
|
Model: &request.Model,
|
||||||
}
|
}
|
||||||
return &TencentChatRequest{
|
if request.TopP != 0 {
|
||||||
Timestamp: common.GetTimestamp(),
|
req.TopP = &request.TopP
|
||||||
Expired: common.GetTimestamp() + 24*60*60,
|
|
||||||
QueryID: common.GetUUID(),
|
|
||||||
Temperature: request.Temperature,
|
|
||||||
TopP: request.TopP,
|
|
||||||
Stream: stream,
|
|
||||||
Messages: messages,
|
|
||||||
Model: request.Model,
|
|
||||||
}
|
}
|
||||||
|
if request.Temperature != 0 {
|
||||||
|
req.Temperature = &request.Temperature
|
||||||
|
}
|
||||||
|
return &req
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
|
Id: response.Id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Usage: response.Usage,
|
Usage: dto.Usage{
|
||||||
|
PromptTokens: response.Usage.PromptTokens,
|
||||||
|
CompletionTokens: response.Usage.CompletionTokens,
|
||||||
|
TotalTokens: response.Usage.TotalTokens,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if len(response.Choices) > 0 {
|
if len(response.Choices) > 0 {
|
||||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||||
@@ -99,69 +91,51 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
|||||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
var responseText string
|
var responseText string
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
for scanner.Scan() {
|
||||||
case data := <-dataChan:
|
data := scanner.Text()
|
||||||
var TencentResponse TencentChatResponse
|
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
continue
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
|
||||||
if len(response.Choices) != 0 {
|
|
||||||
responseText += response.Choices[0].Delta.GetContentString()
|
|
||||||
}
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
})
|
data = strings.TrimPrefix(data, "data:")
|
||||||
|
|
||||||
|
var tencentResponse TencentChatResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
response := streamResponseTencent2OpenAI(&tencentResponse)
|
||||||
|
if len(response.Choices) != 0 {
|
||||||
|
responseText += response.Choices[0].Delta.GetContentString()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.SysError("error reading stream: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Done(c)
|
||||||
|
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var TencentResponse TencentChatResponse
|
var tencentSb TencentChatResponseSB
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -170,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
err = json.Unmarshal(responseBody, &tencentSb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if TencentResponse.Error.Code != 0 {
|
if tencentSb.Response.Error.Code != 0 {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
Error: dto.OpenAIError{
|
Error: dto.OpenAIError{
|
||||||
Message: TencentResponse.Error.Message,
|
Message: tencentSb.Response.Error.Message,
|
||||||
Code: TencentResponse.Error.Code,
|
Code: tencentSb.Response.Error.Code,
|
||||||
},
|
},
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -206,29 +180,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
func sha256hex(s string) string {
|
||||||
params := make([]string, 0)
|
b := sha256.Sum256([]byte(s))
|
||||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
return hex.EncodeToString(b[:])
|
||||||
params = append(params, "secret_id="+req.SecretId)
|
}
|
||||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
|
||||||
params = append(params, "query_id="+req.QueryID)
|
func hmacSha256(s, key string) string {
|
||||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
hashed := hmac.New(sha256.New, []byte(key))
|
||||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
hashed.Write([]byte(s))
|
||||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
return string(hashed.Sum(nil))
|
||||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
}
|
||||||
|
|
||||||
var messageStr string
|
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
||||||
for _, msg := range req.Messages {
|
// build canonical request string
|
||||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
host := "hunyuan.tencentcloudapi.com"
|
||||||
}
|
httpRequestMethod := "POST"
|
||||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
canonicalURI := "/"
|
||||||
params = append(params, "messages=["+messageStr+"]")
|
canonicalQueryString := ""
|
||||||
|
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
||||||
sort.Sort(sort.StringSlice(params))
|
"application/json", host, strings.ToLower(adaptor.Action))
|
||||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
signedHeaders := "content-type;host;x-tc-action"
|
||||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
payload, _ := json.Marshal(req)
|
||||||
signURL := url
|
hashedRequestPayload := sha256hex(string(payload))
|
||||||
mac.Write([]byte(signURL))
|
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||||
sign := mac.Sum([]byte(nil))
|
httpRequestMethod,
|
||||||
return base64.StdEncoding.EncodeToString(sign)
|
canonicalURI,
|
||||||
|
canonicalQueryString,
|
||||||
|
canonicalHeaders,
|
||||||
|
signedHeaders,
|
||||||
|
hashedRequestPayload)
|
||||||
|
// build string to sign
|
||||||
|
algorithm := "TC3-HMAC-SHA256"
|
||||||
|
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
|
||||||
|
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
|
||||||
|
t := time.Unix(timestamp, 0).UTC()
|
||||||
|
// must be the format 2006-01-02, ref to package time for more info
|
||||||
|
date := t.Format("2006-01-02")
|
||||||
|
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
|
||||||
|
hashedCanonicalRequest := sha256hex(canonicalRequest)
|
||||||
|
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
|
||||||
|
algorithm,
|
||||||
|
requestTimestamp,
|
||||||
|
credentialScope,
|
||||||
|
hashedCanonicalRequest)
|
||||||
|
|
||||||
|
// sign string
|
||||||
|
secretDate := hmacSha256(date, "TC3"+secKey)
|
||||||
|
secretService := hmacSha256("hunyuan", secretDate)
|
||||||
|
secretKey := hmacSha256("tc3_request", secretService)
|
||||||
|
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
|
||||||
|
|
||||||
|
// build authorization
|
||||||
|
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||||
|
algorithm,
|
||||||
|
secId,
|
||||||
|
credentialScope,
|
||||||
|
signedHeaders,
|
||||||
|
signature)
|
||||||
|
return authorization
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ type Adaptor struct {
|
|||||||
request *dto.GeneralOpenAIRequest
|
request *dto.GeneralOpenAIRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
// xunfei's request is not http request, so we don't need to do anything here
|
// xunfei's request is not http request, so we don't need to do anything here
|
||||||
dummyResp := &http.Response{}
|
dummyResp := &http.Response{}
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return requestOpenAI2Zhipu(*request), nil
|
return requestOpenAI2Zhipu(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||||
|
//TODO implement me
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,7 +35,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return requestOpenAI2Zhipu(*request), nil
|
return requestOpenAI2Zhipu(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@@ -48,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
var toolCount int
|
var toolCount int
|
||||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,24 +9,27 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RelayInfo struct {
|
type RelayInfo struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
TokenId int
|
TokenId int
|
||||||
UserId int
|
UserId int
|
||||||
Group string
|
Group string
|
||||||
TokenUnlimited bool
|
TokenUnlimited bool
|
||||||
StartTime time.Time
|
StartTime time.Time
|
||||||
FirstResponseTime time.Time
|
FirstResponseTime time.Time
|
||||||
ApiType int
|
setFirstResponse bool
|
||||||
IsStream bool
|
ApiType int
|
||||||
RelayMode int
|
IsStream bool
|
||||||
UpstreamModelName string
|
RelayMode int
|
||||||
RequestURLPath string
|
UpstreamModelName string
|
||||||
ApiVersion string
|
RequestURLPath string
|
||||||
PromptTokens int
|
ApiVersion string
|
||||||
ApiKey string
|
PromptTokens int
|
||||||
Organization string
|
ApiKey string
|
||||||
BaseUrl string
|
Organization string
|
||||||
|
BaseUrl string
|
||||||
|
SupportStreamOptions bool
|
||||||
|
ShouldIncludeUsage bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
@@ -65,6 +68,11 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
info.ApiVersion = GetAPIVersion(c)
|
info.ApiVersion = GetAPIVersion(c)
|
||||||
}
|
}
|
||||||
|
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||||
|
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||||
|
info.ChannelType == common.ChannelCloudflare {
|
||||||
|
info.SupportStreamOptions = true
|
||||||
|
}
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
|||||||
info.IsStream = isStream
|
info.IsStream = isStream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (info *RelayInfo) SetFirstResponseTime() {
|
||||||
|
if !info.setFirstResponse {
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
info.setFirstResponse = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type TaskRelayInfo struct {
|
type TaskRelayInfo struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
|
|||||||
var textResponse dto.TextResponseWithError
|
var textResponse dto.TextResponseWithError
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ const (
|
|||||||
APITypePerplexity
|
APITypePerplexity
|
||||||
APITypeAws
|
APITypeAws
|
||||||
APITypeCohere
|
APITypeCohere
|
||||||
|
APITypeDify
|
||||||
|
APITypeJina
|
||||||
|
APITypeCloudflare
|
||||||
|
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
@@ -57,6 +60,12 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = APITypeAws
|
apiType = APITypeAws
|
||||||
case common.ChannelTypeCohere:
|
case common.ChannelTypeCohere:
|
||||||
apiType = APITypeCohere
|
apiType = APITypeCohere
|
||||||
|
case common.ChannelTypeDify:
|
||||||
|
apiType = APITypeDify
|
||||||
|
case common.ChannelTypeJina:
|
||||||
|
apiType = APITypeJina
|
||||||
|
case common.ChannelCloudflare:
|
||||||
|
apiType = APITypeCloudflare
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return APITypeOpenAI, false
|
return APITypeOpenAI, false
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ const (
|
|||||||
RelayModeSunoFetch
|
RelayModeSunoFetch
|
||||||
RelayModeSunoFetchByID
|
RelayModeSunoFetchByID
|
||||||
RelayModeSunoSubmit
|
RelayModeSunoSubmit
|
||||||
|
RelayModeRerank
|
||||||
)
|
)
|
||||||
|
|
||||||
func Path2RelayMode(path string) int {
|
func Path2RelayMode(path string) int {
|
||||||
@@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
|
|||||||
relayMode = RelayModeAudioTranscription
|
relayMode = RelayModeAudioTranscription
|
||||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||||
relayMode = RelayModeAudioTranslation
|
relayMode = RelayModeAudioTranslation
|
||||||
|
} else if strings.HasPrefix(path, "/v1/rerank") {
|
||||||
|
relayMode = RelayModeRerank
|
||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,14 +73,14 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota > 100*preConsumedQuota {
|
if userQuota > 100*preConsumedQuota {
|
||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
@@ -90,7 +90,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
|
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
|||||||
@@ -415,9 +415,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
originTask := model.GetByMJId(userId, mjId)
|
originTask := model.GetByMJId(userId, mjId)
|
||||||
if originTask == nil {
|
if originTask == nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||||
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
|
||||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
|
if constant.MjActionCheckSuccessEnabled {
|
||||||
|
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||||
|
}
|
||||||
|
}
|
||||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||||
@@ -500,7 +503,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||||
other := make(map[string]interface{})
|
other := make(map[string]interface{})
|
||||||
other["model_price"] = modelPrice
|
other["model_price"] = modelPrice
|
||||||
other["group_ratio"] = groupRatio
|
other["group_ratio"] = groupRatio
|
||||||
@@ -544,7 +547,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("get_channel_null: " + err.Error())
|
common.SysError("get_channel_null: " + err.Error())
|
||||||
}
|
}
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 1 {
|
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
|
||||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetString("model_mapping")
|
||||||
isModelMapped := false
|
//isModelMapped := false
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
if modelMapping != "" && modelMapping != "{}" {
|
||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
@@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if modelMap[textRequest.Model] != "" {
|
if modelMap[textRequest.Model] != "" {
|
||||||
textRequest.Model = modelMap[textRequest.Model]
|
textRequest.Model = modelMap[textRequest.Model]
|
||||||
// set upstream model name
|
// set upstream model name
|
||||||
isModelMapped = true
|
//isModelMapped = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
relayInfo.UpstreamModelName = textRequest.Model
|
relayInfo.UpstreamModelName = textRequest.Model
|
||||||
@@ -130,33 +130,38 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||||
|
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
||||||
|
textRequest.StreamOptions = nil
|
||||||
|
} else {
|
||||||
|
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||||
|
if constant.ForceStreamOption {
|
||||||
|
textRequest.StreamOptions = &dto.StreamOptions{
|
||||||
|
IncludeUsage: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||||
|
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
|
||||||
|
}
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo, *textRequest)
|
adaptor.Init(relayInfo, *textRequest)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
|
|
||||||
if isModelMapped {
|
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
||||||
jsonStr, err := json.Marshal(textRequest)
|
if err != nil {
|
||||||
if err != nil {
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
|
||||||
}
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
@@ -182,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,7 +277,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
|
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||||
modelPrice float64, usePrice bool) {
|
modelPrice float64, usePrice bool) {
|
||||||
|
|
||||||
@@ -281,7 +286,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
|||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
|
|
||||||
tokenName := ctx.GetString("token_name")
|
tokenName := ctx.GetString("token_name")
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
completionRatio := common.GetCompletionRatio(modelName)
|
||||||
|
|
||||||
quota := 0
|
quota := 0
|
||||||
if !usePrice {
|
if !usePrice {
|
||||||
@@ -307,7 +312,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
|||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
quota = 0
|
quota = 0
|
||||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
|
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||||
|
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||||
} else {
|
} else {
|
||||||
//if sensitiveResp != nil {
|
//if sensitiveResp != nil {
|
||||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||||
@@ -327,13 +333,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
|||||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
logModel := textRequest.Model
|
logModel := modelName
|
||||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||||
logModel = "gpt-4-gizmo-*"
|
logModel = "gpt-4-gizmo-*"
|
||||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||||
}
|
}
|
||||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||||
|
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||||
|
|
||||||
//if quota != 0 {
|
//if quota != 0 {
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -7,8 +7,11 @@ import (
|
|||||||
"one-api/relay/channel/aws"
|
"one-api/relay/channel/aws"
|
||||||
"one-api/relay/channel/baidu"
|
"one-api/relay/channel/baidu"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
|
"one-api/relay/channel/cloudflare"
|
||||||
"one-api/relay/channel/cohere"
|
"one-api/relay/channel/cohere"
|
||||||
|
"one-api/relay/channel/dify"
|
||||||
"one-api/relay/channel/gemini"
|
"one-api/relay/channel/gemini"
|
||||||
|
"one-api/relay/channel/jina"
|
||||||
"one-api/relay/channel/ollama"
|
"one-api/relay/channel/ollama"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
"one-api/relay/channel/palm"
|
"one-api/relay/channel/palm"
|
||||||
@@ -53,6 +56,12 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &aws.Adaptor{}
|
return &aws.Adaptor{}
|
||||||
case constant.APITypeCohere:
|
case constant.APITypeCohere:
|
||||||
return &cohere.Adaptor{}
|
return &cohere.Adaptor{}
|
||||||
|
case constant.APITypeDify:
|
||||||
|
return &dify.Adaptor{}
|
||||||
|
case constant.APITypeJina:
|
||||||
|
return &jina.Adaptor{}
|
||||||
|
case constant.APITypeCloudflare:
|
||||||
|
return &cloudflare.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
104
relay/relay_rerank.go
Normal file
104
relay/relay_rerank.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||||
|
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||||
|
for _, document := range rerankRequest.Documents {
|
||||||
|
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
||||||
|
if err == nil {
|
||||||
|
token += tkm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
var rerankRequest *dto.RerankRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if rerankRequest.Query == "" {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if len(rerankRequest.Documents) == 0 {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||||
|
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||||
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
|
var preConsumedQuota int
|
||||||
|
var ratio float64
|
||||||
|
var modelRatio float64
|
||||||
|
|
||||||
|
promptToken := getRerankPromptToken(*rerankRequest)
|
||||||
|
if !success {
|
||||||
|
preConsumedTokens := promptToken
|
||||||
|
modelRatio = common.GetModelRatio(rerankRequest.Model)
|
||||||
|
ratio = modelRatio * groupRatio
|
||||||
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
|
} else {
|
||||||
|
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
relayInfo.PromptTokens = promptToken
|
||||||
|
|
||||||
|
// pre-consume quota 预消耗配额
|
||||||
|
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.InitRerank(relayInfo, *rerankRequest)
|
||||||
|
|
||||||
|
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
|
openaiErr := service.RelayErrorHandler(resp)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/moderations", controller.Relay)
|
relayV1Router.POST("/moderations", controller.Relay)
|
||||||
|
relayV1Router.POST("/rerank", controller.Relay)
|
||||||
}
|
}
|
||||||
|
|
||||||
relayMjRouter := router.Group("/mj")
|
relayMjRouter := router.Group("/mj")
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func EnableChannel(channelId int, channelName string) {
|
|||||||
notifyRootUser(subject, content)
|
notifyRootUser(subject, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -34,9 +34,15 @@ func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
|||||||
if err.LocalError {
|
if err.LocalError {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err.StatusCode == http.StatusUnauthorized || err.StatusCode == http.StatusForbidden {
|
if err.StatusCode == http.StatusUnauthorized {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
if err.StatusCode == http.StatusForbidden {
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeGemini:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
switch err.Error.Code {
|
switch err.Error.Code {
|
||||||
case "invalid_api_key":
|
case "invalid_api_key":
|
||||||
return true
|
return true
|
||||||
@@ -68,14 +74,14 @@ func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
|
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
if !common.AutomaticEnableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openAIErr != nil {
|
if openaiWithStatusErr != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if status != common.ChannelStatusAutoDisabled {
|
if status != common.ChannelStatusAutoDisabled {
|
||||||
|
|||||||
42
service/relay.go
Normal file
42
service/relay.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetEventStreamHeaders(c *gin.Context) {
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringData(c *gin.Context, str string) {
|
||||||
|
str = strings.TrimPrefix(str, "data: ")
|
||||||
|
str = strings.TrimSuffix(str, "\r")
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func ObjectData(c *gin.Context, object interface{}) error {
|
||||||
|
jsonData, err := json.Marshal(object)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshalling object: %w", err)
|
||||||
|
}
|
||||||
|
StringData(c, string(jsonData))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Done(c *gin.Context) {
|
||||||
|
StringData(c, "[DONE]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetResponseID(c *gin.Context) string {
|
||||||
|
logID := c.GetString("X-Oneapi-Request-Id")
|
||||||
|
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||||
|
}
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
func SetEventStreamHeaders(c *gin.Context) {
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
|
||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
}
|
|
||||||
@@ -24,3 +24,19 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
|
|||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||||
|
return &dto.ChatCompletionsStreamResponse{
|
||||||
|
Id: id,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: createAt,
|
||||||
|
Model: model,
|
||||||
|
SystemFingerprint: nil,
|
||||||
|
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
|
||||||
|
Usage: &usage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidUsage(usage *dto.Usage) bool {
|
||||||
|
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
|
||||||
|
}
|
||||||
|
|||||||
@@ -550,7 +550,7 @@ const ChannelsTable = () => {
|
|||||||
);
|
);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setChannels(data);
|
setChannelFormat(data);
|
||||||
setActivePage(1);
|
setActivePage(1);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ const OperationSetting = () => {
|
|||||||
MjAccountFilterEnabled: false,
|
MjAccountFilterEnabled: false,
|
||||||
MjModeClearEnabled: false,
|
MjModeClearEnabled: false,
|
||||||
MjForwardUrlEnabled: false,
|
MjForwardUrlEnabled: false,
|
||||||
|
MjActionCheckSuccessEnabled: false,
|
||||||
DrawingEnabled: false,
|
DrawingEnabled: false,
|
||||||
DataExportEnabled: false,
|
DataExportEnabled: false,
|
||||||
DataExportDefaultTime: 'hour',
|
DataExportDefaultTime: 'hour',
|
||||||
|
|||||||
@@ -99,11 +99,14 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'orange',
|
color: 'orange',
|
||||||
label: 'Google PaLM2',
|
label: 'Google PaLM2',
|
||||||
},
|
},
|
||||||
|
{ key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
|
||||||
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
||||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
||||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||||
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
|
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
|
||||||
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
|
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
|
||||||
|
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
|
||||||
|
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
||||||
{
|
{
|
||||||
key: 22,
|
key: 22,
|
||||||
|
|||||||
@@ -153,8 +153,8 @@ export function renderModelPrice(
|
|||||||
let inputRatioPrice = modelRatio * 2.0;
|
let inputRatioPrice = modelRatio * 2.0;
|
||||||
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
||||||
let price =
|
let price =
|
||||||
(inputTokens / 1000000) * inputRatioPrice +
|
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
|
||||||
(completionTokens / 1000000) * completionRatioPrice;
|
(completionTokens / 1000000) * completionRatioPrice * groupRatio;
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<article>
|
<article>
|
||||||
|
|||||||
@@ -605,6 +605,24 @@ const EditChannel = (props) => {
|
|||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
{inputs.type === 39 && (
|
||||||
|
<>
|
||||||
|
<div style={{ marginTop: 10 }}>
|
||||||
|
<Typography.Text strong>Account ID:</Typography.Text>
|
||||||
|
</div>
|
||||||
|
<Input
|
||||||
|
name='other'
|
||||||
|
placeholder={
|
||||||
|
'请输入Account ID,例如:d6b5da8hk1awo8nap34ube6gh'
|
||||||
|
}
|
||||||
|
onChange={(value) => {
|
||||||
|
handleInputChange('other', value);
|
||||||
|
}}
|
||||||
|
value={inputs.other}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<div style={{ marginTop: 10 }}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型:</Typography.Text>
|
<Typography.Text strong>模型:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ export default function SettingsDrawing(props) {
|
|||||||
MjAccountFilterEnabled: false,
|
MjAccountFilterEnabled: false,
|
||||||
MjForwardUrlEnabled: false,
|
MjForwardUrlEnabled: false,
|
||||||
MjModeClearEnabled: false,
|
MjModeClearEnabled: false,
|
||||||
|
MjActionCheckSuccessEnabled: false,
|
||||||
});
|
});
|
||||||
const refForm = useRef();
|
const refForm = useRef();
|
||||||
const [inputsRow, setInputsRow] = useState(inputs);
|
const [inputsRow, setInputsRow] = useState(inputs);
|
||||||
@@ -156,6 +157,25 @@ export default function SettingsDrawing(props) {
|
|||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
</Col>
|
</Col>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Switch
|
||||||
|
field={'MjActionCheckSuccessEnabled'}
|
||||||
|
label={
|
||||||
|
<>
|
||||||
|
检测必须等待绘图成功才能进行放大等操作
|
||||||
|
</>
|
||||||
|
}
|
||||||
|
size='large'
|
||||||
|
checkedText='|'
|
||||||
|
uncheckedText='〇'
|
||||||
|
onChange={(value) =>
|
||||||
|
setInputs({
|
||||||
|
...inputs,
|
||||||
|
MjActionCheckSuccessEnabled: value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
<Row>
|
<Row>
|
||||||
<Button size='large' onClick={onSubmit}>
|
<Button size='large' onClick={onSubmit}>
|
||||||
|
|||||||
Reference in New Issue
Block a user