mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
74 Commits
v0.2.5.0-a
...
v0.2.7.2-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3b83f886f | ||
|
|
4d0d18931d | ||
|
|
11856ab39e | ||
|
|
eb9b4b07ad | ||
|
|
963985e76c | ||
|
|
a3880d558a | ||
|
|
ba27da9e2c | ||
|
|
e262a9bd2c | ||
|
|
9bbe8e7d1b | ||
|
|
e2b9061650 | ||
|
|
220ab412e2 | ||
|
|
7029065892 | ||
|
|
0f687aab9a | ||
|
|
5e936b3923 | ||
|
|
d55cb35c1c | ||
|
|
5be4cbcaaf | ||
|
|
e67aa370bc | ||
|
|
7b36a2b885 | ||
|
|
c88f3741e6 | ||
|
|
4e7e206290 | ||
|
|
579fc8129e | ||
|
|
f55f63f412 | ||
|
|
0526c85732 | ||
|
|
b75134ece4 | ||
|
|
a075598757 | ||
|
|
a984daa503 | ||
|
|
90abe7f27d | ||
|
|
bb313eb26f | ||
|
|
02545e4856 | ||
|
|
49cec50908 | ||
|
|
4f6710e50c | ||
|
|
03b130f2b5 | ||
|
|
45b9de9df9 | ||
|
|
e062cf32e3 | ||
|
|
52debe7572 | ||
|
|
df6502733c | ||
|
|
9896ba0a64 | ||
|
|
e8b93ed6ec | ||
|
|
b0e234e8f5 | ||
|
|
20d71711d3 | ||
|
|
4246c4cdc1 | ||
|
|
1e536ee7d9 | ||
|
|
8a730cfe12 | ||
|
|
3ed4f2f0a9 | ||
|
|
bec18ed82d | ||
|
|
bd9bf4b732 | ||
|
|
1735e093db | ||
|
|
8af4e28f75 | ||
|
|
afe02c6aa5 | ||
|
|
e0ed59bfe3 | ||
|
|
bd7222118a | ||
|
|
cf3d894195 | ||
|
|
7011083201 | ||
|
|
752048dfb4 | ||
|
|
eb382d28ab | ||
|
|
a9e1078bca | ||
|
|
6c5b3b51b0 | ||
|
|
d306aea9e5 | ||
|
|
d4578e28b3 | ||
|
|
584eefec3e | ||
|
|
a7e3168c17 | ||
|
|
d767ae04ff | ||
|
|
402a415c79 | ||
|
|
55c28b2f98 | ||
|
|
fc6ae6bf34 | ||
|
|
a9b978528e | ||
|
|
d1778bb20a | ||
|
|
37a0930db4 | ||
|
|
1117112225 | ||
|
|
f2654692e8 | ||
|
|
c834289f2c | ||
|
|
bc649ddaa7 | ||
|
|
c838beba3d | ||
|
|
6b07e6fb97 |
1
.github/workflows/docker-image-arm64.yml
vendored
1
.github/workflows/docker-image-arm64.yml
vendored
@@ -4,6 +4,7 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
|
||||
**简介**:Midjourney Proxy API文档
|
||||
|
||||
## 接口列表
|
||||
支持的接口如下:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
+ [x] /mj/submit/describe
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||
+ [x] /mj/submit/modal
|
||||
+ [x] /mj/submit/shorten
|
||||
+ [x] /mj/task/{id}/image-seed
|
||||
+ [x] /mj/insight-face/swap (InsightFace)
|
||||
|
||||
## 模型列表
|
||||
|
||||
### midjourney-proxy支持
|
||||
|
||||
35
README.md
35
README.md
@@ -5,8 +5,6 @@
|
||||
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
|
||||
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||
|
||||
|
||||
> [!WARNING]
|
||||
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
|
||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
@@ -18,19 +16,7 @@
|
||||
此分叉版本的主要变更如下:
|
||||
|
||||
1. 全新的UI界面(部分界面还待更新)
|
||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
+ [x] /mj/submit/describe
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||
+ [x] /mj/submit/modal
|
||||
+ [x] /mj/submit/shorten
|
||||
+ [x] /mj/task/{id}/image-seed
|
||||
+ [x] /mj/insight-face/swap (InsightFace)
|
||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
|
||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||
+ [x] 易支付
|
||||
4. 支持用key查询使用额度:
|
||||
@@ -47,22 +33,21 @@
|
||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
|
||||
+ [x] /suno/submit/music
|
||||
+ [x] /suno/submit/lyrics
|
||||
+ [x] /suno/fetch
|
||||
+ [x] /suno/fetch/:id
|
||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
|
||||
14. 支持Rerank模型,目前仅兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||
2. 智谱glm-4v,glm-4v识图
|
||||
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
||||
3. Anthropic Claude 3
|
||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||
6. [零一万物](https://platform.lingyiwanwu.com/)
|
||||
7. 自定义渠道,支持填入完整调用地址
|
||||
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||
9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
||||
10. Dify
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
@@ -85,8 +70,14 @@
|
||||
```
|
||||
可以实现400错误转为500错误,从而重试
|
||||
|
||||
|
||||
## 比原版One API多出的配置
|
||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true`
|
||||
## 部署
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
|
||||
### 基于 Docker 进行部署
|
||||
```shell
|
||||
# 使用 SQLite 的部署命令:
|
||||
|
||||
62
Rerank.md
Normal file
62
Rerank.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# Rerank API文档
|
||||
|
||||
**简介**:Rerank API文档
|
||||
|
||||
## 接入Dify
|
||||
模型供应商选择Jina,按要求填写模型信息即可接入Dify。
|
||||
|
||||
## 请求方式
|
||||
|
||||
Post: /v1/rerank
|
||||
|
||||
Request:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "rerank-multilingual-v3.0",
|
||||
"query": "What is the capital of the United States?",
|
||||
"top_n": 3,
|
||||
"documents": [
|
||||
"Carson City is the capital city of the American state of Nevada.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"document": {
|
||||
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
|
||||
},
|
||||
"index": 2,
|
||||
"relevance_score": 0.9999702
|
||||
},
|
||||
{
|
||||
"document": {
|
||||
"text": "Carson City is the capital city of the American state of Nevada."
|
||||
},
|
||||
"index": 0,
|
||||
"relevance_score": 0.67800725
|
||||
},
|
||||
{
|
||||
"document": {
|
||||
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
|
||||
},
|
||||
"index": 3,
|
||||
"relevance_score": 0.02800752
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 158,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 158
|
||||
}
|
||||
}
|
||||
```
|
||||
7
Suno.md
7
Suno.md
@@ -2,6 +2,13 @@
|
||||
|
||||
**简介**:Suno API文档
|
||||
|
||||
## 接口列表
|
||||
支持的接口如下:
|
||||
+ [x] /suno/submit/music
|
||||
+ [x] /suno/submit/lyrics
|
||||
+ [x] /suno/fetch
|
||||
+ [x] /suno/fetch/:id
|
||||
|
||||
## 模型列表
|
||||
|
||||
### Suno API支持
|
||||
|
||||
@@ -103,14 +103,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||
|
||||
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
@@ -133,10 +133,10 @@ var (
|
||||
// All duration's unit is seconds
|
||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||
var (
|
||||
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||
|
||||
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||
|
||||
UploadRateLimitNum = 10
|
||||
@@ -210,36 +210,39 @@ const (
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", //23
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
@@ -253,4 +256,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
}
|
||||
|
||||
38
common/env.go
Normal file
38
common/env.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetEnvOrDefault(env string, defaultValue int) int {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
num, err := strconv.Atoi(os.Getenv(env))
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func GetEnvOrDefaultString(env string, defaultValue string) string {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
||||
|
||||
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
b, err := strconv.ParseBool(os.Getenv(env))
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SafeGoroutine(f func()) {
|
||||
@@ -45,3 +46,21 @@ func SafeSendString(ch chan string, value string) (closed bool) {
|
||||
// If the code reaches here, then the channel was not closed.
|
||||
return false
|
||||
}
|
||||
|
||||
// SafeSendStringTimeout send, return true, else return false
|
||||
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
closed = false
|
||||
}
|
||||
}()
|
||||
|
||||
// This will panic if the channel is closed.
|
||||
select {
|
||||
case ch <- value:
|
||||
return true
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package common
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
var GroupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
|
||||
@@ -72,24 +72,29 @@ var defaultModelRatio = map[string]float64{
|
||||
"text-search-ada-doc-001": 10,
|
||||
"text-moderation-stable": 0.1,
|
||||
"text-moderation-latest": 0.1,
|
||||
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens //renamed to ERNIE-4.0-8K
|
||||
"ERNIE-4.0-8K": 8.572, // ¥0.12 / 1k tokens
|
||||
"ERNIE-3.5-8K": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Speed-8K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"ERNIE-Speed-128K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"ERNIE-Lite-8K": 0.2143, // ¥0.003 / 1k tokens
|
||||
"ERNIE-Tiny-8K": 0.0715, // ¥0.001 / 1k tokens
|
||||
"ERNIE-Character-8K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"ERNIE-Functions-8K": 0.2858, // ¥0.004 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5,
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
||||
"ERNIE-3.5-8K-1222": 0.012 * RMB,
|
||||
"ERNIE-Bot-8K": 0.024 * RMB,
|
||||
"ERNIE-3.5-4K-0205": 0.012 * RMB,
|
||||
"ERNIE-Speed-8K": 0.004 * RMB,
|
||||
"ERNIE-Speed-128K": 0.004 * RMB,
|
||||
"ERNIE-Lite-8K-0922": 0.008 * RMB,
|
||||
"ERNIE-Lite-8K-0308": 0.003 * RMB,
|
||||
"ERNIE-Tiny-8K": 0.001 * RMB,
|
||||
"BLOOMZ-7B": 0.004 * RMB,
|
||||
"Embedding-V1": 0.002 * RMB,
|
||||
"bge-large-zh": 0.002 * RMB,
|
||||
"bge-large-en": 0.002 * RMB,
|
||||
"tao-8k": 0.002 * RMB,
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
@@ -100,12 +105,13 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-1.0-pro-latest": 1,
|
||||
"gemini-1.0-pro-vision-latest": 1,
|
||||
"gemini-ultra": 1,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"glm-4": 7.143, // ¥0.1 / 1k tokens
|
||||
"glm-4v": 7.143, // ¥0.1 / 1k tokens
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"glm-4": 7.143, // ¥0.1 / 1k tokens
|
||||
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
|
||||
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
|
||||
"glm-3-turbo": 0.3572,
|
||||
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||
@@ -114,6 +120,7 @@ var defaultModelRatio = map[string]float64{
|
||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v4.0": 1.2858,
|
||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
|
||||
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
|
||||
@@ -152,6 +159,8 @@ var defaultModelRatio = map[string]float64{
|
||||
}
|
||||
|
||||
var defaultModelPrice = map[string]float64{
|
||||
"suno_music": 0.1,
|
||||
"suno_lyrics": 0.01,
|
||||
"dall-e-3": 0.04,
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package common
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
var TopupGroupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -191,25 +190,6 @@ func Max(a int, b int) int {
|
||||
}
|
||||
}
|
||||
|
||||
func GetOrDefault(env string, defaultValue int) int {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
num, err := strconv.Atoi(os.Getenv(env))
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func GetOrDefaultString(env string, defaultValue string) string {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
||||
|
||||
func MessageWithRequestId(message string, id string) string {
|
||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||
}
|
||||
|
||||
11
constant/env.go
Normal file
11
constant/env.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
@@ -4,6 +4,7 @@ var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
var MjActionCheckSuccessEnabled = true
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -23,7 +25,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
tik := time.Now()
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
@@ -38,34 +41,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
if testModel == "" {
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
if len(channel.GetModels()) > 0 {
|
||||
testModel = channel.GetModels()[0]
|
||||
} else {
|
||||
testModel = "gpt-3.5-turbo"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
modelMapping := *channel.ModelMapping
|
||||
@@ -73,8 +58,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
|
||||
return err, &openaiErr
|
||||
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[testModel] != "" {
|
||||
testModel = modelMap[testModel]
|
||||
@@ -82,6 +66,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
|
||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
@@ -89,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
|
||||
adaptor.Init(meta, *request)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -105,21 +103,39 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
}
|
||||
if usage == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
result := w.Result()
|
||||
// print result.Body
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
modelPrice, usePrice := common.GetModelPrice(testModel, false)
|
||||
modelRatio := common.GetModelRatio(testModel)
|
||||
completionRatio := common.GetCompletionRatio(testModel)
|
||||
ratio := modelRatio
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
|
||||
quota = int(math.Round(float64(quota) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit)
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
@@ -140,7 +156,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
|
||||
}
|
||||
|
||||
func TestChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -148,7 +164,7 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -205,7 +221,7 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, "")
|
||||
err, openaiWithStatusErr := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
@@ -214,25 +230,29 @@ func testAllChannels(notify bool) error {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
ban = true
|
||||
}
|
||||
if openaiErr != nil {
|
||||
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
|
||||
ban = true
|
||||
|
||||
// request error disables the channel
|
||||
if openaiWithStatusErr != nil {
|
||||
oaiErr := openaiWithStatusErr.Error
|
||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
}
|
||||
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: -1,
|
||||
Error: *openaiErr,
|
||||
LocalError: false,
|
||||
}
|
||||
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
|
||||
|
||||
// disable channel
|
||||
if ban && isChannelEnabled {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
||||
|
||||
// enable channel
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
||||
@@ -146,28 +146,26 @@ func UpdateMidjourneyTaskBulk() {
|
||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
|
||||
shouldReturnQuota := false
|
||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
if task.Quota != 0 {
|
||||
shouldReturnQuota = true
|
||||
}
|
||||
}
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
} else {
|
||||
if shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err = relay.AudioHelper(c, relayMode)
|
||||
case relayconstant.RelayModeRerank:
|
||||
err = relay.RerankHelper(c, relayMode)
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
@@ -40,12 +42,13 @@ func Relay(c *gin.Context) {
|
||||
retryTimes := common.RetryTimes
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelType := c.GetInt("channel_type")
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
openaiErr := relayHandler(c, relayMode)
|
||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||
if openaiErr != nil {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
go processChannelError(c, channelId, channelType, openaiErr)
|
||||
} else {
|
||||
retryTimes = 0
|
||||
}
|
||||
@@ -66,7 +69,7 @@ func Relay(c *gin.Context) {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
openaiErr = relayHandler(c, relayMode)
|
||||
if openaiErr != nil {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
go processChannelError(c, channelId, channel.Type, openaiErr)
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
@@ -125,10 +128,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||
if service.ShouldDisableChannel(err) && autoBan {
|
||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||
channelName := c.GetString("channel_name")
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
}
|
||||
|
||||
@@ -5,11 +5,10 @@ import (
|
||||
"github.com/Calcium-Ion/go-epay/epay"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"one-api/constant"
|
||||
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
|
||||
@@ -24,14 +24,3 @@ type OpenAIModels struct {
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
}
|
||||
|
||||
type ModelPricing struct {
|
||||
Available bool `json:"available"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_group,omitempty"`
|
||||
}
|
||||
|
||||
19
dto/rerank.go
Normal file
19
dto/rerank.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package dto
|
||||
|
||||
type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
}
|
||||
|
||||
type RerankResponseDocument struct {
|
||||
Document any `json:"document"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type RerankResponse struct {
|
||||
Results []RerankResponseDocument `json:"results"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
@@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct {
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
@@ -43,8 +44,12 @@ type OpenAIFunction struct {
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||
return int64(r.MaxTokens)
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int {
|
||||
return int(r.MaxTokens)
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
||||
@@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
|
||||
return c.Content == nil && len(c.ToolCalls) == 0
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||
c.Content = &s
|
||||
}
|
||||
@@ -102,10 +98,23 @@ type ChatCompletionsStreamResponse struct {
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint *string `json:"system_fingerprint"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
|
||||
if c.SystemFingerprint == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.SystemFingerprint
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
|
||||
c.SystemFingerprint = &s
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseSimple struct {
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
|
||||
@@ -178,6 +178,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
c.Set("channel_type", channel.Type)
|
||||
ban := true
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
@@ -197,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@@ -33,6 +34,13 @@ type Channel struct {
|
||||
OtherInfo string `json:"other_info"`
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
if channel.Models == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||||
}
|
||||
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
|
||||
@@ -86,9 +86,9 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
||||
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
|
||||
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
|
||||
@@ -99,6 +99,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
@@ -210,6 +211,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
case "MjForwardUrlEnabled":
|
||||
constant.MjForwardUrlEnabled = boolValue
|
||||
case "MjActionCheckSuccessEnabled":
|
||||
constant.MjActionCheckSuccessEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
|
||||
@@ -2,18 +2,28 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Pricing struct {
|
||||
Available bool `json:"available"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_group,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
pricingMap []dto.ModelPricing
|
||||
pricingMap []Pricing
|
||||
lastGetPricingTime time.Time
|
||||
updatePricingLock sync.Mutex
|
||||
)
|
||||
|
||||
func GetPricing(group string) []dto.ModelPricing {
|
||||
func GetPricing(group string) []Pricing {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
|
||||
@@ -21,7 +31,7 @@ func GetPricing(group string) []dto.ModelPricing {
|
||||
updatePricing()
|
||||
}
|
||||
if group != "" {
|
||||
userPricingMap := make([]dto.ModelPricing, 0)
|
||||
userPricingMap := make([]Pricing, 0)
|
||||
models := GetGroupModels(group)
|
||||
for _, pricing := range pricingMap {
|
||||
if !common.StringsContains(models, pricing.ModelName) {
|
||||
@@ -42,9 +52,9 @@ func updatePricing() {
|
||||
allModels[model] = i
|
||||
}
|
||||
|
||||
pricingMap = make([]dto.ModelPricing, 0)
|
||||
pricingMap = make([]Pricing, 0)
|
||||
for model, _ := range allModels {
|
||||
pricing := dto.ModelPricing{
|
||||
pricing := Pricing{
|
||||
Available: true,
|
||||
ModelName: model,
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if err != nil {
|
||||
return 0, errors.New("兑换失败," + err.Error())
|
||||
}
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
|
||||
return redemption.Quota, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -250,11 +250,9 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
|
||||
if userQuota < quota {
|
||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
}
|
||||
if !token.UnlimitedQuota {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = DecreaseUserQuota(token.UserId, quota)
|
||||
return userQuota - quota, err
|
||||
@@ -272,15 +270,13 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
||||
return err
|
||||
}
|
||||
|
||||
if !token.UnlimitedQuota {
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(tokenId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(tokenId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sendEmail {
|
||||
|
||||
@@ -298,7 +298,8 @@ func (user *User) ValidateAndFill() (err error) {
|
||||
if user.Username == "" || password == "" {
|
||||
return errors.New("用户名或密码为空")
|
||||
}
|
||||
DB.Where(User{Username: user.Username}).First(user)
|
||||
// find buy username or email
|
||||
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
|
||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||
if !okay || user.Status != common.UserStatusEnabled {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
type Adaptor interface {
|
||||
// Init IsStream bool
|
||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
@@ -39,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
@@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
@@ -36,7 +41,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -53,13 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a.RequestMode)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
relaymodel "one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
@@ -111,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
@@ -156,16 +158,20 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
var usage relaymodel.Usage
|
||||
var id string
|
||||
var model string
|
||||
isFirst := true
|
||||
createdTime := common.GetTimestamp()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
claudeResp := new(claude.ClaudeResponse)
|
||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||
if err != nil {
|
||||
@@ -208,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package baidu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -15,44 +16,74 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.UpstreamModelName {
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "ERNIE-Bot-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Speed":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "ERNIE-4.0-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "ERNIE-3.5-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Speed-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
case "ERNIE-Character-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k"
|
||||
case "ERNIE-Functions-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k"
|
||||
case "ERNIE-Lite-8K-0922":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "Yi-34B-Chat":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
default:
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName)
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
suffix := "chat/"
|
||||
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
switch info.UpstreamModelName {
|
||||
case "ERNIE-4.0":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot-4":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot":
|
||||
suffix += "completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Speed":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-4.0-8K":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-3.5-8K":
|
||||
suffix += "completions"
|
||||
case "ERNIE-3.5-8K-0205":
|
||||
suffix += "ernie-3.5-8k-0205"
|
||||
case "ERNIE-3.5-8K-1222":
|
||||
suffix += "ernie-3.5-8k-1222"
|
||||
case "ERNIE-Bot-8K":
|
||||
suffix += "ernie_bot_8k"
|
||||
case "ERNIE-3.5-4K-0205":
|
||||
suffix += "ernie-3.5-4k-0205"
|
||||
case "ERNIE-Speed-8K":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-Speed-128K":
|
||||
suffix += "ernie-speed-128k"
|
||||
case "ERNIE-Lite-8K-0922":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Lite-8K-0308":
|
||||
suffix += "ernie-lite-8k"
|
||||
case "ERNIE-Tiny-8K":
|
||||
suffix += "ernie-tiny-8k"
|
||||
case "BLOOMZ-7B":
|
||||
suffix += "bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
suffix += "embedding-v1"
|
||||
case "bge-large-zh":
|
||||
suffix += "bge_large_zh"
|
||||
case "bge-large-en":
|
||||
suffix += "bge_large_en"
|
||||
case "tao-8k":
|
||||
suffix += "tao_8k"
|
||||
default:
|
||||
suffix += strings.ToLower(info.UpstreamModelName)
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||
@@ -68,11 +99,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
@@ -82,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-3.5-8K",
|
||||
"ERNIE-4.0-8K",
|
||||
"ERNIE-3.5-8K",
|
||||
"ERNIE-3.5-8K-0205",
|
||||
"ERNIE-3.5-8K-1222",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-3.5-4K-0205",
|
||||
"ERNIE-Speed-8K",
|
||||
"ERNIE-Speed-128K",
|
||||
"ERNIE-Lite-8K",
|
||||
"ERNIE-Lite-8K-0922",
|
||||
"ERNIE-Lite-8K-0308",
|
||||
"ERNIE-Tiny-8K",
|
||||
"ERNIE-Character-8K",
|
||||
"ERNIE-Functions-8K",
|
||||
//"ERNIE-Bot-4",
|
||||
//"ERNIE-Bot-8K",
|
||||
//"ERNIE-Bot",
|
||||
//"ERNIE-Speed",
|
||||
//"ERNIE-Bot-turbo",
|
||||
"BLOOMZ-7B",
|
||||
"Embedding-V1",
|
||||
"bge-large-zh",
|
||||
"bge-large-en",
|
||||
"tao-8k",
|
||||
}
|
||||
|
||||
var ChannelName = "baidu"
|
||||
|
||||
@@ -11,9 +11,16 @@ type BaiduMessage struct {
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
|
||||
@@ -22,17 +22,33 @@ import (
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
baiduRequest := BaiduChatRequest{
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
UserId: request.User,
|
||||
}
|
||||
if request.MaxTokens != 0 {
|
||||
maxTokens := int(request.MaxTokens)
|
||||
if request.MaxTokens == 1 {
|
||||
maxTokens = 2
|
||||
}
|
||||
baiduRequest.MaxOutputTokens = &maxTokens
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
if message.Role == "system" {
|
||||
baiduRequest.System = message.StringContent()
|
||||
} else {
|
||||
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &baiduRequest
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||
|
||||
@@ -21,6 +21,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
@@ -48,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -59,13 +64,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
|
||||
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ var ModelList = []string{
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
@@ -69,6 +72,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
}
|
||||
if textRequest.Stop != nil {
|
||||
// stop maybe string/array string, convert to array string
|
||||
switch textRequest.Stop.(type) {
|
||||
case string:
|
||||
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
|
||||
case []interface{}:
|
||||
stopSequences := make([]string, 0)
|
||||
for _, stop := range textRequest.Stop.([]interface{}) {
|
||||
stopSequences = append(stopSequences, stop.(string))
|
||||
}
|
||||
claudeRequest.StopSequences = stopSequences
|
||||
}
|
||||
}
|
||||
formatMessages := make([]dto.Message, 0)
|
||||
var lastMessage *dto.Message
|
||||
for i, message := range textRequest.Messages {
|
||||
@@ -246,7 +262,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
var usage *dto.Usage
|
||||
usage = &dto.Usage{}
|
||||
@@ -265,8 +281,8 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
dataChan := make(chan string, 5)
|
||||
stopChan := make(chan bool, 2)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
@@ -274,14 +290,23 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
dataChan <- data
|
||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
||||
// send data timeout, stop the stream
|
||||
common.LogError(c, "send data timeout, stop the stream")
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
isFirst := true
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse ClaudeResponse
|
||||
@@ -302,7 +327,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
responseId = claudeResponse.Message.Id
|
||||
modelName = claudeResponse.Message.Model
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
usage.PromptTokens = claudeUsage.InputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
responseText += claudeResponse.Delta.Text
|
||||
@@ -316,34 +341,39 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
||||
//response.Id = responseId
|
||||
response.Id = responseId
|
||||
response.Created = createdTime
|
||||
response.Model = modelName
|
||||
response.Model = info.UpstreamModelName
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
err = service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
|
||||
76
relay/channel/cloudflare/adaptor.go
Normal file
76
relay/channel/cloudflare/adaptor.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return convertCf2CompletionsRequest(*request), nil
|
||||
default:
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
38
relay/channel/cloudflare/constant.go
Normal file
38
relay/channel/cloudflare/constant.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package cloudflare
|
||||
|
||||
var ModelList = []string{
|
||||
"@cf/meta/llama-2-7b-chat-fp16",
|
||||
"@cf/meta/llama-2-7b-chat-int8",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-base",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-instruct",
|
||||
"@cf/thebloke/discolm-german-7b-v1-awq",
|
||||
"@cf/tiiuae/falcon-7b-instruct",
|
||||
"@cf/google/gemma-2b-it-lora",
|
||||
"@hf/google/gemma-7b-it",
|
||||
"@cf/google/gemma-7b-it-lora",
|
||||
"@hf/nousresearch/hermes-2-pro-mistral-7b",
|
||||
"@hf/thebloke/llama-2-13b-chat-awq",
|
||||
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
|
||||
"@cf/meta/llama-3-8b-instruct",
|
||||
"@hf/thebloke/llamaguard-7b-awq",
|
||||
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
|
||||
"@hf/mistralai/mistral-7b-instruct-v0.2",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
|
||||
"@hf/thebloke/neural-chat-7b-v3-1-awq",
|
||||
"@cf/openchat/openchat-3.5-0106",
|
||||
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
|
||||
"@cf/microsoft/phi-2",
|
||||
"@cf/qwen/qwen1.5-0.5b-chat",
|
||||
"@cf/qwen/qwen1.5-1.8b-chat",
|
||||
"@cf/qwen/qwen1.5-14b-chat-awq",
|
||||
"@cf/qwen/qwen1.5-7b-chat-awq",
|
||||
"@cf/defog/sqlcoder-7b-2",
|
||||
"@hf/nexusflow/starling-lm-7b-beta",
|
||||
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
|
||||
"@hf/thebloke/zephyr-7b-beta-awq",
|
||||
}
|
||||
|
||||
var ChannelName = "cloudflare"
|
||||
13
relay/channel/cloudflare/model.go
Normal file
13
relay/channel/cloudflare/model.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package cloudflare
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type CfRequest struct {
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Lora string `json:"lora,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
121
relay/channel/cloudflare/relay_cloudflare.go
Normal file
121
relay/channel/cloudflare/relay_cloudflare.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||
p, _ := textRequest.Prompt.(string)
|
||||
return &CfRequest{
|
||||
Prompt: p,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
Stream: textRequest.Stream,
|
||||
Temperature: textRequest.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
id := service.GetResponseID(c)
|
||||
var responseText string
|
||||
isFirst := true
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < len("data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range response.Choices {
|
||||
choice.Delta.Role = "assistant"
|
||||
responseText += choice.Delta.GetContentString()
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = info.UpstreamModelName
|
||||
err = service.ObjectData(c, response)
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
response.Model = info.UpstreamModelName
|
||||
var responseText string
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = service.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
}
|
||||
@@ -8,16 +8,24 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
} else {
|
||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
@@ -26,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
@@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return requestConvertRerank2Cohere(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package cohere
|
||||
|
||||
var ModelList = []string{
|
||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
|
||||
}
|
||||
|
||||
var ChannelName = "cohere"
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package cohere
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type CohereRequest struct {
|
||||
Model string `json:"model"`
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type ChatHistory struct {
|
||||
@@ -28,6 +30,19 @@ type CohereResponseResult struct {
|
||||
Meta CohereMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type CohereRerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
ReturnDocuments bool `json:"return_documents"`
|
||||
}
|
||||
|
||||
type CohereRerankResponseResult struct {
|
||||
Results []dto.RerankResponseDocument `json:"results"`
|
||||
Meta CohereMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type CohereMeta struct {
|
||||
//Tokens CohereTokens `json:"tokens"`
|
||||
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
@@ -45,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
return &cohereReq
|
||||
}
|
||||
|
||||
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
|
||||
if rerankRequest.TopN == 0 {
|
||||
rerankRequest.TopN = 1
|
||||
}
|
||||
cohereReq := CohereRerankRequest{
|
||||
Query: rerankRequest.Query,
|
||||
Documents: rerankRequest.Documents,
|
||||
Model: rerankRequest.Model,
|
||||
TopN: rerankRequest.TopN,
|
||||
ReturnDocuments: true,
|
||||
}
|
||||
return &cohereReq
|
||||
}
|
||||
|
||||
func stopReasonCohere2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "COMPLETE":
|
||||
@@ -56,7 +72,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
@@ -84,9 +100,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
isFirst := true
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var cohereResp CohereResponse
|
||||
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||
@@ -98,7 +119,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
||||
openaiResp.Id = responseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion.chunk"
|
||||
openaiResp.Model = modelName
|
||||
openaiResp.Model = info.UpstreamModelName
|
||||
if cohereResp.IsFinished {
|
||||
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
||||
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||
@@ -137,7 +158,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
@@ -187,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var cohereResp CohereRerankResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
} else {
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
||||
}
|
||||
|
||||
var rerankResp dto.RerankResponse
|
||||
rerankResp.Results = cohereResp.Results
|
||||
rerankResp.Usage = usage
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
65
relay/channel/dify/adaptor.go
Normal file
65
relay/channel/dify/adaptor.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Dify(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = difyStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = difyHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
5
relay/channel/dify/constants.go
Normal file
5
relay/channel/dify/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dify
|
||||
|
||||
var ModelList []string
|
||||
|
||||
var ChannelName = "dify"
|
||||
35
relay/channel/dify/dto.go
Normal file
35
relay/channel/dify/dto.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package dify
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type DifyChatRequest struct {
|
||||
Inputs map[string]interface{} `json:"inputs"`
|
||||
Query string `json:"query"`
|
||||
ResponseMode string `json:"response_mode"`
|
||||
User string `json:"user"`
|
||||
AutoGenerateName bool `json:"auto_generate_name"`
|
||||
}
|
||||
|
||||
type DifyMetaData struct {
|
||||
Usage dto.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type DifyData struct {
|
||||
WorkflowId string `json:"workflow_id"`
|
||||
NodeId string `json:"node_id"`
|
||||
}
|
||||
|
||||
type DifyChatCompletionResponse struct {
|
||||
ConversationId string `json:"conversation_id"`
|
||||
Answer string `json:"answer"`
|
||||
CreateAt int64 `json:"create_at"`
|
||||
MetaData DifyMetaData `json:"metadata"`
|
||||
}
|
||||
|
||||
type DifyChunkChatCompletionResponse struct {
|
||||
Event string `json:"event"`
|
||||
ConversationId string `json:"conversation_id"`
|
||||
Answer string `json:"answer"`
|
||||
Data DifyData `json:"data"`
|
||||
MetaData DifyMetaData `json:"metadata"`
|
||||
}
|
||||
156
relay/channel/dify/relay-dify.go
Normal file
156
relay/channel/dify/relay-dify.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
|
||||
content := ""
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
content += "SYSTEM: \n" + message.StringContent() + "\n"
|
||||
} else if message.Role == "assistant" {
|
||||
content += "ASSISTANT: \n" + message.StringContent() + "\n"
|
||||
} else {
|
||||
content += "USER: \n" + message.StringContent() + "\n"
|
||||
}
|
||||
}
|
||||
mode := "blocking"
|
||||
if request.Stream {
|
||||
mode = "streaming"
|
||||
}
|
||||
user := request.User
|
||||
if user == "" {
|
||||
user = "api-user"
|
||||
}
|
||||
return &DifyChatRequest{
|
||||
Inputs: make(map[string]interface{}),
|
||||
Query: content,
|
||||
ResponseMode: mode,
|
||||
User: user,
|
||||
AutoGenerateName: false,
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "dify",
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
|
||||
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
||||
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
||||
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
||||
} else if difyResponse.Event == "message" {
|
||||
choice.Delta.SetContentString(difyResponse.Answer)
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
return &response
|
||||
}
|
||||
|
||||
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var responseText string
|
||||
usage := &dto.Usage{}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||
if difyResponse.Event == "message_end" {
|
||||
usage = &difyResponse.MetaData.Usage
|
||||
break
|
||||
} else if difyResponse.Event == "error" {
|
||||
break
|
||||
} else {
|
||||
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
||||
if len(openaiResponse.Choices) != 0 {
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
}
|
||||
err = service.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
service.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var difyResponse DifyChatCompletionResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &difyResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: difyResponse.ConversationId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: difyResponse.MetaData.Usage,
|
||||
}
|
||||
content, _ := json.Marshal(difyResponse.Answer)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &difyResponse.MetaData.Usage
|
||||
}
|
||||
@@ -9,38 +9,40 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
// 定义一个映射,存储模型名称和对应的版本
|
||||
var modelVersionMap = map[string]string{
|
||||
"gemini-1.5-pro-latest": "v1beta",
|
||||
"gemini-1.5-flash-latest": "v1beta",
|
||||
"gemini-ultra": "v1beta",
|
||||
"gemini-1.5-pro-latest": "v1beta",
|
||||
"gemini-1.5-flash-latest": "v1beta",
|
||||
"gemini-ultra": "v1beta",
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
||||
if !beta {
|
||||
if info.ApiVersion != "" {
|
||||
version = info.ApiVersion
|
||||
} else {
|
||||
version = "v1"
|
||||
}
|
||||
}
|
||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
||||
if !beta {
|
||||
if info.ApiVersion != "" {
|
||||
version = info.ApiVersion
|
||||
} else {
|
||||
version = "v1"
|
||||
}
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
@@ -49,22 +51,24 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return CovertGemini2OpenAI(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = geminiChatStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
@@ -5,12 +5,15 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -160,10 +163,14 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
responseJson := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
dataChan := make(chan string, 5)
|
||||
stopChan := make(chan bool, 2)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -180,20 +187,30 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
responseJson += data
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
||||
// send data timeout, stop the stream
|
||||
common.LogError(c, "send data timeout, stop the stream")
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
isFirst := true
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
@@ -205,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.SetContentString(dummy.Content)
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Created: createAt,
|
||||
Model: info.UpstreamModelName,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@@ -219,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
var geminiChatResponses []GeminiChatResponse
|
||||
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
log.Printf("cannot get gemini usage: %s", err.Error())
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
for _, response := range geminiChatResponses {
|
||||
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return nil, responseText
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -256,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
|
||||
64
relay/channel/jina/adaptor.go
Normal file
64
relay/channel/jina/adaptor.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = jinaRerankHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
8
relay/channel/jina/constant.go
Normal file
8
relay/channel/jina/constant.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package jina
|
||||
|
||||
var ModelList = []string{
|
||||
"jina-clip-v1",
|
||||
"jina-reranker-v2-base-multilingual",
|
||||
}
|
||||
|
||||
var ChannelName = "jina"
|
||||
35
relay/channel/jina/relay-jina.go
Normal file
35
relay/channel/jina/relay-jina.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var jinaResp dto.RerankResponse
|
||||
err = json.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &jinaResp.Usage
|
||||
}
|
||||
@@ -10,12 +10,14 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@@ -33,11 +35,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return requestOpenAI2Embeddings(*request), nil
|
||||
default:
|
||||
@@ -45,15 +47,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage = openai.OpenaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -22,6 +21,13 @@ type Adaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
@@ -67,10 +73,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -80,11 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
var toolCount int
|
||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
err, usage = OpenaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -8,42 +8,42 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
hasStreamUsage := false
|
||||
responseId := ""
|
||||
var createAt int64 = 0
|
||||
var systemFingerprint string
|
||||
model := info.UpstreamModelName
|
||||
|
||||
var responseTextBuilder strings.Builder
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
|
||||
toolCount := 0
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string, 5)
|
||||
stopChan := make(chan bool, 2)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
|
||||
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
stopChan := make(chan bool)
|
||||
defer close(stopChan)
|
||||
defer close(dataChan)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
var streamItems []string // store stream items
|
||||
for scanner.Scan() {
|
||||
info.SetFirstResponseTime()
|
||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
@@ -51,39 +51,47 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
common.SafeSendString(dataChan, data)
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
err := service.StringData(c, data)
|
||||
if err != nil {
|
||||
common.LogError(c, "streaming error: "+err.Error())
|
||||
}
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
}
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
common.LogError(c, "streaming timeout")
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
}
|
||||
|
||||
// 计算token
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
responseId = streamResponse.Id
|
||||
createAt = streamResponse.Created
|
||||
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||
model = streamResponse.Model
|
||||
if service.ValidUsage(streamResponse.Usage) {
|
||||
usage = streamResponse.Usage
|
||||
hasStreamUsage = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
@@ -98,60 +106,72 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
}
|
||||
}
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
responseId = streamResponse.Id
|
||||
createAt = streamResponse.Created
|
||||
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||
model = streamResponse.Model
|
||||
if service.ValidUsage(streamResponse.Usage) {
|
||||
usage = streamResponse.Usage
|
||||
hasStreamUsage = true
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
}
|
||||
}
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(dataChan) > 0 {
|
||||
// wait data out
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
isFirst := true
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if !hasStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage && !hasStreamUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
wg.Wait()
|
||||
return nil, responseTextBuilder.String(), toolCount
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
|
||||
@@ -15,6 +15,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@@ -28,13 +33,17 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -10,12 +10,16 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@@ -29,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -39,15 +43,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return requestOpenAI2Perplexity(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage = openai.OpenaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -6,49 +6,68 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
Sign string
|
||||
AppID int64
|
||||
Action string
|
||||
Version string
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
a.Action = "ChatCompletions"
|
||||
a.Version = "2023-09-01"
|
||||
a.Timestamp = common.GetTimestamp()
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
||||
req.Header.Set("X-TC-Action", a.Action)
|
||||
req.Header.Set("X-TC-Version", a.Version)
|
||||
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
a.AppID = appId
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := requestOpenAI2Tencent(*request)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
tencentRequest := requestOpenAI2Tencent(a, *request)
|
||||
// we have to calculate the sign here
|
||||
a.Sign = getTencentSign(*tencentRequest, secretKey)
|
||||
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package tencent
|
||||
|
||||
var ModelList = []string{
|
||||
"ChatPro",
|
||||
"ChatStd",
|
||||
"hunyuan",
|
||||
"hunyuan-lite",
|
||||
"hunyuan-standard",
|
||||
"hunyuan-standard-256K",
|
||||
"hunyuan-pro",
|
||||
}
|
||||
|
||||
var ChannelName = "tencent"
|
||||
|
||||
@@ -1,62 +1,75 @@
|
||||
package tencent
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"Role"`
|
||||
Content string `json:"Content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
Model string `json:"model"` // 模型名称
|
||||
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
|
||||
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
|
||||
//
|
||||
// 注意:
|
||||
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
|
||||
Model *string `json:"Model"`
|
||||
// 聊天上下文信息。
|
||||
// 说明:
|
||||
// 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
|
||||
// 2. Message.Role 可选值:system、user、assistant。
|
||||
// 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
|
||||
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
|
||||
Messages []*TencentMessage `json:"Messages"`
|
||||
// 流式调用开关。
|
||||
// 说明:
|
||||
// 1. 未传值时默认为非流式调用(false)。
|
||||
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
|
||||
// 3. 非流式调用时:
|
||||
// 调用方式与普通 HTTP 请求无异。
|
||||
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
|
||||
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
|
||||
//
|
||||
// 注意:
|
||||
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
|
||||
Stream *bool `json:"Stream,omitempty"`
|
||||
// 说明:
|
||||
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
TopP *float64 `json:"TopP,omitempty"`
|
||||
// 说明:
|
||||
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
Temperature *float64 `json:"Temperature,omitempty"`
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Code int `json:"Code"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokens int `json:"PromptTokens"`
|
||||
CompletionTokens int `json:"CompletionTokens"`
|
||||
TotalTokens int `json:"TotalTokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage dto.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
|
||||
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"Id,omitempty"` // 会话 id
|
||||
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"Note,omitempty"` // 注释
|
||||
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
type TencentChatResponseSB struct {
|
||||
Response TencentChatResponse `json:"Response,omitempty"`
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package tencent
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,54 +15,46 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
messages := make([]*TencentMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
continue
|
||||
}
|
||||
messages = append(messages, TencentMessage{
|
||||
messages = append(messages, &TencentMessage{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
stream := 0
|
||||
if request.Stream {
|
||||
stream = 1
|
||||
var req = TencentChatRequest{
|
||||
Stream: &request.Stream,
|
||||
Messages: messages,
|
||||
Model: &request.Model,
|
||||
}
|
||||
return &TencentChatRequest{
|
||||
Timestamp: common.GetTimestamp(),
|
||||
Expired: common.GetTimestamp() + 24*60*60,
|
||||
QueryID: common.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
Messages: messages,
|
||||
Model: request.Model,
|
||||
if request.TopP != 0 {
|
||||
req.TopP = &request.TopP
|
||||
}
|
||||
if request.Temperature != 0 {
|
||||
req.Temperature = &request.Temperature
|
||||
}
|
||||
return &req
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||
@@ -99,69 +91,51 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
})
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
|
||||
var tencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response := streamResponseTencent2OpenAI(&tencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
|
||||
err = service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var TencentResponse TencentChatResponse
|
||||
var tencentSb TencentChatResponseSB
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -170,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
if tencentSb.Response.Error.Code != 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
Message: tencentSb.Response.Error.Message,
|
||||
Code: tencentSb.Response.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -206,29 +180,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
|
||||
return
|
||||
}
|
||||
|
||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
||||
params := make([]string, 0)
|
||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
params = append(params, "secret_id="+req.SecretId)
|
||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
params = append(params, "query_id="+req.QueryID)
|
||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
var messageStr string
|
||||
for _, msg := range req.Messages {
|
||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
}
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Sort(sort.StringSlice(params))
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
mac.Write([]byte(signURL))
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
func sha256hex(s string) string {
|
||||
b := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func hmacSha256(s, key string) string {
|
||||
hashed := hmac.New(sha256.New, []byte(key))
|
||||
hashed.Write([]byte(s))
|
||||
return string(hashed.Sum(nil))
|
||||
}
|
||||
|
||||
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
||||
// build canonical request string
|
||||
host := "hunyuan.tencentcloudapi.com"
|
||||
httpRequestMethod := "POST"
|
||||
canonicalURI := "/"
|
||||
canonicalQueryString := ""
|
||||
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
||||
"application/json", host, strings.ToLower(adaptor.Action))
|
||||
signedHeaders := "content-type;host;x-tc-action"
|
||||
payload, _ := json.Marshal(req)
|
||||
hashedRequestPayload := sha256hex(string(payload))
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
httpRequestMethod,
|
||||
canonicalURI,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
hashedRequestPayload)
|
||||
// build string to sign
|
||||
algorithm := "TC3-HMAC-SHA256"
|
||||
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
|
||||
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
|
||||
t := time.Unix(timestamp, 0).UTC()
|
||||
// must be the format 2006-01-02, ref to package time for more info
|
||||
date := t.Format("2006-01-02")
|
||||
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
|
||||
hashedCanonicalRequest := sha256hex(canonicalRequest)
|
||||
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
|
||||
algorithm,
|
||||
requestTimestamp,
|
||||
credentialScope,
|
||||
hashedCanonicalRequest)
|
||||
|
||||
// sign string
|
||||
secretDate := hmacSha256(date, "TC3"+secKey)
|
||||
secretService := hmacSha256("hunyuan", secretDate)
|
||||
secretKey := hmacSha256("tc3_request", secretService)
|
||||
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
|
||||
|
||||
// build authorization
|
||||
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
algorithm,
|
||||
secId,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature)
|
||||
return authorization
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@@ -28,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
|
||||
@@ -6,6 +6,7 @@ var ModelList = []string{
|
||||
"SparkDesk-v2.1",
|
||||
"SparkDesk-v3.1",
|
||||
"SparkDesk-v3.5",
|
||||
"SparkDesk-v4.0",
|
||||
}
|
||||
|
||||
var ChannelName = "xunfei"
|
||||
|
||||
@@ -252,6 +252,8 @@ func apiVersion2domain(apiVersion string) string {
|
||||
return "generalv3"
|
||||
case "v3.5":
|
||||
return "generalv3.5"
|
||||
case "v4.0":
|
||||
return "4.0Ultra"
|
||||
}
|
||||
return "general" + apiVersion
|
||||
}
|
||||
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@@ -32,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return requestOpenAI2Zhipu(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -10,12 +10,16 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
//TODO implement me
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
@@ -30,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -40,17 +44,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return requestOpenAI2Zhipu(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
var toolCount int
|
||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
err, usage = openai.OpenaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package zhipu_4v
|
||||
|
||||
var ModelList = []string{
|
||||
"glm-4", "glm-4v", "glm-3-turbo",
|
||||
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
|
||||
}
|
||||
|
||||
var ChannelName = "zhipu_4v"
|
||||
|
||||
@@ -9,24 +9,27 @@ import (
|
||||
)
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
ApiType int
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
setFirstResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
@@ -38,24 +41,26 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := time.Now()
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
@@ -63,6 +68,11 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||
info.ChannelType == common.ChannelCloudflare {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -74,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
info.IsStream = isStream
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
if !info.setFirstResponse {
|
||||
info.FirstResponseTime = time.Now()
|
||||
info.setFirstResponse = true
|
||||
}
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
|
||||
@@ -38,6 +38,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
|
||||
var textResponse dto.TextResponseWithError
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
|
||||
return
|
||||
}
|
||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
||||
|
||||
@@ -20,6 +20,9 @@ const (
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -57,6 +60,12 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeAws
|
||||
case common.ChannelTypeCohere:
|
||||
apiType = APITypeCohere
|
||||
case common.ChannelTypeDify:
|
||||
apiType = APITypeDify
|
||||
case common.ChannelTypeJina:
|
||||
apiType = APITypeJina
|
||||
case common.ChannelCloudflare:
|
||||
apiType = APITypeCloudflare
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
RelayModeSunoFetch
|
||||
RelayModeSunoFetchByID
|
||||
RelayModeSunoSubmit
|
||||
RelayModeRerank
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
@@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
|
||||
relayMode = RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
} else if strings.HasPrefix(path, "/v1/rerank") {
|
||||
relayMode = RelayModeRerank
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -73,14 +73,14 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
@@ -90,7 +90,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
|
||||
@@ -415,9 +415,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
if constant.MjActionCheckSuccessEnabled {
|
||||
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
}
|
||||
}
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||
@@ -500,7 +503,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
@@ -544,7 +547,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if err != nil {
|
||||
common.SysError("get_channel_null: " + err.Error())
|
||||
}
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 1 {
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
|
||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
@@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
// set upstream model name
|
||||
isModelMapped = true
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
@@ -130,33 +130,38 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
||||
textRequest.StreamOptions = nil
|
||||
} else {
|
||||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||
if constant.ForceStreamOption {
|
||||
textRequest.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo, *textRequest)
|
||||
var requestBody io.Reader
|
||||
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
@@ -182,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -272,7 +277,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
||||
}
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, usePrice bool) {
|
||||
|
||||
@@ -281,7 +286,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
completionTokens := usage.CompletionTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
@@ -307,7 +312,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
} else {
|
||||
//if sensitiveResp != nil {
|
||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
@@ -327,13 +333,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
logModel := textRequest.Model
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
|
||||
@@ -7,8 +7,11 @@ import (
|
||||
"one-api/relay/channel/aws"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/cloudflare"
|
||||
"one-api/relay/channel/cohere"
|
||||
"one-api/relay/channel/dify"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
@@ -53,6 +56,12 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &aws.Adaptor{}
|
||||
case constant.APITypeCohere:
|
||||
return &cohere.Adaptor{}
|
||||
case constant.APITypeDify:
|
||||
return &dify.Adaptor{}
|
||||
case constant.APITypeJina:
|
||||
return &jina.Adaptor{}
|
||||
case constant.APITypeCloudflare:
|
||||
return &cloudflare.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
104
relay/relay_rerank.go
Normal file
104
relay/relay_rerank.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
for _, document := range rerankRequest.Documents {
|
||||
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
||||
if err == nil {
|
||||
token += tkm
|
||||
}
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
var rerankRequest *dto.RerankRequest
|
||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
if rerankRequest.Query == "" {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
||||
}
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||
}
|
||||
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
if !success {
|
||||
preConsumedTokens := promptToken
|
||||
modelRatio = common.GetModelRatio(rerankRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.InitRerank(relayInfo, *rerankRequest)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
return nil
|
||||
}
|
||||
@@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
relayV1Router.POST("/rerank", controller.Relay)
|
||||
}
|
||||
|
||||
relayMjRouter := router.Group("/mj")
|
||||
|
||||
@@ -24,7 +24,7 @@ func EnableChannel(channelId int, channelName string) {
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
||||
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
||||
if !common.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
@@ -34,9 +34,15 @@ func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
||||
if err.LocalError {
|
||||
return false
|
||||
}
|
||||
if err.StatusCode == http.StatusUnauthorized || err.StatusCode == http.StatusForbidden {
|
||||
if err.StatusCode == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
if err.StatusCode == http.StatusForbidden {
|
||||
switch channelType {
|
||||
case common.ChannelTypeGemini:
|
||||
return true
|
||||
}
|
||||
}
|
||||
switch err.Error.Code {
|
||||
case "invalid_api_key":
|
||||
return true
|
||||
@@ -68,14 +74,14 @@ func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
|
||||
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
if openaiWithStatusErr != nil {
|
||||
return false
|
||||
}
|
||||
if status != common.ChannelStatusAutoDisabled {
|
||||
|
||||
47
service/relay.go
Normal file
47
service/relay.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func StringData(c *gin.Context, str string) error {
|
||||
//str = strings.TrimPrefix(str, "data: ")
|
||||
//str = strings.TrimSuffix(str, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ObjectData(c *gin.Context, object interface{}) error {
|
||||
jsonData, err := json.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
return StringData(c, string(jsonData))
|
||||
}
|
||||
|
||||
func Done(c *gin.Context) {
|
||||
_ = StringData(c, "[DONE]")
|
||||
}
|
||||
|
||||
func GetResponseID(c *gin.Context) string {
|
||||
logID := c.GetString("X-Oneapi-Request-Id")
|
||||
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
@@ -62,7 +61,7 @@ func SensitiveWordContains(text string) (bool, []string) {
|
||||
}
|
||||
checkText := strings.ToLower(text)
|
||||
// 构建一个AC自动机
|
||||
m := common.InitAc()
|
||||
m := InitAc()
|
||||
hits := m.MultiPatternSearch([]rune(checkText), false)
|
||||
if len(hits) > 0 {
|
||||
words := make([]string, 0)
|
||||
@@ -80,7 +79,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
|
||||
return false, nil, text
|
||||
}
|
||||
checkText := strings.ToLower(text)
|
||||
m := common.InitAc()
|
||||
m := InitAc()
|
||||
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
|
||||
if len(hits) > 0 {
|
||||
words := make([]string, 0)
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package service
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package common
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -24,3 +24,19 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
}
|
||||
|
||||
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
SystemFingerprint: nil,
|
||||
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
|
||||
Usage: &usage,
|
||||
}
|
||||
}
|
||||
|
||||
func ValidUsage(usage *dto.Usage) bool {
|
||||
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
|
||||
}
|
||||
|
||||
@@ -550,7 +550,7 @@ const ChannelsTable = () => {
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setChannels(data);
|
||||
setChannelFormat(data);
|
||||
setActivePage(1);
|
||||
} else {
|
||||
showError(message);
|
||||
|
||||
@@ -367,7 +367,7 @@ const LogsTable = () => {
|
||||
dataIndex: 'content',
|
||||
render: (text, record, index) => {
|
||||
let other = getLogOther(record.other);
|
||||
if (other == null) {
|
||||
if (other == null || record.type !== 2) {
|
||||
return (
|
||||
<Paragraph
|
||||
ellipsis={{
|
||||
|
||||
@@ -42,6 +42,7 @@ const OperationSetting = () => {
|
||||
MjAccountFilterEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
DrawingEnabled: false,
|
||||
DataExportEnabled: false,
|
||||
DataExportDefaultTime: 'hour',
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Button,
|
||||
Form,
|
||||
Grid,
|
||||
Header,
|
||||
Image,
|
||||
Message,
|
||||
Segment,
|
||||
} from 'semantic-ui-react';
|
||||
import { Link, useNavigate } from 'react-router-dom';
|
||||
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { Button, Card, Form, Layout } from '@douyinfe/semi-ui';
|
||||
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
|
||||
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||
|
||||
const RegisterForm = () => {
|
||||
const [inputs, setInputs] = useState({
|
||||
@@ -18,7 +12,7 @@ const RegisterForm = () => {
|
||||
password: '',
|
||||
password2: '',
|
||||
email: '',
|
||||
verification_code: '',
|
||||
verification_code: ''
|
||||
});
|
||||
const { username, password, password2 } = inputs;
|
||||
const [showEmailVerification, setShowEmailVerification] = useState(false);
|
||||
@@ -46,9 +40,7 @@ const RegisterForm = () => {
|
||||
|
||||
let navigate = useNavigate();
|
||||
|
||||
function handleChange(e) {
|
||||
const { name, value } = e.target;
|
||||
console.log(name, value);
|
||||
function handleChange(name, value) {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
}
|
||||
|
||||
@@ -73,7 +65,7 @@ const RegisterForm = () => {
|
||||
inputs.aff_code = affCode;
|
||||
const res = await API.post(
|
||||
`/api/user/register?turnstile=${turnstileToken}`,
|
||||
inputs,
|
||||
inputs
|
||||
);
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
@@ -94,7 +86,7 @@ const RegisterForm = () => {
|
||||
}
|
||||
setLoading(true);
|
||||
const res = await API.get(
|
||||
`/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`,
|
||||
`/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`
|
||||
);
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
@@ -106,96 +98,113 @@ const RegisterForm = () => {
|
||||
};
|
||||
|
||||
return (
|
||||
<Grid textAlign='center' style={{ marginTop: '48px' }}>
|
||||
<Grid.Column style={{ maxWidth: 450 }}>
|
||||
<Header as='h2' color='' textAlign='center'>
|
||||
<Image src={logo} /> 新用户注册
|
||||
</Header>
|
||||
<Form size='large'>
|
||||
<Segment>
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='user'
|
||||
iconPosition='left'
|
||||
placeholder='输入用户名,最长 12 位'
|
||||
onChange={handleChange}
|
||||
name='username'
|
||||
/>
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='lock'
|
||||
iconPosition='left'
|
||||
placeholder='输入密码,最短 8 位,最长 20 位'
|
||||
onChange={handleChange}
|
||||
name='password'
|
||||
type='password'
|
||||
/>
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='lock'
|
||||
iconPosition='left'
|
||||
placeholder='输入密码,最短 8 位,最长 20 位'
|
||||
onChange={handleChange}
|
||||
name='password2'
|
||||
type='password'
|
||||
/>
|
||||
{showEmailVerification ? (
|
||||
<>
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='mail'
|
||||
iconPosition='left'
|
||||
placeholder='输入邮箱地址'
|
||||
onChange={handleChange}
|
||||
name='email'
|
||||
type='email'
|
||||
action={
|
||||
<Button onClick={sendVerificationCode} disabled={loading}>
|
||||
获取验证码
|
||||
</Button>
|
||||
}
|
||||
<div>
|
||||
<Layout>
|
||||
<Layout.Header></Layout.Header>
|
||||
<Layout.Content>
|
||||
<div
|
||||
style={{
|
||||
justifyContent: 'center',
|
||||
display: 'flex',
|
||||
marginTop: 120
|
||||
}}
|
||||
>
|
||||
<div style={{ width: 500 }}>
|
||||
<Card>
|
||||
<Title heading={2} style={{ textAlign: 'center' }}>
|
||||
新用户注册
|
||||
</Title>
|
||||
<Form size="large">
|
||||
<Form.Input
|
||||
field={'username'}
|
||||
label={'用户名'}
|
||||
placeholder="用户名"
|
||||
name="username"
|
||||
onChange={(value) => handleChange('username', value)}
|
||||
/>
|
||||
<Form.Input
|
||||
field={'password'}
|
||||
label={'密码'}
|
||||
placeholder="密码,最短 8 位,最长 20 位"
|
||||
name="password"
|
||||
type="password"
|
||||
onChange={(value) => handleChange('password', value)}
|
||||
/>
|
||||
<Form.Input
|
||||
field={'password2'}
|
||||
label={'确认密码'}
|
||||
placeholder="确认密码"
|
||||
name="password2"
|
||||
type="password"
|
||||
onChange={(value) => handleChange('password2', value)}
|
||||
/>
|
||||
{showEmailVerification ? (
|
||||
<>
|
||||
<Form.Input
|
||||
field={'email'}
|
||||
label={'邮箱'}
|
||||
placeholder="输入邮箱地址"
|
||||
onChange={(value) => handleChange('email', value)}
|
||||
name="email"
|
||||
type="email"
|
||||
suffix={
|
||||
<Button onClick={sendVerificationCode} disabled={loading}>
|
||||
获取验证码
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<Form.Input
|
||||
field={'verification_code'}
|
||||
label={'验证码'}
|
||||
placeholder="输入验证码"
|
||||
onChange={(value) => handleChange('verification_code', value)}
|
||||
name="verification_code"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
<Button
|
||||
theme='solid'
|
||||
style={{ width: '100%' }}
|
||||
type={'primary'}
|
||||
size='large'
|
||||
htmlType={'submit'}
|
||||
onClick={handleSubmit}
|
||||
>
|
||||
注册
|
||||
</Button>
|
||||
</Form>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
marginTop: 20
|
||||
}}
|
||||
>
|
||||
<Text>
|
||||
已有账户?
|
||||
<Link to="/login">
|
||||
点击登录
|
||||
</Link>
|
||||
</Text>
|
||||
</div>
|
||||
</Card>
|
||||
{turnstileEnabled ? (
|
||||
<Turnstile
|
||||
sitekey={turnstileSiteKey}
|
||||
onVerify={(token) => {
|
||||
setTurnstileToken(token);
|
||||
}}
|
||||
/>
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='lock'
|
||||
iconPosition='left'
|
||||
placeholder='输入验证码'
|
||||
onChange={handleChange}
|
||||
name='verification_code'
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{turnstileEnabled ? (
|
||||
<Turnstile
|
||||
sitekey={turnstileSiteKey}
|
||||
onVerify={(token) => {
|
||||
setTurnstileToken(token);
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
<Button
|
||||
color='green'
|
||||
fluid
|
||||
size='large'
|
||||
onClick={handleSubmit}
|
||||
loading={loading}
|
||||
>
|
||||
注册
|
||||
</Button>
|
||||
</Segment>
|
||||
</Form>
|
||||
<Message>
|
||||
已有账户?
|
||||
<Link to='/login' className='btn btn-link'>
|
||||
点击登录
|
||||
</Link>
|
||||
</Message>
|
||||
</Grid.Column>
|
||||
</Grid>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -99,11 +99,14 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'orange',
|
||||
label: 'Google PaLM2',
|
||||
},
|
||||
{ key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
|
||||
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
|
||||
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
|
||||
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
|
||||
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
||||
{
|
||||
key: 22,
|
||||
|
||||
@@ -144,28 +144,29 @@ export function renderModelPrice(
|
||||
) {
|
||||
// 1 ratio = $0.002 / 1K tokens
|
||||
if (modelPrice !== -1) {
|
||||
return '模型价格:$' + modelPrice * groupRatio;
|
||||
return '模型价格:$' + modelPrice + ' * 分组倍率:' + groupRatio + ' = $' + modelPrice * groupRatio;
|
||||
} else {
|
||||
if (completionRatio === undefined) {
|
||||
completionRatio = 0;
|
||||
}
|
||||
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除
|
||||
let inputRatioPrice = modelRatio * 2.0 * groupRatio;
|
||||
let completionRatioPrice = modelRatio * 2.0 * completionRatio * groupRatio;
|
||||
let inputRatioPrice = modelRatio * 2.0;
|
||||
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
||||
let price =
|
||||
(inputTokens / 1000000) * inputRatioPrice +
|
||||
(completionTokens / 1000000) * completionRatioPrice;
|
||||
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
|
||||
(completionTokens / 1000000) * completionRatioPrice * groupRatio;
|
||||
return (
|
||||
<>
|
||||
<article>
|
||||
<p>提示 ${inputRatioPrice} / 1M tokens</p>
|
||||
<p>补全 ${completionRatioPrice} / 1M tokens</p>
|
||||
<p>提示:${inputRatioPrice} * {groupRatio} = ${inputRatioPrice * groupRatio} / 1M tokens</p>
|
||||
<p>补全:${completionRatioPrice} * {groupRatio} = ${completionRatioPrice * groupRatio} / 1M tokens</p>
|
||||
<p></p>
|
||||
<p>
|
||||
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
|
||||
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $
|
||||
{price.toFixed(6)}
|
||||
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} * 分组 {groupRatio} =
|
||||
${price.toFixed(6)}
|
||||
</p>
|
||||
<p>仅供参考,以实际扣费为准</p>
|
||||
</article>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -605,6 +605,24 @@ const EditChannel = (props) => {
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{inputs.type === 39 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>Account ID:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
name='other'
|
||||
placeholder={
|
||||
'请输入Account ID,例如:d6b5da8hk1awo8nap34ube6gh'
|
||||
}
|
||||
onChange={(value) => {
|
||||
handleInputChange('other', value);
|
||||
}}
|
||||
value={inputs.other}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>模型:</Typography.Text>
|
||||
</div>
|
||||
|
||||
@@ -16,6 +16,7 @@ export default function SettingsDrawing(props) {
|
||||
MjAccountFilterEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
@@ -156,6 +157,25 @@ export default function SettingsDrawing(props) {
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={8}>
|
||||
<Form.Switch
|
||||
field={'MjActionCheckSuccessEnabled'}
|
||||
label={
|
||||
<>
|
||||
检测必须等待绘图成功才能进行放大等操作
|
||||
</>
|
||||
}
|
||||
size='large'
|
||||
checkedText='|'
|
||||
uncheckedText='〇'
|
||||
onChange={(value) =>
|
||||
setInputs({
|
||||
...inputs,
|
||||
MjActionCheckSuccessEnabled: value,
|
||||
})
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row>
|
||||
<Button size='large' onClick={onSubmit}>
|
||||
|
||||
Reference in New Issue
Block a user