mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
merge upstream
Signed-off-by: wozulong <>
This commit is contained in:
commit
7a249b206d
70
README.md
70
README.md
@ -2,15 +2,21 @@
|
|||||||
# New API
|
# New API
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
|
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
|
||||||
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
|
||||||
|
|
||||||
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
|
> [!IMPORTANT]
|
||||||
|
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||||
|
> 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
|
||||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!TIP]
|
||||||
> 最新版Docker镜像 calciumion/new-api:latest
|
> 最新版Docker镜像:`calciumion/new-api:latest`
|
||||||
> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
> 默认账号root 密码123456
|
||||||
|
> 更新指令:
|
||||||
|
> ```
|
||||||
|
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||||
|
> ```
|
||||||
|
|
||||||
|
|
||||||
## 主要变更
|
## 主要变更
|
||||||
此分叉版本的主要变更如下:
|
此分叉版本的主要变更如下:
|
||||||
@ -18,9 +24,9 @@
|
|||||||
1. 全新的UI界面(部分界面还待更新)
|
1. 全新的UI界面(部分界面还待更新)
|
||||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
|
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
|
||||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||||
+ [x] 易支付
|
+ [x] 易支付
|
||||||
4. 支持用key查询使用额度:
|
4. 支持用key查询使用额度:
|
||||||
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
|
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
|
||||||
5. 渠道显示已使用额度,支持指定组织访问
|
5. 渠道显示已使用额度,支持指定组织访问
|
||||||
6. 分页支持选择每页显示数量
|
6. 分页支持选择每页显示数量
|
||||||
7. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
7. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
||||||
@ -51,29 +57,14 @@
|
|||||||
|
|
||||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||||
|
|
||||||
## 渠道重试
|
|
||||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
|
|
||||||
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
|
||||||
### 缓存设置方法
|
|
||||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
|
||||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
|
||||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
|
||||||
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
|
||||||
### 为什么有的时候没有重试
|
|
||||||
这些错误码不会重试:400,504,524
|
|
||||||
### 我想让400也重试
|
|
||||||
在`渠道->编辑`中,将`状态码复写`改为
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"400": "500"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
可以实现400错误转为500错误,从而重试
|
|
||||||
|
|
||||||
## 比原版One API多出的配置
|
## 比原版One API多出的配置
|
||||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
||||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false`
|
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
|
||||||
- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型
|
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true`
|
||||||
|
- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用,
|
||||||
|
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`
|
||||||
|
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 部署要求
|
### 部署要求
|
||||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||||
@ -96,8 +87,25 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
|
|||||||
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
|
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
|
||||||
# 注意:数据库要开启远程访问,并且只允许服务器IP访问
|
# 注意:数据库要开启远程访问,并且只允许服务器IP访问
|
||||||
```
|
```
|
||||||
### 默认账号密码
|
|
||||||
默认账号root 密码123456
|
## 渠道重试
|
||||||
|
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
|
||||||
|
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
||||||
|
### 缓存设置方法
|
||||||
|
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||||
|
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||||
|
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
|
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||||
|
### 为什么有的时候没有重试
|
||||||
|
这些错误码不会重试:400,504,524
|
||||||
|
### 我想让400也重试
|
||||||
|
在`渠道->编辑`中,将`状态码复写`改为
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"400": "500"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
可以实现400错误转为500错误,从而重试
|
||||||
|
|
||||||
## Midjourney接口设置文档
|
## Midjourney接口设置文档
|
||||||
[对接文档](Midjourney.md)
|
[对接文档](Midjourney.md)
|
||||||
|
@ -235,6 +235,7 @@ const (
|
|||||||
ChannelTypeSunoAPI = 36
|
ChannelTypeSunoAPI = 36
|
||||||
ChannelTypeDify = 37
|
ChannelTypeDify = 37
|
||||||
ChannelTypeJina = 38
|
ChannelTypeJina = 38
|
||||||
|
ChannelCloudflare = 39
|
||||||
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
@ -280,4 +281,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //36
|
"", //36
|
||||||
"", //37
|
"", //37
|
||||||
"https://api.jina.ai", //38
|
"https://api.jina.ai", //38
|
||||||
|
"https://api.cloudflare.com", //39
|
||||||
}
|
}
|
||||||
|
@ -30,6 +30,8 @@ var defaultModelRatio = map[string]float64{
|
|||||||
"gpt-4-32k": 30,
|
"gpt-4-32k": 30,
|
||||||
"gpt-4-32k-0314": 30,
|
"gpt-4-32k-0314": 30,
|
||||||
"gpt-4-32k-0613": 30,
|
"gpt-4-32k-0613": 30,
|
||||||
|
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
||||||
|
"gpt-4o-mini-2024-07-18": 0.075,
|
||||||
"gpt-4o": 2.5, // $0.005 / 1K tokens
|
"gpt-4o": 2.5, // $0.005 / 1K tokens
|
||||||
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
||||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||||
@ -104,12 +106,13 @@ var defaultModelRatio = map[string]float64{
|
|||||||
"gemini-1.0-pro-latest": 1,
|
"gemini-1.0-pro-latest": 1,
|
||||||
"gemini-1.0-pro-vision-latest": 1,
|
"gemini-1.0-pro-vision-latest": 1,
|
||||||
"gemini-ultra": 1,
|
"gemini-ultra": 1,
|
||||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"glm-4": 7.143, // ¥0.1 / 1k tokens
|
"glm-4": 7.143, // ¥0.1 / 1k tokens
|
||||||
"glm-4v": 7.143, // ¥0.1 / 1k tokens
|
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
|
||||||
|
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
|
||||||
"glm-3-turbo": 0.3572,
|
"glm-3-turbo": 0.3572,
|
||||||
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||||
@ -157,6 +160,8 @@ var defaultModelRatio = map[string]float64{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var defaultModelPrice = map[string]float64{
|
var defaultModelPrice = map[string]float64{
|
||||||
|
"suno_music": 0.1,
|
||||||
|
"suno_lyrics": 0.01,
|
||||||
"dall-e-2": 0.02,
|
"dall-e-2": 0.02,
|
||||||
"dall-e-3": 0.04,
|
"dall-e-3": 0.04,
|
||||||
"gpt-4-gizmo-*": 0.1,
|
"gpt-4-gizmo-*": 0.1,
|
||||||
@ -313,6 +318,10 @@ func GetCompletionRatio(name string) float64 {
|
|||||||
return 4.0 / 3.0
|
return 4.0 / 3.0
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
|
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
|
||||||
|
if strings.HasPrefix(name, "gpt-4o-mini") {
|
||||||
|
return 4
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
|
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
|
73
common/str.go
Normal file
73
common/str.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"math/rand"
|
||||||
|
"strconv"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||||
|
if str == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRandomString(length int) string {
|
||||||
|
//rand.Seed(time.Now().UnixNano())
|
||||||
|
key := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||||
|
}
|
||||||
|
return string(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapToJsonStr(m map[string]interface{}) string {
|
||||||
|
bytes, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapToJsonStrFloat(m map[string]float64) string {
|
||||||
|
bytes, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func StrToMap(str string) map[string]interface{} {
|
||||||
|
m := make(map[string]interface{})
|
||||||
|
err := json.Unmarshal([]byte(str), &m)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func String2Int(str string) int {
|
||||||
|
num, err := strconv.Atoi(str)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringsContains(strs []string, str string) bool {
|
||||||
|
for _, s := range strs {
|
||||||
|
if s == str {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringToByteSlice []byte only read, panic on append
|
||||||
|
func StringToByteSlice(s string) []byte {
|
||||||
|
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
||||||
|
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||||
|
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||||
|
}
|
@ -2,7 +2,6 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -18,7 +17,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenBrowser(url string) {
|
func OpenBrowser(url string) {
|
||||||
@ -164,15 +162,6 @@ func GenerateKey() string {
|
|||||||
return string(key)
|
return string(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomString(length int) string {
|
|
||||||
//rand.Seed(time.Now().UnixNano())
|
|
||||||
key := make([]byte, length)
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
|
||||||
}
|
|
||||||
return string(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetRandomInt(max int) int {
|
func GetRandomInt(max int) int {
|
||||||
//rand.Seed(time.Now().UnixNano())
|
//rand.Seed(time.Now().UnixNano())
|
||||||
return rand.Intn(max)
|
return rand.Intn(max)
|
||||||
@ -199,60 +188,11 @@ func MessageWithRequestId(message string, id string) string {
|
|||||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func String2Int(str string) int {
|
|
||||||
num, err := strconv.Atoi(str)
|
|
||||||
if err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func StringsContains(strs []string, str string) bool {
|
|
||||||
for _, s := range strs {
|
|
||||||
if s == str {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// StringToByteSlice []byte only read, panic on append
|
|
||||||
func StringToByteSlice(s string) []byte {
|
|
||||||
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
|
||||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
|
||||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
|
||||||
}
|
|
||||||
|
|
||||||
func RandomSleep() {
|
func RandomSleep() {
|
||||||
// Sleep for 0-3000 ms
|
// Sleep for 0-3000 ms
|
||||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MapToJsonStr(m map[string]interface{}) string {
|
|
||||||
bytes, err := json.Marshal(m)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return string(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MapToJsonStrFloat(m map[string]float64) string {
|
|
||||||
bytes, err := json.Marshal(m)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return string(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func StrToMap(str string) map[string]interface{} {
|
|
||||||
m := make(map[string]interface{})
|
|
||||||
err := json.Unmarshal([]byte(str), &m)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
|
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
|
||||||
if "" == proxyUrl {
|
if "" == proxyUrl {
|
||||||
return &http.Client{}, nil
|
return &http.Client{}, nil
|
||||||
|
@ -9,3 +9,9 @@ var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
|||||||
|
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||||
|
|
||||||
|
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||||
|
|
||||||
|
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||||
|
|
||||||
|
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@ -24,7 +26,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == common.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
@ -40,29 +42,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
Body: nil,
|
Body: nil,
|
||||||
Header: make(http.Header),
|
Header: make(http.Header),
|
||||||
}
|
}
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
|
||||||
c.Set("channel", channel.Type)
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
//case common.ChannelTypeAIProxyLibrary:
|
|
||||||
// c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set("plugin", channel.Other)
|
|
||||||
}
|
|
||||||
|
|
||||||
meta := relaycommon.GenRelayInfo(c)
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
|
||||||
if adaptor == nil {
|
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
|
||||||
}
|
|
||||||
if testModel == "" {
|
if testModel == "" {
|
||||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||||
testModel = *channel.TestModel
|
testModel = *channel.TestModel
|
||||||
@ -79,8 +59,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
|
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
return err, &openaiErr
|
|
||||||
}
|
}
|
||||||
if modelMap[testModel] != "" {
|
if modelMap[testModel] != "" {
|
||||||
testModel = modelMap[testModel]
|
testModel = modelMap[testModel]
|
||||||
@ -88,14 +67,28 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Set("channel", channel.Type)
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||||
|
|
||||||
|
meta := relaycommon.GenRelayInfo(c)
|
||||||
|
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||||
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
|
}
|
||||||
|
|
||||||
request := buildTestRequest()
|
request := buildTestRequest()
|
||||||
request.Model = testModel
|
request.Model = testModel
|
||||||
meta.UpstreamModelName = testModel
|
meta.UpstreamModelName = testModel
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||||
|
|
||||||
adaptor.Init(meta, *request)
|
adaptor.Init(meta)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
@ -110,12 +103,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||||
err := relaycommon.RelayErrorHandler(resp)
|
err := service.RelayErrorHandler(resp)
|
||||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
||||||
}
|
}
|
||||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||||
}
|
}
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
return errors.New("usage is nil"), nil
|
return errors.New("usage is nil"), nil
|
||||||
@ -225,11 +218,11 @@ func testAllChannels(notify bool) error {
|
|||||||
if disableThreshold == 0 {
|
if disableThreshold == 0 {
|
||||||
disableThreshold = 10000000 // a impossible value
|
disableThreshold = 10000000 // a impossible value
|
||||||
}
|
}
|
||||||
go func() {
|
gopool.Go(func() {
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, openaiErr := testChannel(channel, "")
|
err, openaiWithStatusErr := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
@ -238,27 +231,29 @@ func testAllChannels(notify bool) error {
|
|||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
ban = true
|
ban = true
|
||||||
}
|
}
|
||||||
if openaiErr != nil {
|
|
||||||
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
|
// request error disables the channel
|
||||||
ban = true
|
if openaiWithStatusErr != nil {
|
||||||
|
oaiErr := openaiWithStatusErr.Error
|
||||||
|
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||||
|
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse *int to bool
|
// parse *int to bool
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
ban = false
|
ban = false
|
||||||
}
|
}
|
||||||
if openaiErr != nil {
|
|
||||||
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
|
// disable channel
|
||||||
StatusCode: -1,
|
if ban && isChannelEnabled {
|
||||||
Error: *openaiErr,
|
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
LocalError: false,
|
|
||||||
}
|
|
||||||
if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
|
|
||||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
|
||||||
}
|
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// enable channel
|
||||||
|
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||||
|
service.EnableChannel(channel.Id, channel.Name)
|
||||||
|
}
|
||||||
|
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
}
|
}
|
||||||
@ -271,7 +266,7 @@ func testAllChannels(notify bool) error {
|
|||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,28 +146,26 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||||
task.Buttons = string(buttonStr)
|
task.Buttons = string(buttonStr)
|
||||||
}
|
}
|
||||||
|
shouldReturnQuota := false
|
||||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
err = model.CacheUpdateUserQuota(task.UserId)
|
if task.Quota != 0 {
|
||||||
if err != nil {
|
shouldReturnQuota = true
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
|
||||||
} else {
|
|
||||||
quota := task.Quota
|
|
||||||
if quota != 0 {
|
|
||||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
|
||||||
}
|
|
||||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||||
|
} else {
|
||||||
|
if shouldReturnQuota {
|
||||||
|
err = model.IncreaseUserQuota(task.UserId, task.Quota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -131,7 +131,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
adaptor.Init(meta, dto.GeneralOpenAIRequest{})
|
adaptor.Init(meta)
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,13 +22,13 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
var err *dto.OpenAIErrorWithStatusCode
|
var err *dto.OpenAIErrorWithStatusCode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeImagesGenerations:
|
case relayconstant.RelayModeImagesGenerations:
|
||||||
err = relay.RelayImageHelper(c, relayMode)
|
err = relay.ImageHelper(c, relayMode)
|
||||||
case relayconstant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.AudioHelper(c, relayMode)
|
err = relay.AudioHelper(c)
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err = relay.RerankHelper(c, relayMode)
|
err = relay.RerankHelper(c, relayMode)
|
||||||
default:
|
default:
|
||||||
|
33
dto/audio.go
33
dto/audio.go
@ -1,13 +1,34 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type TextToSpeechRequest struct {
|
type AudioRequest struct {
|
||||||
Model string `json:"model" binding:"required"`
|
Model string `json:"model"`
|
||||||
Input string `json:"input" binding:"required"`
|
Input string `json:"input"`
|
||||||
Voice string `json:"voice" binding:"required"`
|
Voice string `json:"voice"`
|
||||||
Speed float64 `json:"speed"`
|
Speed float64 `json:"speed,omitempty"`
|
||||||
ResponseFormat string `json:"response_format"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AudioResponse struct {
|
type AudioResponse struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WhisperVerboseJSONResponse struct {
|
||||||
|
Task string `json:"task,omitempty"`
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
Duration float64 `json:"duration,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Segments []Segment `json:"segments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Segment struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Seek int `json:"seek"`
|
||||||
|
Start float64 `json:"start"`
|
||||||
|
End float64 `json:"end"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Tokens []int `json:"tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
AvgLogprob float64 `json:"avg_logprob"`
|
||||||
|
CompressionRatio float64 `json:"compression_ratio"`
|
||||||
|
NoSpeechProb float64 `json:"no_speech_prob"`
|
||||||
|
}
|
||||||
|
12
dto/dalle.go
12
dto/dalle.go
@ -12,9 +12,11 @@ type ImageRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
Created int `json:"created"`
|
Data []ImageData `json:"data"`
|
||||||
Data []struct {
|
Created int64 `json:"created"`
|
||||||
Url string `json:"url"`
|
}
|
||||||
B64Json string `json:"b64_json"`
|
type ImageData struct {
|
||||||
}
|
Url string `json:"url"`
|
||||||
|
B64Json string `json:"b64_json"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt"`
|
||||||
}
|
}
|
||||||
|
@ -29,12 +29,13 @@ type GeneralOpenAIRequest struct {
|
|||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
Seed float64 `json:"seed,omitempty"`
|
Seed float64 `json:"seed,omitempty"`
|
||||||
Tools any `json:"tools,omitempty"`
|
Tools []ToolCall `json:"tools,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
LogitBias any `json:"logit_bias,omitempty"`
|
LogitBias any `json:"logit_bias,omitempty"`
|
||||||
LogProbs any `json:"logprobs,omitempty"`
|
LogProbs any `json:"logprobs,omitempty"`
|
||||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAITools struct {
|
type OpenAITools struct {
|
||||||
@ -52,8 +53,8 @@ type StreamOptions struct {
|
|||||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
func (r GeneralOpenAIRequest) GetMaxTokens() int {
|
||||||
return int64(r.MaxTokens)
|
return int(r.MaxTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
@ -107,6 +108,11 @@ func (m Message) StringContent() string {
|
|||||||
return string(m.Content)
|
return string(m.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Message) SetStringContent(content string) {
|
||||||
|
jsonContent, _ := json.Marshal(content)
|
||||||
|
m.Content = jsonContent
|
||||||
|
}
|
||||||
|
|
||||||
func (m Message) IsStringContent() bool {
|
func (m Message) IsStringContent() bool {
|
||||||
var stringContent string
|
var stringContent string
|
||||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||||
@ -146,7 +152,7 @@ func (m Message) ParseContent() []MediaMessage {
|
|||||||
if ok {
|
if ok {
|
||||||
subObj["detail"] = detail.(string)
|
subObj["detail"] = detail.(string)
|
||||||
} else {
|
} else {
|
||||||
subObj["detail"] = "auto"
|
subObj["detail"] = "high"
|
||||||
}
|
}
|
||||||
contentList = append(contentList, MediaMessage{
|
contentList = append(contentList, MediaMessage{
|
||||||
Type: ContentTypeImageURL,
|
Type: ContentTypeImageURL,
|
||||||
@ -155,7 +161,16 @@ func (m Message) ParseContent() []MediaMessage {
|
|||||||
Detail: subObj["detail"].(string),
|
Detail: subObj["detail"].(string),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
} else if url, ok := contentMap["image_url"].(string); ok {
|
||||||
|
contentList = append(contentList, MediaMessage{
|
||||||
|
Type: ContentTypeImageURL,
|
||||||
|
ImageUrl: MessageImageUrl{
|
||||||
|
Url: url,
|
||||||
|
Detail: "high",
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return contentList
|
return contentList
|
||||||
|
@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
|
|||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
|
|
||||||
return c.Content == nil && len(c.ToolCalls) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||||
c.Content = &s
|
c.Content = &s
|
||||||
}
|
}
|
||||||
@ -90,9 +86,11 @@ type ToolCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FunctionCall struct {
|
type FunctionCall struct {
|
||||||
Name string `json:"name,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
// call function with arguments in JSON format
|
// call function with arguments in JSON format
|
||||||
Arguments string `json:"arguments,omitempty"`
|
Parameters any `json:"parameters,omitempty"` // request
|
||||||
|
Arguments string `json:"arguments,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponse struct {
|
type ChatCompletionsStreamResponse struct {
|
||||||
@ -105,6 +103,17 @@ type ChatCompletionsStreamResponse struct {
|
|||||||
Usage *Usage `json:"usage"`
|
Usage *Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
|
||||||
|
if c.SystemFingerprint == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *c.SystemFingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
|
||||||
|
c.SystemFingerprint = &s
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseSimple struct {
|
type ChatCompletionsStreamResponseSimple struct {
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
Usage *Usage `json:"usage"`
|
Usage *Usage `json:"usage"`
|
||||||
|
1
go.mod
1
go.mod
@ -38,6 +38,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||||
github.com/aws/smithy-go v1.20.2 // indirect
|
github.com/aws/smithy-go v1.20.2 // indirect
|
||||||
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
|
4
go.sum
4
go.sum
@ -16,6 +16,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
|
|||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||||
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||||
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
@ -205,6 +207,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
|||||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||||
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@ -214,6 +217,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
8
main.go
8
main.go
@ -3,12 +3,14 @@ package main
|
|||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-contrib/sessions/cookie"
|
"github.com/gin-contrib/sessions/cookie"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@ -89,11 +91,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
if common.IsMasterNode {
|
if common.IsMasterNode && constant.UpdateTask {
|
||||||
common.SafeGoroutine(func() {
|
gopool.Go(func() {
|
||||||
controller.UpdateMidjourneyTaskBulk()
|
controller.UpdateMidjourneyTaskBulk()
|
||||||
})
|
})
|
||||||
common.SafeGoroutine(func() {
|
gopool.Go(func() {
|
||||||
controller.UpdateTaskBulk()
|
controller.UpdateTaskBulk()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -25,6 +26,10 @@ func Distribute() func(c *gin.Context) {
|
|||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := c.Get("specific_channel_id")
|
channelId, ok := c.Get("specific_channel_id")
|
||||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set("group", userGroup)
|
c.Set("group", userGroup)
|
||||||
if ok {
|
if ok {
|
||||||
@ -141,7 +146,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||||
return nil, false, err
|
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
@ -154,18 +159,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
if modelRequest.Model == "" {
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||||
modelRequest.Model = "dall-e"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||||
if modelRequest.Model == "" {
|
relayMode := relayconstant.RelayModeAudioSpeech
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||||
modelRequest.Model = "tts-1"
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
|
||||||
} else {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
modelRequest.Model = "whisper-1"
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
|
||||||
}
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
|
||||||
|
relayMode = relayconstant.RelayModeAudioTranslation
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
|
||||||
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
|
||||||
|
relayMode = relayconstant.RelayModeAudioTranscription
|
||||||
}
|
}
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
}
|
}
|
||||||
return &modelRequest, shouldSelectChannel, nil
|
return &modelRequest, shouldSelectChannel, nil
|
||||||
}
|
}
|
||||||
@ -198,11 +207,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
//case common.ChannelTypeAIProxyLibrary:
|
|
||||||
// c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
case common.ChannelTypeGemini:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeAli:
|
case common.ChannelTypeAli:
|
||||||
c.Set("plugin", channel.Other)
|
c.Set("plugin", channel.Other)
|
||||||
|
case common.ChannelCloudflare:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strings"
|
"strings"
|
||||||
@ -87,7 +88,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
if common.DataExportEnabled {
|
if common.DataExportEnabled {
|
||||||
common.SafeGoroutine(func() {
|
gopool.Go(func() {
|
||||||
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
|
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
@ -28,12 +29,12 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func InitBatchUpdater() {
|
func InitBatchUpdater() {
|
||||||
go func() {
|
gopool.Go(func() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
|
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
|
||||||
batchUpdate()
|
batchUpdate()
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func addNewRecord(type_ int, id int, value int) {
|
func addNewRecord(type_ int, id int, value int) {
|
||||||
|
@ -10,12 +10,13 @@ import (
|
|||||||
|
|
||||||
type Adaptor interface {
|
type Adaptor interface {
|
||||||
// Init IsStream bool
|
// Init IsStream bool
|
||||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
Init(info *relaycommon.RelayInfo)
|
||||||
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
|
||||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||||
|
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||||
|
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
)
|
)
|
||||||
@ -15,17 +16,18 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
|
var fullRequestURL string
|
||||||
if info.RelayMode == constant.RelayModeEmbeddings {
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||||
|
case constant.RelayModeImagesGenerations:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
||||||
|
default:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
|
||||||
}
|
}
|
||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
@ -42,22 +44,32 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||||
return baiduEmbeddingRequest, nil
|
return baiduEmbeddingRequest, nil
|
||||||
default:
|
default:
|
||||||
baiduRequest := requestOpenAI2Ali(*request)
|
aliReq := requestOpenAI2Ali(*request)
|
||||||
return baiduRequest, nil
|
return aliReq, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
aliRequest := oaiImage2Ali(request)
|
||||||
|
return aliRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
return nil, nil
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
@ -65,14 +77,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
switch info.RelayMode {
|
||||||
err, usage = aliStreamHandler(c, resp)
|
case constant.RelayModeImagesGenerations:
|
||||||
} else {
|
err, usage = aliImageHandler(c, resp, info)
|
||||||
switch info.RelayMode {
|
case constant.RelayModeEmbeddings:
|
||||||
case constant.RelayModeEmbeddings:
|
err, usage = aliEmbeddingHandler(c, resp)
|
||||||
err, usage = aliEmbeddingHandler(c, resp)
|
default:
|
||||||
default:
|
if info.IsStream {
|
||||||
err, usage = aliHandler(c, resp)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -60,13 +60,40 @@ type AliUsage struct {
|
|||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliOutput struct {
|
type TaskResult struct {
|
||||||
Text string `json:"text"`
|
B64Image string `json:"b64_image,omitempty"`
|
||||||
FinishReason string `json:"finish_reason"`
|
Url string `json:"url,omitempty"`
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliChatResponse struct {
|
type AliOutput struct {
|
||||||
|
TaskId string `json:"task_id,omitempty"`
|
||||||
|
TaskStatus string `json:"task_status,omitempty"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
Results []TaskResult `json:"results,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliResponse struct {
|
||||||
Output AliOutput `json:"output"`
|
Output AliOutput `json:"output"`
|
||||||
Usage AliUsage `json:"usage"`
|
Usage AliUsage `json:"usage"`
|
||||||
AliError
|
AliError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliImageRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||||
|
} `json:"input"`
|
||||||
|
Parameters struct {
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Steps string `json:"steps,omitempty"`
|
||||||
|
Scale string `json:"scale,omitempty"`
|
||||||
|
} `json:"parameters,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
}
|
||||||
|
177
relay/channel/ali/image.go
Normal file
177
relay/channel/ali/image.go
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
package ali
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||||
|
var imageRequest AliImageRequest
|
||||||
|
imageRequest.Input.Prompt = request.Prompt
|
||||||
|
imageRequest.Model = request.Model
|
||||||
|
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
||||||
|
imageRequest.Parameters.N = request.N
|
||||||
|
imageRequest.ResponseFormat = request.ResponseFormat
|
||||||
|
|
||||||
|
return &imageRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
|
||||||
|
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
|
||||||
|
|
||||||
|
var aliResponse AliResponse
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return &aliResponse, err, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+key)
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("updateTask client.Do err: " + err.Error())
|
||||||
|
return &aliResponse, err, nil
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
var response AliResponse
|
||||||
|
err = json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("updateTask NewDecoder err: " + err.Error())
|
||||||
|
return &aliResponse, err, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, nil, responseBody
|
||||||
|
}
|
||||||
|
|
||||||
|
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
|
||||||
|
waitSeconds := 3
|
||||||
|
step := 0
|
||||||
|
maxStep := 20
|
||||||
|
|
||||||
|
var taskResponse AliResponse
|
||||||
|
var responseBody []byte
|
||||||
|
|
||||||
|
for {
|
||||||
|
step++
|
||||||
|
rsp, err, body := updateTask(info, taskID, key)
|
||||||
|
responseBody = body
|
||||||
|
if err != nil {
|
||||||
|
return &taskResponse, responseBody, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if rsp.Output.TaskStatus == "" {
|
||||||
|
return &taskResponse, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rsp.Output.TaskStatus {
|
||||||
|
case "FAILED":
|
||||||
|
fallthrough
|
||||||
|
case "CANCELED":
|
||||||
|
fallthrough
|
||||||
|
case "SUCCEEDED":
|
||||||
|
fallthrough
|
||||||
|
case "UNKNOWN":
|
||||||
|
return rsp, responseBody, nil
|
||||||
|
}
|
||||||
|
if step >= maxStep {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
|
||||||
|
imageResponse := dto.ImageResponse{
|
||||||
|
Created: info.StartTime.Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, data := range response.Output.Results {
|
||||||
|
var b64Json string
|
||||||
|
if responseFormat == "b64_json" {
|
||||||
|
_, b64, err := common.GetImageFromUrl(data.Url)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "get_image_data_failed: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b64Json = b64
|
||||||
|
} else {
|
||||||
|
b64Json = data.B64Image
|
||||||
|
}
|
||||||
|
|
||||||
|
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
|
||||||
|
Url: data.Url,
|
||||||
|
B64Json: b64Json,
|
||||||
|
RevisedPrompt: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &imageResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
responseFormat := c.GetString("response_format")
|
||||||
|
|
||||||
|
var aliTaskResponse AliResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliTaskResponse.Message != "" {
|
||||||
|
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
||||||
|
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||||
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
|
Error: dto.OpenAIError{
|
||||||
|
Message: aliResponse.Output.Message,
|
||||||
|
Type: "ali_error",
|
||||||
|
Param: "",
|
||||||
|
Code: aliResponse.Output.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, nil
|
||||||
|
}
|
@ -16,34 +16,13 @@ import (
|
|||||||
|
|
||||||
const EnableSearchModelSuffix = "-internet"
|
const EnableSearchModelSuffix = "-internet"
|
||||||
|
|
||||||
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
|
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||||
messages := make([]AliMessage, 0, len(request.Messages))
|
if request.TopP >= 1 {
|
||||||
//prompt := ""
|
request.TopP = 0.999
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
} else if request.TopP <= 0 {
|
||||||
message := request.Messages[i]
|
request.TopP = 0.001
|
||||||
messages = append(messages, AliMessage{
|
|
||||||
Content: message.StringContent(),
|
|
||||||
Role: strings.ToLower(message.Role),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
enableSearch := false
|
|
||||||
aliModel := request.Model
|
|
||||||
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
|
|
||||||
enableSearch = true
|
|
||||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
|
||||||
}
|
|
||||||
return &AliChatRequest{
|
|
||||||
Model: request.Model,
|
|
||||||
Input: AliInput{
|
|
||||||
//Prompt: prompt,
|
|
||||||
Messages: messages,
|
|
||||||
},
|
|
||||||
Parameters: AliParameters{
|
|
||||||
IncrementalOutput: request.Stream,
|
|
||||||
Seed: uint64(request.Seed),
|
|
||||||
EnableSearch: enableSearch,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
return &request
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
|
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||||
@ -110,7 +89,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
|
|||||||
return &openAIEmbeddingResponse
|
return &openAIEmbeddingResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
|
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
|
||||||
content, _ := json.Marshal(response.Output.Text)
|
content, _ := json.Marshal(response.Output.Text)
|
||||||
choice := dto.OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
@ -134,7 +113,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
|
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.SetContentString(aliResponse.Output.Text)
|
choice.Delta.SetContentString(aliResponse.Output.Text)
|
||||||
if aliResponse.Output.FinishReason != "null" {
|
if aliResponse.Output.FinishReason != "null" {
|
||||||
@ -154,18 +133,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletions
|
|||||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var usage dto.Usage
|
var usage dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
@ -187,7 +155,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
|||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
var aliResponse AliChatResponse
|
var aliResponse AliResponse
|
||||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
@ -221,7 +189,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
|||||||
}
|
}
|
||||||
|
|
||||||
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var aliResponse AliChatResponse
|
var aliResponse AliResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
@ -7,14 +7,19 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/relay/common"
|
"one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
// multipart/form-data
|
||||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
} else {
|
||||||
req.Header.Set("Accept", "text/event-stream")
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,6 +43,29 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
fullRequestURL, err := a.GetRequestURL(info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
|
}
|
||||||
|
// set form data
|
||||||
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
|
||||||
|
err = a.SetupRequestHeader(c, req, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
|
}
|
||||||
|
resp, err := doRequest(c, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("do request failed: %w", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -20,12 +20,17 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
} else {
|
} else {
|
||||||
@ -41,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -16,12 +16,17 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,11 +104,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||||
return baiduEmbeddingRequest, nil
|
return baiduEmbeddingRequest, nil
|
||||||
|
@ -21,12 +21,17 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
} else {
|
} else {
|
||||||
@ -53,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -5,11 +5,18 @@ type ClaudeMetadata struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeMediaMessage struct {
|
type ClaudeMediaMessage struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||||
StopReason *string `json:"stop_reason,omitempty"`
|
StopReason *string `json:"stop_reason,omitempty"`
|
||||||
|
PartialJson string `json:"partial_json,omitempty"`
|
||||||
|
// tool_calls
|
||||||
|
Id string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input any `json:"input,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeMessageSource struct {
|
type ClaudeMessageSource struct {
|
||||||
@ -23,6 +30,18 @@ type ClaudeMessage struct {
|
|||||||
Content any `json:"content"`
|
Content any `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
InputSchema InputSchema `json:"input_schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputSchema struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties any `json:"properties,omitempty"`
|
||||||
|
Required any `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
type ClaudeRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
@ -35,7 +54,9 @@ type ClaudeRequest struct {
|
|||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeError struct {
|
type ClaudeError struct {
|
||||||
@ -44,24 +65,20 @@ type ClaudeError struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeResponse struct {
|
type ClaudeResponse struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Content []ClaudeMediaMessage `json:"content"`
|
Content []ClaudeMediaMessage `json:"content"`
|
||||||
Completion string `json:"completion"`
|
Completion string `json:"completion"`
|
||||||
StopReason string `json:"stop_reason"`
|
StopReason string `json:"stop_reason"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Error ClaudeError `json:"error"`
|
Error ClaudeError `json:"error"`
|
||||||
Usage ClaudeUsage `json:"usage"`
|
Usage ClaudeUsage `json:"usage"`
|
||||||
Index int `json:"index"` // stream only
|
Index int `json:"index"` // stream only
|
||||||
Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
ContentBlock *ClaudeMediaMessage `json:"content_block"`
|
||||||
Message *ClaudeResponse `json:"message"` // stream only: message_start
|
Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
||||||
|
Message *ClaudeResponse `json:"message"` // stream only: message_start
|
||||||
}
|
}
|
||||||
|
|
||||||
//type ClaudeResponseChoice struct {
|
|
||||||
// Index int `json:"index"`
|
|
||||||
// Type string `json:"type"`
|
|
||||||
//}
|
|
||||||
|
|
||||||
type ClaudeUsage struct {
|
type ClaudeUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
@ -8,12 +8,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func stopReasonClaude2OpenAI(reason string) string {
|
func stopReasonClaude2OpenAI(reason string) string {
|
||||||
@ -30,6 +28,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||||
|
|
||||||
claudeRequest := ClaudeRequest{
|
claudeRequest := ClaudeRequest{
|
||||||
Model: textRequest.Model,
|
Model: textRequest.Model,
|
||||||
Prompt: "",
|
Prompt: "",
|
||||||
@ -60,6 +59,22 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||||
|
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||||
|
|
||||||
|
for _, tool := range textRequest.Tools {
|
||||||
|
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||||
|
claudeTools = append(claudeTools, Tool{
|
||||||
|
Name: tool.Function.Name,
|
||||||
|
Description: tool.Function.Description,
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: params["type"].(string),
|
||||||
|
Properties: params["properties"],
|
||||||
|
Required: params["required"],
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
claudeRequest := ClaudeRequest{
|
claudeRequest := ClaudeRequest{
|
||||||
Model: textRequest.Model,
|
Model: textRequest.Model,
|
||||||
MaxTokens: textRequest.MaxTokens,
|
MaxTokens: textRequest.MaxTokens,
|
||||||
@ -68,10 +83,24 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
|||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
TopK: textRequest.TopK,
|
TopK: textRequest.TopK,
|
||||||
Stream: textRequest.Stream,
|
Stream: textRequest.Stream,
|
||||||
|
Tools: claudeTools,
|
||||||
}
|
}
|
||||||
if claudeRequest.MaxTokens == 0 {
|
if claudeRequest.MaxTokens == 0 {
|
||||||
claudeRequest.MaxTokens = 4096
|
claudeRequest.MaxTokens = 4096
|
||||||
}
|
}
|
||||||
|
if textRequest.Stop != nil {
|
||||||
|
// stop maybe string/array string, convert to array string
|
||||||
|
switch textRequest.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
|
||||||
|
case []interface{}:
|
||||||
|
stopSequences := make([]string, 0)
|
||||||
|
for _, stop := range textRequest.Stop.([]interface{}) {
|
||||||
|
stopSequences = append(stopSequences, stop.(string))
|
||||||
|
}
|
||||||
|
claudeRequest.StopSequences = stopSequences
|
||||||
|
}
|
||||||
|
}
|
||||||
formatMessages := make([]dto.Message, 0)
|
formatMessages := make([]dto.Message, 0)
|
||||||
var lastMessage *dto.Message
|
var lastMessage *dto.Message
|
||||||
for i, message := range textRequest.Messages {
|
for i, message := range textRequest.Messages {
|
||||||
@ -171,6 +200,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Model = claudeResponse.Model
|
response.Model = claudeResponse.Model
|
||||||
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
||||||
|
tools := make([]dto.ToolCall, 0)
|
||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
if reqMode == RequestModeCompletion {
|
if reqMode == RequestModeCompletion {
|
||||||
choice.Delta.SetContentString(claudeResponse.Completion)
|
choice.Delta.SetContentString(claudeResponse.Completion)
|
||||||
@ -186,10 +216,33 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|||||||
choice.Delta.SetContentString("")
|
choice.Delta.SetContentString("")
|
||||||
choice.Delta.Role = "assistant"
|
choice.Delta.Role = "assistant"
|
||||||
} else if claudeResponse.Type == "content_block_start" {
|
} else if claudeResponse.Type == "content_block_start" {
|
||||||
return nil, nil
|
if claudeResponse.ContentBlock != nil {
|
||||||
|
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
|
||||||
|
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||||
|
tools = append(tools, dto.ToolCall{
|
||||||
|
ID: claudeResponse.ContentBlock.Id,
|
||||||
|
Type: "function",
|
||||||
|
Function: dto.FunctionCall{
|
||||||
|
Name: claudeResponse.ContentBlock.Name,
|
||||||
|
Arguments: "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
choice.Index = claudeResponse.Index
|
if claudeResponse.Delta != nil {
|
||||||
choice.Delta.SetContentString(claudeResponse.Delta.Text)
|
choice.Index = claudeResponse.Index
|
||||||
|
choice.Delta.SetContentString(claudeResponse.Delta.Text)
|
||||||
|
if claudeResponse.Delta.Type == "input_json_delta" {
|
||||||
|
tools = append(tools, dto.ToolCall{
|
||||||
|
Function: dto.FunctionCall{
|
||||||
|
Arguments: claudeResponse.Delta.PartialJson,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if claudeResponse.Type == "message_delta" {
|
} else if claudeResponse.Type == "message_delta" {
|
||||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||||
if finishReason != "null" {
|
if finishReason != "null" {
|
||||||
@ -205,6 +258,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|||||||
if claudeUsage == nil {
|
if claudeUsage == nil {
|
||||||
claudeUsage = &ClaudeUsage{}
|
claudeUsage = &ClaudeUsage{}
|
||||||
}
|
}
|
||||||
|
if len(tools) > 0 {
|
||||||
|
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||||
|
choice.Delta.ToolCalls = tools
|
||||||
|
}
|
||||||
response.Choices = append(response.Choices, choice)
|
response.Choices = append(response.Choices, choice)
|
||||||
|
|
||||||
return &response, claudeUsage
|
return &response, claudeUsage
|
||||||
@ -217,6 +274,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
|||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
}
|
}
|
||||||
|
var responseText string
|
||||||
|
if len(claudeResponse.Content) > 0 {
|
||||||
|
responseText = claudeResponse.Content[0].Text
|
||||||
|
}
|
||||||
|
tools := make([]dto.ToolCall, 0)
|
||||||
if reqMode == RequestModeCompletion {
|
if reqMode == RequestModeCompletion {
|
||||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||||
choice := dto.OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
@ -231,20 +293,32 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
|||||||
choices = append(choices, choice)
|
choices = append(choices, choice)
|
||||||
} else {
|
} else {
|
||||||
fullTextResponse.Id = claudeResponse.Id
|
fullTextResponse.Id = claudeResponse.Id
|
||||||
for i, message := range claudeResponse.Content {
|
for _, message := range claudeResponse.Content {
|
||||||
content, _ := json.Marshal(message.Text)
|
if message.Type == "tool_use" {
|
||||||
choice := dto.OpenAITextResponseChoice{
|
args, _ := json.Marshal(message.Input)
|
||||||
Index: i,
|
tools = append(tools, dto.ToolCall{
|
||||||
Message: dto.Message{
|
ID: message.Id,
|
||||||
Role: "assistant",
|
Type: "function", // compatible with other OpenAI derivative applications
|
||||||
Content: content,
|
Function: dto.FunctionCall{
|
||||||
},
|
Name: message.Name,
|
||||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
Arguments: string(args),
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
choices = append(choices, choice)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
choice := dto.OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: dto.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
},
|
||||||
|
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||||
|
}
|
||||||
|
choice.SetStringContent(responseText)
|
||||||
|
if len(tools) > 0 {
|
||||||
|
choice.Message.ToolCalls = tools
|
||||||
|
}
|
||||||
|
choices = append(choices, choice)
|
||||||
fullTextResponse.Choices = choices
|
fullTextResponse.Choices = choices
|
||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
@ -256,89 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
responseText := ""
|
responseText := ""
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
service.SetEventStreamHeaders(c)
|
||||||
return 0, nil, nil
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
data = strings.TrimPrefix(data, "data:")
|
||||||
return i + 1, data[0:i], nil
|
data = strings.TrimSpace(data)
|
||||||
|
var claudeResponse ClaudeResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
|
if response == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
return 0, nil, nil
|
if requestMode == RequestModeCompletion {
|
||||||
})
|
responseText += claudeResponse.Completion
|
||||||
dataChan := make(chan string, 5)
|
responseId = response.Id
|
||||||
stopChan := make(chan bool, 2)
|
} else {
|
||||||
go func() {
|
if claudeResponse.Type == "message_start" {
|
||||||
for scanner.Scan() {
|
// message_start, 获取usage
|
||||||
data := scanner.Text()
|
responseId = claudeResponse.Message.Id
|
||||||
if !strings.HasPrefix(data, "data: ") {
|
info.UpstreamModelName = claudeResponse.Message.Model
|
||||||
|
usage.PromptTokens = claudeUsage.InputTokens
|
||||||
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
|
responseText += claudeResponse.Delta.Text
|
||||||
|
} else if claudeResponse.Type == "message_delta" {
|
||||||
|
usage.CompletionTokens = claudeUsage.OutputTokens
|
||||||
|
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
||||||
|
} else if claudeResponse.Type == "content_block_start" {
|
||||||
|
|
||||||
|
} else {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
data = strings.TrimPrefix(data, "data: ")
|
|
||||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
|
||||||
// send data timeout, stop the stream
|
|
||||||
common.LogError(c, "send data timeout, stop the stream")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
stopChan <- true
|
//response.Id = responseId
|
||||||
}()
|
response.Id = responseId
|
||||||
isFirst := true
|
response.Created = createdTime
|
||||||
service.SetEventStreamHeaders(c)
|
response.Model = info.UpstreamModelName
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
|
||||||
// some implementations may add \r at the end of data
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
var claudeResponse ClaudeResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
err = service.ObjectData(c, response)
|
||||||
if response == nil {
|
if err != nil {
|
||||||
return true
|
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
||||||
}
|
|
||||||
if requestMode == RequestModeCompletion {
|
|
||||||
responseText += claudeResponse.Completion
|
|
||||||
responseId = response.Id
|
|
||||||
} else {
|
|
||||||
if claudeResponse.Type == "message_start" {
|
|
||||||
// message_start, 获取usage
|
|
||||||
responseId = claudeResponse.Message.Id
|
|
||||||
info.UpstreamModelName = claudeResponse.Message.Model
|
|
||||||
usage.PromptTokens = claudeUsage.InputTokens
|
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
|
||||||
responseText += claudeResponse.Delta.Text
|
|
||||||
} else if claudeResponse.Type == "message_delta" {
|
|
||||||
usage.CompletionTokens = claudeUsage.OutputTokens
|
|
||||||
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
|
||||||
} else {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//response.Id = responseId
|
|
||||||
response.Id = responseId
|
|
||||||
response.Created = createdTime
|
|
||||||
response.Model = info.UpstreamModelName
|
|
||||||
|
|
||||||
err = service.ObjectData(c, response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(err.Error())
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
@ -357,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
service.Done(c)
|
service.Done(c)
|
||||||
err := resp.Body.Close()
|
resp.Body.Close()
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
105
relay/channel/cloudflare/adaptor.go
Normal file
105
relay/channel/cloudflare/adaptor.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeChatCompletions:
|
||||||
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeCompletions:
|
||||||
|
return convertCf2CompletionsRequest(*request), nil
|
||||||
|
default:
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
// 添加文件字段
|
||||||
|
file, _, err := c.Request.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("file is required")
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
// 打开临时文件用于保存上传的文件内容
|
||||||
|
requestBody := &bytes.Buffer{}
|
||||||
|
|
||||||
|
// 将上传的文件内容复制到临时文件
|
||||||
|
if _, err := io.Copy(requestBody, file); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return requestBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
fallthrough
|
||||||
|
case constant.RelayModeChatCompletions:
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = cfStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = cfHandler(c, resp, info)
|
||||||
|
}
|
||||||
|
case constant.RelayModeAudioTranslation:
|
||||||
|
fallthrough
|
||||||
|
case constant.RelayModeAudioTranscription:
|
||||||
|
err, usage = cfSTTHandler(c, resp, info)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
38
relay/channel/cloudflare/constant.go
Normal file
38
relay/channel/cloudflare/constant.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"@cf/meta/llama-2-7b-chat-fp16",
|
||||||
|
"@cf/meta/llama-2-7b-chat-int8",
|
||||||
|
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||||
|
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
|
||||||
|
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
|
||||||
|
"@cf/deepseek-ai/deepseek-math-7b-base",
|
||||||
|
"@cf/deepseek-ai/deepseek-math-7b-instruct",
|
||||||
|
"@cf/thebloke/discolm-german-7b-v1-awq",
|
||||||
|
"@cf/tiiuae/falcon-7b-instruct",
|
||||||
|
"@cf/google/gemma-2b-it-lora",
|
||||||
|
"@hf/google/gemma-7b-it",
|
||||||
|
"@cf/google/gemma-7b-it-lora",
|
||||||
|
"@hf/nousresearch/hermes-2-pro-mistral-7b",
|
||||||
|
"@hf/thebloke/llama-2-13b-chat-awq",
|
||||||
|
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
|
||||||
|
"@cf/meta/llama-3-8b-instruct",
|
||||||
|
"@hf/thebloke/llamaguard-7b-awq",
|
||||||
|
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
|
||||||
|
"@hf/mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
|
||||||
|
"@hf/thebloke/neural-chat-7b-v3-1-awq",
|
||||||
|
"@cf/openchat/openchat-3.5-0106",
|
||||||
|
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
|
||||||
|
"@cf/microsoft/phi-2",
|
||||||
|
"@cf/qwen/qwen1.5-0.5b-chat",
|
||||||
|
"@cf/qwen/qwen1.5-1.8b-chat",
|
||||||
|
"@cf/qwen/qwen1.5-14b-chat-awq",
|
||||||
|
"@cf/qwen/qwen1.5-7b-chat-awq",
|
||||||
|
"@cf/defog/sqlcoder-7b-2",
|
||||||
|
"@hf/nexusflow/starling-lm-7b-beta",
|
||||||
|
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
|
||||||
|
"@hf/thebloke/zephyr-7b-beta-awq",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "cloudflare"
|
21
relay/channel/cloudflare/dto.go
Normal file
21
relay/channel/cloudflare/dto.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
type CfRequest struct {
|
||||||
|
Messages []dto.Message `json:"messages,omitempty"`
|
||||||
|
Lora string `json:"lora,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Raw bool `json:"raw,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CfAudioResponse struct {
|
||||||
|
Result CfSTTResult `json:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CfSTTResult struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
156
relay/channel/cloudflare/relay_cloudflare.go
Normal file
156
relay/channel/cloudflare/relay_cloudflare.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
package cloudflare
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||||
|
p, _ := textRequest.Prompt.(string)
|
||||||
|
return &CfRequest{
|
||||||
|
Prompt: p,
|
||||||
|
MaxTokens: textRequest.GetMaxTokens(),
|
||||||
|
Stream: textRequest.Stream,
|
||||||
|
Temperature: textRequest.Temperature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
service.SetEventStreamHeaders(c)
|
||||||
|
id := service.GetResponseID(c)
|
||||||
|
var responseText string
|
||||||
|
isFirst := true
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < len("data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = strings.TrimPrefix(data, "data: ")
|
||||||
|
data = strings.TrimSuffix(data, "\r")
|
||||||
|
|
||||||
|
if data == "[DONE]" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var response dto.ChatCompletionsStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &response)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
choice.Delta.Role = "assistant"
|
||||||
|
responseText += choice.Delta.GetContentString()
|
||||||
|
}
|
||||||
|
response.Id = id
|
||||||
|
response.Model = info.UpstreamModelName
|
||||||
|
err = service.ObjectData(c, response)
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||||
|
}
|
||||||
|
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var response dto.TextResponse
|
||||||
|
err = json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
response.Model = info.UpstreamModelName
|
||||||
|
var responseText string
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
responseText += choice.Message.StringContent()
|
||||||
|
}
|
||||||
|
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
response.Usage = *usage
|
||||||
|
response.Id = service.GetResponseID(c)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var cfResp CfAudioResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &cfResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
audioResp := &dto.AudioResponse{
|
||||||
|
Text: cfResp.Result.Text,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(audioResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package cohere
|
package cohere
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
@ -14,10 +15,17 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -34,7 +42,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
return requestOpenAI2Cohere(*request), nil
|
return requestOpenAI2Cohere(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ type CohereRequest struct {
|
|||||||
ChatHistory []ChatHistory `json:"chat_history"`
|
ChatHistory []ChatHistory `json:"chat_history"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
MaxTokens int64 `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatHistory struct {
|
type ChatHistory struct {
|
||||||
|
@ -14,12 +14,17 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -32,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -14,10 +14,17 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定义一个映射,存储模型名称和对应的版本
|
// 定义一个映射,存储模型名称和对应的版本
|
||||||
@ -40,7 +47,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
action = "streamGenerateContent"
|
action = "streamGenerateContent?alt=sse"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
@ -51,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -12,9 +12,15 @@ type GeminiInlineData struct {
|
|||||||
Data string `json:"data"`
|
Data string `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
FunctionName string `json:"name"`
|
||||||
|
Arguments any `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
type GeminiPart struct {
|
type GeminiPart struct {
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatContent struct {
|
type GeminiChatContent struct {
|
||||||
|
@ -4,18 +4,14 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
@ -46,7 +42,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
|||||||
MaxOutputTokens: textRequest.MaxTokens,
|
MaxOutputTokens: textRequest.MaxTokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if textRequest.Functions != nil {
|
if textRequest.Tools != nil {
|
||||||
|
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
|
||||||
|
for _, tool := range textRequest.Tools {
|
||||||
|
functions = append(functions, tool.Function)
|
||||||
|
}
|
||||||
|
geminiRequest.Tools = []GeminiChatTools{
|
||||||
|
{
|
||||||
|
FunctionDeclarations: functions,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else if textRequest.Functions != nil {
|
||||||
geminiRequest.Tools = []GeminiChatTools{
|
geminiRequest.Tools = []GeminiChatTools{
|
||||||
{
|
{
|
||||||
FunctionDeclarations: textRequest.Functions,
|
FunctionDeclarations: textRequest.Functions,
|
||||||
@ -126,6 +132,30 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
||||||
|
var toolCalls []dto.ToolCall
|
||||||
|
|
||||||
|
item := candidate.Content.Parts[0]
|
||||||
|
if item.FunctionCall == nil {
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
//common.SysError("getToolCalls failed: " + err.Error())
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
toolCall := dto.ToolCall{
|
||||||
|
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||||
|
Type: "function",
|
||||||
|
Function: dto.FunctionCall{
|
||||||
|
Arguments: string(argsBytes),
|
||||||
|
Name: item.FunctionCall.FunctionName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
@ -144,8 +174,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
|||||||
FinishReason: relaycommon.StopFinishReason,
|
FinishReason: relaycommon.StopFinishReason,
|
||||||
}
|
}
|
||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) > 0 {
|
||||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||||
choice.Message.Content = content
|
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||||
|
} else {
|
||||||
|
choice.Message.SetStringContent(candidate.Content.Parts[0].Text)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
}
|
}
|
||||||
@ -154,7 +187,17 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
|||||||
|
|
||||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
||||||
|
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
||||||
|
respFirst := geminiResponse.Candidates[0].Content.Parts[0]
|
||||||
|
if respFirst.FunctionCall != nil {
|
||||||
|
// function response
|
||||||
|
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
|
||||||
|
} else {
|
||||||
|
// text response
|
||||||
|
choice.Delta.SetContentString(respFirst.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
choice.FinishReason = &relaycommon.StopFinishReason
|
choice.FinishReason = &relaycommon.StopFinishReason
|
||||||
var response dto.ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
@ -165,92 +208,47 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
responseJson := ""
|
|
||||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
createAt := common.GetTimestamp()
|
createAt := common.GetTimestamp()
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
dataChan := make(chan string, 5)
|
|
||||||
stopChan := make(chan bool, 2)
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
responseJson += data
|
|
||||||
data = strings.TrimSpace(data)
|
|
||||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
|
||||||
data = strings.TrimSuffix(data, "\"")
|
|
||||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
|
||||||
// send data timeout, stop the stream
|
|
||||||
common.LogError(c, "send data timeout, stop the stream")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
isFirst := true
|
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
for scanner.Scan() {
|
||||||
select {
|
data := scanner.Text()
|
||||||
case data := <-dataChan:
|
info.SetFirstResponseTime()
|
||||||
if isFirst {
|
data = strings.TrimSpace(data)
|
||||||
isFirst = false
|
if !strings.HasPrefix(data, "data: ") {
|
||||||
info.FirstResponseTime = time.Now()
|
continue
|
||||||
}
|
|
||||||
// this is used to prevent annoying \ related format bug
|
|
||||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
|
||||||
type dummyStruct struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
var dummy dummyStruct
|
|
||||||
err := json.Unmarshal([]byte(data), &dummy)
|
|
||||||
responseText += dummy.Content
|
|
||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.SetContentString(dummy.Content)
|
|
||||||
response := dto.ChatCompletionsStreamResponse{
|
|
||||||
Id: id,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: createAt,
|
|
||||||
Model: info.UpstreamModelName,
|
|
||||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
})
|
data = strings.TrimPrefix(data, "data: ")
|
||||||
var geminiChatResponses []GeminiChatResponse
|
data = strings.TrimSuffix(data, "\"")
|
||||||
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
var geminiResponse GeminiChatResponse
|
||||||
if err != nil {
|
err := json.Unmarshal([]byte(data), &geminiResponse)
|
||||||
log.Printf("cannot get gemini usage: %s", err.Error())
|
if err != nil {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
} else {
|
continue
|
||||||
for _, response := range geminiChatResponses {
|
}
|
||||||
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
|
||||||
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||||
|
if response == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
response.Id = id
|
||||||
|
response.Created = createAt
|
||||||
|
responseText += response.Choices[0].Delta.GetContentString()
|
||||||
|
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||||
|
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||||
|
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||||
|
}
|
||||||
|
err = service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, err.Error())
|
||||||
}
|
}
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||||
err := service.ObjectData(c, response)
|
err := service.ObjectData(c, response)
|
||||||
@ -259,10 +257,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
service.Done(c)
|
service.Done(c)
|
||||||
err = resp.Body.Close()
|
resp.Body.Close()
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
|
||||||
}
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,10 +15,17 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -36,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,16 +10,22 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -36,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
return requestOpenAI2Embeddings(*request), nil
|
return requestOpenAI2Embeddings(*request), nil
|
||||||
default:
|
default:
|
||||||
@ -58,11 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
|
||||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@ -14,22 +17,16 @@ import (
|
|||||||
"one-api/relay/channel/minimax"
|
"one-api/relay/channel/minimax"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/relay/constant"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
|
ResponseFormat string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,28 +71,84 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||||
|
request.StreamOptions = nil
|
||||||
|
}
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
a.ResponseFormat = request.ResponseFormat
|
||||||
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||||
|
jsonData, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshalling object: %w", err)
|
||||||
|
}
|
||||||
|
return bytes.NewReader(jsonData), nil
|
||||||
|
} else {
|
||||||
|
var requestBody bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&requestBody)
|
||||||
|
|
||||||
|
writer.WriteField("model", request.Model)
|
||||||
|
|
||||||
|
// 添加文件字段
|
||||||
|
file, header, err := c.Request.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("file is required")
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
part, err := writer.CreateFormFile("file", header.Filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("create form file failed")
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(part, file); err != nil {
|
||||||
|
return nil, errors.New("copy file failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭 multipart 编写器以设置分界线
|
||||||
|
writer.Close()
|
||||||
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
return &requestBody, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||||
|
return channel.DoFormRequest(a, c, info, requestBody)
|
||||||
|
} else {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
switch info.RelayMode {
|
||||||
var responseText string
|
case constant.RelayModeAudioSpeech:
|
||||||
var toolCount int
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||||
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
case constant.RelayModeAudioTranslation:
|
||||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
fallthrough
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
case constant.RelayModeAudioTranscription:
|
||||||
usage.CompletionTokens += toolCount * 7
|
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||||
|
case constant.RelayModeImagesGenerations:
|
||||||
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||||
|
default:
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = OaiStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ var ModelList = []string{
|
|||||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
"gpt-4o", "gpt-4o-2024-05-13",
|
"gpt-4o", "gpt-4o-2024-05-13",
|
||||||
|
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||||
"text-moderation-latest", "text-moderation-stable",
|
"text-moderation-latest", "text-moderation-stable",
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -14,38 +16,36 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
|
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
containStreamUsage := false
|
||||||
|
responseId := ""
|
||||||
|
var createAt int64 = 0
|
||||||
|
var systemFingerprint string
|
||||||
|
model := info.UpstreamModelName
|
||||||
|
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
var usage dto.Usage
|
var usage = &dto.Usage{}
|
||||||
|
var streamItems []string // store stream items
|
||||||
|
|
||||||
toolCount := 0
|
toolCount := 0
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
service.SetEventStreamHeaders(c)
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||||
return i + 1, data[0:i], nil
|
defer ticker.Stop()
|
||||||
}
|
|
||||||
if atEOF {
|
stopChan := make(chan bool)
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string, 5)
|
|
||||||
stopChan := make(chan bool, 2)
|
|
||||||
defer close(stopChan)
|
defer close(stopChan)
|
||||||
defer close(dataChan)
|
|
||||||
var wg sync.WaitGroup
|
gopool.Go(func() {
|
||||||
go func() {
|
|
||||||
wg.Add(1)
|
|
||||||
defer wg.Done()
|
|
||||||
var streamItems []string // store stream items
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
if len(data) < 6 { // ignore blank line or wrong format
|
||||||
continue
|
continue
|
||||||
@ -53,54 +53,46 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
|
||||||
// send data timeout, stop the stream
|
|
||||||
common.LogError(c, "send data timeout, stop the stream")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
data = data[6:]
|
data = data[6:]
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
|
err := service.StringData(c, data)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "streaming error: "+err.Error())
|
||||||
|
}
|
||||||
streamItems = append(streamItems, data)
|
streamItems = append(streamItems, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 计算token
|
common.SafeSendBool(stopChan, true)
|
||||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
})
|
||||||
switch info.RelayMode {
|
|
||||||
case relayconstant.RelayModeChatCompletions:
|
select {
|
||||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
case <-ticker.C:
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
// 超时处理逻辑
|
||||||
if err != nil {
|
common.LogError(c, "streaming timeout")
|
||||||
// 一次性解析失败,逐个解析
|
case <-stopChan:
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
// 正常结束
|
||||||
for _, item := range streamItems {
|
}
|
||||||
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
// 计算token
|
||||||
if err == nil {
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||||
if streamResponse.Usage != nil {
|
switch info.RelayMode {
|
||||||
if streamResponse.Usage.TotalTokens != 0 {
|
case relayconstant.RelayModeChatCompletions:
|
||||||
usage = *streamResponse.Usage
|
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||||
}
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
}
|
if err != nil {
|
||||||
for _, choice := range streamResponse.Choices {
|
// 一次性解析失败,逐个解析
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
if choice.Delta.ToolCalls != nil {
|
for _, item := range streamItems {
|
||||||
if len(choice.Delta.ToolCalls) > toolCount {
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
toolCount = len(choice.Delta.ToolCalls)
|
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||||
}
|
if err == nil {
|
||||||
for _, tool := range choice.Delta.ToolCalls {
|
responseId = streamResponse.Id
|
||||||
responseTextBuilder.WriteString(tool.Function.Name)
|
createAt = streamResponse.Created
|
||||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||||
}
|
model = streamResponse.Model
|
||||||
}
|
if service.ValidUsage(streamResponse.Usage) {
|
||||||
}
|
usage = streamResponse.Usage
|
||||||
}
|
containStreamUsage = true
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for _, streamResponse := range streamResponses {
|
|
||||||
if streamResponse.Usage != nil {
|
|
||||||
if streamResponse.Usage.TotalTokens != 0 {
|
|
||||||
usage = *streamResponse.Usage
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
@ -116,67 +108,69 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case relayconstant.RelayModeCompletions:
|
} else {
|
||||||
var streamResponses []dto.CompletionsStreamResponse
|
for _, streamResponse := range streamResponses {
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
responseId = streamResponse.Id
|
||||||
if err != nil {
|
createAt = streamResponse.Created
|
||||||
// 一次性解析失败,逐个解析
|
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
model = streamResponse.Model
|
||||||
for _, item := range streamItems {
|
if service.ValidUsage(streamResponse.Usage) {
|
||||||
var streamResponse dto.CompletionsStreamResponse
|
usage = streamResponse.Usage
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
containStreamUsage = true
|
||||||
if err == nil {
|
}
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > toolCount {
|
||||||
|
toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
for _, streamResponse := range streamResponses {
|
}
|
||||||
|
case relayconstant.RelayModeCompletions:
|
||||||
|
var streamResponses []dto.CompletionsStreamResponse
|
||||||
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
|
if err != nil {
|
||||||
|
// 一次性解析失败,逐个解析
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
for _, item := range streamItems {
|
||||||
|
var streamResponse dto.CompletionsStreamResponse
|
||||||
|
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||||
|
if err == nil {
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
responseTextBuilder.WriteString(choice.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
if len(dataChan) > 0 {
|
for _, streamResponse := range streamResponses {
|
||||||
// wait data out
|
for _, choice := range streamResponse.Choices {
|
||||||
time.Sleep(2 * time.Second)
|
responseTextBuilder.WriteString(choice.Text)
|
||||||
}
|
}
|
||||||
common.SafeSendBool(stopChan, true)
|
|
||||||
}()
|
|
||||||
service.SetEventStreamHeaders(c)
|
|
||||||
isFirst := true
|
|
||||||
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
common.LogError(c, "reading data from upstream timeout")
|
|
||||||
return false
|
|
||||||
case data := <-dataChan:
|
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
}
|
||||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
|
||||||
if strings.HasPrefix(data, "data: [DONE]") {
|
|
||||||
data = data[:12]
|
|
||||||
}
|
|
||||||
// some implementations may add \r at the end of data
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
c.Render(-1, common.CustomEvent{Data: data})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
|
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
return nil, &usage, responseTextBuilder.String(), toolCount
|
if !containStreamUsage {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||||
|
response.SetSystemFingerprint(systemFingerprint)
|
||||||
|
service.ObjectData(c, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
service.Done(c)
|
||||||
|
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
@ -213,11 +207,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
resp.Body.Close()
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
@ -232,3 +222,134 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
}
|
}
|
||||||
return nil, &simpleResponse.Usage
|
return nil, &simpleResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
// Reset response body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
|
// So the httpClient will be confused by the response.
|
||||||
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.TotalTokens = info.PromptTokens
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
var audioResp dto.AudioResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &audioResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset response body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
|
// So the httpClient will be confused by the response.
|
||||||
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
var text string
|
||||||
|
switch responseFormat {
|
||||||
|
case "json":
|
||||||
|
text, err = getTextFromJSON(responseBody)
|
||||||
|
case "text":
|
||||||
|
text, err = getTextFromText(responseBody)
|
||||||
|
case "srt":
|
||||||
|
text, err = getTextFromSRT(responseBody)
|
||||||
|
case "verbose_json":
|
||||||
|
text, err = getTextFromVerboseJSON(responseBody)
|
||||||
|
case "vtt":
|
||||||
|
text, err = getTextFromVTT(responseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTextFromVTT(body []byte) (string, error) {
|
||||||
|
return getTextFromSRT(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTextFromVerboseJSON(body []byte) (string, error) {
|
||||||
|
var whisperResponse dto.WhisperVerboseJSONResponse
|
||||||
|
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||||
|
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||||
|
}
|
||||||
|
return whisperResponse.Text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTextFromSRT(body []byte) (string, error) {
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
||||||
|
var builder strings.Builder
|
||||||
|
var textLine bool
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if textLine {
|
||||||
|
builder.WriteString(line)
|
||||||
|
textLine = false
|
||||||
|
continue
|
||||||
|
} else if strings.Contains(line, "-->") {
|
||||||
|
textLine = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return builder.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTextFromText(body []byte) (string, error) {
|
||||||
|
return strings.TrimSuffix(string(body), "\n"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTextFromJSON(body []byte) (string, error) {
|
||||||
|
var whisperResponse dto.AudioResponse
|
||||||
|
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||||
|
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||||
|
}
|
||||||
|
return whisperResponse.Text, nil
|
||||||
|
}
|
||||||
|
@ -15,12 +15,17 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -33,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -10,18 +10,22 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -34,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@ -54,11 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
|
||||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -23,12 +23,17 @@ type Adaptor struct {
|
|||||||
Timestamp int64
|
Timestamp int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
a.Action = "ChatCompletions"
|
a.Action = "ChatCompletions"
|
||||||
a.Version = "2023-09-01"
|
a.Version = "2023-09-01"
|
||||||
a.Timestamp = common.GetTimestamp()
|
a.Timestamp = common.GetTimestamp()
|
||||||
@ -47,7 +52,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -16,12 +16,17 @@ type Adaptor struct {
|
|||||||
request *dto.GeneralOpenAIRequest
|
request *dto.GeneralOpenAIRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -33,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -14,12 +14,17 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -37,7 +42,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -153,18 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
|
|||||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var usage *dto.Usage
|
var usage *dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(bufio.ScanLines)
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
|
||||||
return i + 2, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string)
|
||||||
metaChan := make(chan string)
|
metaChan := make(chan string)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
|
@ -10,18 +10,22 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@ -35,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@ -55,13 +59,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
var toolCount int
|
|
||||||
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
|
||||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
usage.CompletionTokens += toolCount * 7
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package zhipu_4v
|
package zhipu_4v
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"glm-4", "glm-4v", "glm-3-turbo",
|
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "zhipu_4v"
|
var ChannelName = "zhipu_4v"
|
||||||
|
@ -17,6 +17,7 @@ type RelayInfo struct {
|
|||||||
TokenUnlimited bool
|
TokenUnlimited bool
|
||||||
StartTime time.Time
|
StartTime time.Time
|
||||||
FirstResponseTime time.Time
|
FirstResponseTime time.Time
|
||||||
|
setFirstResponse bool
|
||||||
ApiType int
|
ApiType int
|
||||||
IsStream bool
|
IsStream bool
|
||||||
RelayMode int
|
RelayMode int
|
||||||
@ -68,7 +69,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
info.ApiVersion = GetAPIVersion(c)
|
info.ApiVersion = GetAPIVersion(c)
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
|
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||||
|
info.ChannelType == common.ChannelCloudflare {
|
||||||
info.SupportStreamOptions = true
|
info.SupportStreamOptions = true
|
||||||
}
|
}
|
||||||
return info
|
return info
|
||||||
@ -82,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
|||||||
info.IsStream = isStream
|
info.IsStream = isStream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (info *RelayInfo) SetFirstResponseTime() {
|
||||||
|
if !info.setFirstResponse {
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
info.setFirstResponse = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type TaskRelayInfo struct {
|
type TaskRelayInfo struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
|
@ -1,50 +1,17 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var StopFinishReason = "stop"
|
var StopFinishReason = "stop"
|
||||||
|
|
||||||
func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
|
||||||
OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
|
||||||
Type: "upstream_error",
|
|
||||||
Code: "bad_response_status_code",
|
|
||||||
Param: strconv.Itoa(resp.StatusCode),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var textResponse dto.TextResponseWithError
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ const (
|
|||||||
APITypeCohere
|
APITypeCohere
|
||||||
APITypeDify
|
APITypeDify
|
||||||
APITypeJina
|
APITypeJina
|
||||||
|
APITypeCloudflare
|
||||||
|
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = APITypeDify
|
apiType = APITypeDify
|
||||||
case common.ChannelTypeJina:
|
case common.ChannelTypeJina:
|
||||||
apiType = APITypeJina
|
apiType = APITypeJina
|
||||||
|
case common.ChannelCloudflare:
|
||||||
|
apiType = APITypeCloudflare
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return APITypeOpenAI, false
|
return APITypeOpenAI, false
|
||||||
|
@ -13,6 +13,7 @@ const (
|
|||||||
RelayModeModerations
|
RelayModeModerations
|
||||||
RelayModeImagesGenerations
|
RelayModeImagesGenerations
|
||||||
RelayModeEdits
|
RelayModeEdits
|
||||||
|
|
||||||
RelayModeMidjourneyImagine
|
RelayModeMidjourneyImagine
|
||||||
RelayModeMidjourneyDescribe
|
RelayModeMidjourneyDescribe
|
||||||
RelayModeMidjourneyBlend
|
RelayModeMidjourneyBlend
|
||||||
@ -22,16 +23,19 @@ const (
|
|||||||
RelayModeMidjourneyTaskFetch
|
RelayModeMidjourneyTaskFetch
|
||||||
RelayModeMidjourneyTaskImageSeed
|
RelayModeMidjourneyTaskImageSeed
|
||||||
RelayModeMidjourneyTaskFetchByCondition
|
RelayModeMidjourneyTaskFetchByCondition
|
||||||
RelayModeAudioSpeech
|
|
||||||
RelayModeAudioTranscription
|
|
||||||
RelayModeAudioTranslation
|
|
||||||
RelayModeMidjourneyAction
|
RelayModeMidjourneyAction
|
||||||
RelayModeMidjourneyModal
|
RelayModeMidjourneyModal
|
||||||
RelayModeMidjourneyShorten
|
RelayModeMidjourneyShorten
|
||||||
RelayModeSwapFace
|
RelayModeSwapFace
|
||||||
|
|
||||||
|
RelayModeAudioSpeech // tts
|
||||||
|
RelayModeAudioTranscription // whisper
|
||||||
|
RelayModeAudioTranslation // whisper
|
||||||
|
|
||||||
RelayModeSunoFetch
|
RelayModeSunoFetch
|
||||||
RelayModeSunoFetchByID
|
RelayModeSunoFetchByID
|
||||||
RelayModeSunoSubmit
|
RelayModeSunoSubmit
|
||||||
|
|
||||||
RelayModeRerank
|
RelayModeRerank
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
package relay
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@ -16,69 +13,71 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||||
tokenId := c.GetInt("token_id")
|
audioRequest := &dto.AudioRequest{}
|
||||||
channelType := c.GetInt("channel")
|
err := common.UnmarshalBodyReusable(c, audioRequest)
|
||||||
channelId := c.GetInt("channel_id")
|
if err != nil {
|
||||||
userId := c.GetInt("id")
|
return nil, err
|
||||||
group := c.GetString("group")
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
var audioRequest dto.TextToSpeechRequest
|
|
||||||
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
|
||||||
err := common.UnmarshalBodyReusable(c, &audioRequest)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
audioRequest = dto.TextToSpeechRequest{
|
|
||||||
Model: "whisper-1",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
//err := common.UnmarshalBodyReusable(c, &audioRequest)
|
switch info.RelayMode {
|
||||||
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
// request validation
|
if audioRequest.Model == "" {
|
||||||
if audioRequest.Model == "" {
|
return nil, errors.New("model is required")
|
||||||
return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
|
||||||
if audioRequest.Voice == "" {
|
|
||||||
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
var err error
|
|
||||||
promptTokens := 0
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
|
||||||
if constant.ShouldCheckPromptSensitive() {
|
if constant.ShouldCheckPromptSensitive() {
|
||||||
err = service.CheckSensitiveInput(audioRequest.Input)
|
err := service.CheckSensitiveInput(audioRequest.Input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
if audioRequest.Model == "" {
|
||||||
|
audioRequest.Model = c.PostForm("model")
|
||||||
|
}
|
||||||
|
if audioRequest.Model == "" {
|
||||||
|
return nil, errors.New("model is required")
|
||||||
|
}
|
||||||
|
if audioRequest.ResponseFormat == "" {
|
||||||
|
audioRequest.ResponseFormat = "json"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return audioRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
promptTokens := 0
|
||||||
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
|
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
preConsumedTokens = promptTokens
|
preConsumedTokens = promptTokens
|
||||||
|
relayInfo.PromptTokens = promptTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
modelRatio := common.GetModelRatio(audioRequest.Model)
|
modelRatio := common.GetModelRatio(audioRequest.Model)
|
||||||
groupRatio := common.GetGroupRatio(group)
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -88,28 +87,12 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
}
|
}
|
||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
succeed := false
|
|
||||||
defer func() {
|
|
||||||
if succeed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
// we need to roll back the pre-consumed quota
|
|
||||||
defer func() {
|
|
||||||
go func() {
|
|
||||||
// negative means add quota back for token & user
|
|
||||||
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
|
|
||||||
}()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetString("model_mapping")
|
||||||
if modelMapping != "" {
|
if modelMapping != "" {
|
||||||
@ -122,133 +105,44 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
audioRequest.Model = modelMap[audioRequest.Model]
|
audioRequest.Model = modelMap[audioRequest.Model]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
relayInfo.UpstreamModelName = audioRequest.Model
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
requestURL := c.Request.URL.String()
|
if adaptor == nil {
|
||||||
if c.GetString("base_url") != "" {
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
}
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
|
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||||
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
|
||||||
apiVersion := relaycommon.GetAPIVersion(c)
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestBody := c.Request.Body
|
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
req.Header.Set("api-key", apiKey)
|
|
||||||
req.ContentLength = c.Request.ContentLength
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
|
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = req.Body.Close()
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
if err != nil {
|
if resp != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
if resp.StatusCode != http.StatusOK {
|
||||||
}
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
err = c.Request.Body.Close()
|
openaiErr := service.RelayErrorHandler(resp)
|
||||||
if err != nil {
|
// reset status code 重置状态码
|
||||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
}
|
return openaiErr
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return relaycommon.RelayErrorHandler(resp)
|
|
||||||
}
|
|
||||||
succeed = true
|
|
||||||
|
|
||||||
var audioResponse dto.AudioResponse
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go func() {
|
|
||||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
|
||||||
quota := 0
|
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
|
||||||
quota = promptTokens
|
|
||||||
} else {
|
|
||||||
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
|
|
||||||
}
|
|
||||||
quota = int(float64(quota) * ratio)
|
|
||||||
if ratio != 0 && quota <= 0 {
|
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
other := make(map[string]interface{})
|
|
||||||
other["model_ratio"] = modelRatio
|
|
||||||
other["group_ratio"] = groupRatio
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
|
||||||
|
|
||||||
} else {
|
|
||||||
err = json.Unmarshal(responseBody, &audioResponse)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
contains, words := service.SensitiveWordContains(audioResponse.Text)
|
|
||||||
if contains {
|
|
||||||
return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -14,72 +13,71 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
|
||||||
tokenId := c.GetInt("token_id")
|
imageRequest := &dto.ImageRequest{}
|
||||||
channelType := c.GetInt("channel")
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
var imageRequest dto.ImageRequest
|
|
||||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if imageRequest.Prompt == "" {
|
||||||
if imageRequest.Model == "" {
|
return nil, errors.New("prompt is required")
|
||||||
imageRequest.Model = "dall-e-3"
|
|
||||||
}
|
}
|
||||||
if imageRequest.Size == "" {
|
if strings.Contains(imageRequest.Size, "×") {
|
||||||
imageRequest.Size = "1024x1024"
|
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
|
||||||
}
|
}
|
||||||
if imageRequest.N == 0 {
|
if imageRequest.N == 0 {
|
||||||
imageRequest.N = 1
|
imageRequest.N = 1
|
||||||
}
|
}
|
||||||
// Prompt validation
|
if imageRequest.Size == "" {
|
||||||
if imageRequest.Prompt == "" {
|
imageRequest.Size = "1024x1024"
|
||||||
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
|
if imageRequest.Model == "" {
|
||||||
if constant.ShouldCheckPromptSensitive() {
|
imageRequest.Model = "dall-e-2"
|
||||||
err = service.CheckSensitiveInput(imageRequest.Prompt)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if imageRequest.Quality == "" {
|
||||||
if strings.Contains(imageRequest.Size, "×") {
|
imageRequest.Quality = "standard"
|
||||||
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
// Not "256x256", "512x512", or "1024x1024"
|
// Not "256x256", "512x512", or "1024x1024"
|
||||||
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
||||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||||
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
|
||||||
}
|
}
|
||||||
} else if imageRequest.Model == "dall-e-3" {
|
} else if imageRequest.Model == "dall-e-3" {
|
||||||
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
||||||
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
|
||||||
}
|
}
|
||||||
if imageRequest.N != 1 {
|
//if imageRequest.N != 1 {
|
||||||
return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
|
// return nil, errors.New("n must be 1")
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
// N should between 1 and 10
|
||||||
|
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||||
|
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||||
|
//}
|
||||||
|
if constant.ShouldCheckPromptSensitive() {
|
||||||
|
err := service.CheckSensitiveInput(imageRequest.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return imageRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
// N should between 1 and 10
|
func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
|
||||||
|
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetString("model_mapping")
|
||||||
isModelMapped := false
|
|
||||||
if modelMapping != "" {
|
if modelMapping != "" {
|
||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
@ -88,31 +86,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
}
|
}
|
||||||
if modelMap[imageRequest.Model] != "" {
|
if modelMap[imageRequest.Model] != "" {
|
||||||
imageRequest.Model = modelMap[imageRequest.Model]
|
imageRequest.Model = modelMap[imageRequest.Model]
|
||||||
isModelMapped = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
relayInfo.UpstreamModelName = imageRequest.Model
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
|
||||||
apiVersion := relaycommon.GetAPIVersion(c)
|
|
||||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
|
|
||||||
}
|
|
||||||
var requestBody io.Reader
|
|
||||||
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
|
|
||||||
jsonStr, err := json.Marshal(imageRequest)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
|
|
||||||
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
|
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
|
||||||
if !success {
|
if !success {
|
||||||
@ -121,8 +97,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
// per 1 modelRatio = $0.04 / 16
|
// per 1 modelRatio = $0.04 / 16
|
||||||
modelPrice = 0.0025 * modelRatio
|
modelPrice = 0.0025 * modelRatio
|
||||||
}
|
}
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
|
|
||||||
sizeRatio := 1.0
|
sizeRatio := 1.0
|
||||||
// Size
|
// Size
|
||||||
@ -150,98 +127,60 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
|
var requestBody io.Reader
|
||||||
|
|
||||||
|
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
token := c.Request.Header.Get("Authorization")
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if channelType == common.ChannelTypeAzure { // Azure authentication
|
if err != nil {
|
||||||
token = strings.TrimPrefix(token, "Bearer ")
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
req.Header.Set("api-key", token)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", token)
|
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
|
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = req.Body.Close()
|
if resp != nil {
|
||||||
if err != nil {
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return relaycommon.RelayErrorHandler(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
var textResponse dto.ImageResponse
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return
|
openaiErr := service.RelayErrorHandler(resp)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
}
|
}
|
||||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
quality := "normal"
|
|
||||||
if imageRequest.Quality == "hd" {
|
|
||||||
quality = "hd"
|
|
||||||
}
|
|
||||||
logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality)
|
|
||||||
other := make(map[string]interface{})
|
|
||||||
other["model_price"] = modelPrice
|
|
||||||
other["group_ratio"] = groupRatio
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
usage := &dto.Usage{
|
||||||
c.Writer.Header().Set(k, v[0])
|
PromptTokens: imageRequest.N,
|
||||||
|
TotalTokens: imageRequest.N,
|
||||||
}
|
}
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
quality := "standard"
|
||||||
if err != nil {
|
if imageRequest.Quality == "hd" {
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
quality = "hd"
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
||||||
|
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
relayInfo.UpstreamModelName = textRequest.Model
|
relayInfo.UpstreamModelName = textRequest.Model
|
||||||
modelPrice, success := common.GetModelPrice(textRequest.Model, false)
|
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
var preConsumedQuota int
|
var preConsumedQuota int
|
||||||
@ -112,7 +112,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !success {
|
if !getModelPriceSuccess {
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if textRequest.MaxTokens != 0 {
|
if textRequest.MaxTokens != 0 {
|
||||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||||
@ -150,10 +150,10 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo, *textRequest)
|
adaptor.Init(relayInfo)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
|||||||
|
|
||||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||||
modelPrice float64, usePrice bool) {
|
modelPrice float64, usePrice bool, extraContent string) {
|
||||||
|
|
||||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||||
promptTokens := usage.PromptTokens
|
promptTokens := usage.PromptTokens
|
||||||
@ -309,7 +309,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
|||||||
}
|
}
|
||||||
totalTokens := promptTokens + completionTokens
|
totalTokens := promptTokens + completionTokens
|
||||||
var logContent string
|
var logContent string
|
||||||
if modelPrice == -1 {
|
if !usePrice {
|
||||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
||||||
} else {
|
} else {
|
||||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||||
@ -350,6 +350,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
|||||||
logModel = "g-*"
|
logModel = "g-*"
|
||||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||||
}
|
}
|
||||||
|
if extraContent != "" {
|
||||||
|
logContent += ", " + extraContent
|
||||||
|
}
|
||||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"one-api/relay/channel/aws"
|
"one-api/relay/channel/aws"
|
||||||
"one-api/relay/channel/baidu"
|
"one-api/relay/channel/baidu"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
|
"one-api/relay/channel/cloudflare"
|
||||||
"one-api/relay/channel/cohere"
|
"one-api/relay/channel/cohere"
|
||||||
"one-api/relay/channel/dify"
|
"one-api/relay/channel/dify"
|
||||||
"one-api/relay/channel/gemini"
|
"one-api/relay/channel/gemini"
|
||||||
@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &dify.Adaptor{}
|
return &dify.Adaptor{}
|
||||||
case constant.APITypeJina:
|
case constant.APITypeJina:
|
||||||
return &jina.Adaptor{}
|
return &jina.Adaptor{}
|
||||||
|
case constant.APITypeCloudflare:
|
||||||
|
return &cloudflare.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
adaptor.InitRerank(relayInfo, *rerankRequest)
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -99,6 +99,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
|
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
if !common.AutomaticEnableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openAIErr != nil {
|
if openaiWithStatusErr != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if status != common.ChannelStatusAutoDisabled {
|
if status != common.ChannelStatusAutoDisabled {
|
||||||
|
@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
|
|||||||
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
Error: dto.OpenAIError{
|
Error: dto.OpenAIError{
|
||||||
Message: "",
|
Type: "upstream_error",
|
||||||
Type: "upstream_error",
|
Code: "bad_response_status_code",
|
||||||
Code: "bad_response_status_code",
|
Param: strconv.Itoa(resp.StatusCode),
|
||||||
Param: strconv.Itoa(resp.StatusCode),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
@ -2,10 +2,11 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetEventStreamHeaders(c *gin.Context) {
|
func SetEventStreamHeaders(c *gin.Context) {
|
||||||
@ -16,11 +17,16 @@ func SetEventStreamHeaders(c *gin.Context) {
|
|||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
}
|
}
|
||||||
|
|
||||||
func StringData(c *gin.Context, str string) {
|
func StringData(c *gin.Context, str string) error {
|
||||||
str = strings.TrimPrefix(str, "data: ")
|
//str = strings.TrimPrefix(str, "data: ")
|
||||||
str = strings.TrimSuffix(str, "\r")
|
//str = strings.TrimSuffix(str, "\r")
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||||
c.Writer.Flush()
|
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
return errors.New("streaming error: flusher not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ObjectData(c *gin.Context, object interface{}) error {
|
func ObjectData(c *gin.Context, object interface{}) error {
|
||||||
@ -28,10 +34,14 @@ func ObjectData(c *gin.Context, object interface{}) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshalling object: %w", err)
|
return fmt.Errorf("error marshalling object: %w", err)
|
||||||
}
|
}
|
||||||
StringData(c, string(jsonData))
|
return StringData(c, string(jsonData))
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Done(c *gin.Context) {
|
func Done(c *gin.Context) {
|
||||||
StringData(c, "[DONE]")
|
_ = StringData(c, "[DONE]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetResponseID(c *gin.Context) string {
|
||||||
|
logID := c.GetString("X-Oneapi-Request-Id")
|
||||||
|
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||||
}
|
}
|
@ -9,6 +9,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@ -71,13 +72,20 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||||
// TODO: 非流模式下不计算图片token数量
|
|
||||||
if model == "glm-4v" {
|
if model == "glm-4v" {
|
||||||
return 1047, nil
|
return 1047, nil
|
||||||
}
|
}
|
||||||
if imageUrl.Detail == "low" {
|
if imageUrl.Detail == "low" {
|
||||||
return 85, nil
|
return 85, nil
|
||||||
}
|
}
|
||||||
|
// TODO: 非流模式下不计算图片token数量
|
||||||
|
if !constant.GetMediaTokenNotStream && !stream {
|
||||||
|
return 1000, nil
|
||||||
|
}
|
||||||
|
// 是否统计图片token
|
||||||
|
if !constant.GetMediaToken {
|
||||||
|
return 1000, nil
|
||||||
|
}
|
||||||
// 同步One API的图片计费逻辑
|
// 同步One API的图片计费逻辑
|
||||||
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
||||||
imageUrl.Detail = "high"
|
imageUrl.Detail = "high"
|
||||||
|
@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d
|
|||||||
Usage: &usage,
|
Usage: &usage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ValidUsage(usage *dto.Usage) bool {
|
||||||
|
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
|
||||||
|
}
|
||||||
|
@ -367,7 +367,7 @@ const LogsTable = () => {
|
|||||||
dataIndex: 'content',
|
dataIndex: 'content',
|
||||||
render: (text, record, index) => {
|
render: (text, record, index) => {
|
||||||
let other = getLogOther(record.other);
|
let other = getLogOther(record.other);
|
||||||
if (other == null) {
|
if (other == null || record.type !== 2) {
|
||||||
return (
|
return (
|
||||||
<Paragraph
|
<Paragraph
|
||||||
ellipsis={{
|
ellipsis={{
|
||||||
|
@ -1,16 +1,10 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import {
|
|
||||||
Button,
|
|
||||||
Form,
|
|
||||||
Grid,
|
|
||||||
Header,
|
|
||||||
Image,
|
|
||||||
Message,
|
|
||||||
Segment,
|
|
||||||
} from 'semantic-ui-react';
|
|
||||||
import { Link, useNavigate } from 'react-router-dom';
|
import { Link, useNavigate } from 'react-router-dom';
|
||||||
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
||||||
import Turnstile from 'react-turnstile';
|
import Turnstile from 'react-turnstile';
|
||||||
|
import { Button, Card, Form, Layout } from '@douyinfe/semi-ui';
|
||||||
|
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
|
||||||
|
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||||
|
|
||||||
const RegisterForm = () => {
|
const RegisterForm = () => {
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
@ -46,9 +40,7 @@ const RegisterForm = () => {
|
|||||||
|
|
||||||
let navigate = useNavigate();
|
let navigate = useNavigate();
|
||||||
|
|
||||||
function handleChange(e) {
|
function handleChange(name, value) {
|
||||||
const { name, value } = e.target;
|
|
||||||
console.log(name, value);
|
|
||||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,96 +100,116 @@ const RegisterForm = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid textAlign='center' style={{ marginTop: '48px' }}>
|
<div>
|
||||||
<Grid.Column style={{ maxWidth: 450 }}>
|
<Layout>
|
||||||
<Header as='h2' color='' textAlign='center'>
|
<Layout.Header></Layout.Header>
|
||||||
<Image src={logo} /> 新用户注册
|
<Layout.Content>
|
||||||
</Header>
|
<div
|
||||||
<Form size='large'>
|
style={{
|
||||||
<Segment>
|
justifyContent: 'center',
|
||||||
<Form.Input
|
display: 'flex',
|
||||||
fluid
|
marginTop: 120,
|
||||||
icon='user'
|
}}
|
||||||
iconPosition='left'
|
>
|
||||||
placeholder='输入用户名,最长 12 位'
|
<div style={{ width: 500 }}>
|
||||||
onChange={handleChange}
|
<Card>
|
||||||
name='username'
|
<Title heading={2} style={{ textAlign: 'center' }}>
|
||||||
/>
|
新用户注册
|
||||||
<Form.Input
|
</Title>
|
||||||
fluid
|
<Form size='large'>
|
||||||
icon='lock'
|
<Form.Input
|
||||||
iconPosition='left'
|
field={'username'}
|
||||||
placeholder='输入密码,最短 8 位,最长 20 位'
|
label={'用户名'}
|
||||||
onChange={handleChange}
|
placeholder='用户名'
|
||||||
name='password'
|
name='username'
|
||||||
type='password'
|
onChange={(value) => handleChange('username', value)}
|
||||||
/>
|
/>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
fluid
|
field={'password'}
|
||||||
icon='lock'
|
label={'密码'}
|
||||||
iconPosition='left'
|
placeholder='密码,最短 8 位,最长 20 位'
|
||||||
placeholder='输入密码,最短 8 位,最长 20 位'
|
name='password'
|
||||||
onChange={handleChange}
|
type='password'
|
||||||
name='password2'
|
onChange={(value) => handleChange('password', value)}
|
||||||
type='password'
|
/>
|
||||||
/>
|
<Form.Input
|
||||||
{showEmailVerification ? (
|
field={'password2'}
|
||||||
<>
|
label={'确认密码'}
|
||||||
<Form.Input
|
placeholder='确认密码'
|
||||||
fluid
|
name='password2'
|
||||||
icon='mail'
|
type='password'
|
||||||
iconPosition='left'
|
onChange={(value) => handleChange('password2', value)}
|
||||||
placeholder='输入邮箱地址'
|
/>
|
||||||
onChange={handleChange}
|
{showEmailVerification ? (
|
||||||
name='email'
|
<>
|
||||||
type='email'
|
<Form.Input
|
||||||
action={
|
field={'email'}
|
||||||
<Button onClick={sendVerificationCode} disabled={loading}>
|
label={'邮箱'}
|
||||||
获取验证码
|
placeholder='输入邮箱地址'
|
||||||
</Button>
|
onChange={(value) => handleChange('email', value)}
|
||||||
}
|
name='email'
|
||||||
|
type='email'
|
||||||
|
suffix={
|
||||||
|
<Button
|
||||||
|
onClick={sendVerificationCode}
|
||||||
|
disabled={loading}
|
||||||
|
>
|
||||||
|
获取验证码
|
||||||
|
</Button>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Form.Input
|
||||||
|
field={'verification_code'}
|
||||||
|
label={'验证码'}
|
||||||
|
placeholder='输入验证码'
|
||||||
|
onChange={(value) =>
|
||||||
|
handleChange('verification_code', value)
|
||||||
|
}
|
||||||
|
name='verification_code'
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<></>
|
||||||
|
)}
|
||||||
|
<Button
|
||||||
|
theme='solid'
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
type={'primary'}
|
||||||
|
size='large'
|
||||||
|
htmlType={'submit'}
|
||||||
|
onClick={handleSubmit}
|
||||||
|
>
|
||||||
|
注册
|
||||||
|
</Button>
|
||||||
|
</Form>
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
marginTop: 20,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text>
|
||||||
|
已有账户?
|
||||||
|
<Link to='/login'>点击登录</Link>
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
{turnstileEnabled ? (
|
||||||
|
<Turnstile
|
||||||
|
sitekey={turnstileSiteKey}
|
||||||
|
onVerify={(token) => {
|
||||||
|
setTurnstileToken(token);
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
<Form.Input
|
) : (
|
||||||
fluid
|
<></>
|
||||||
icon='lock'
|
)}
|
||||||
iconPosition='left'
|
</div>
|
||||||
placeholder='输入验证码'
|
</div>
|
||||||
onChange={handleChange}
|
</Layout.Content>
|
||||||
name='verification_code'
|
</Layout>
|
||||||
/>
|
</div>
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<></>
|
|
||||||
)}
|
|
||||||
{turnstileEnabled ? (
|
|
||||||
<Turnstile
|
|
||||||
sitekey={turnstileSiteKey}
|
|
||||||
onVerify={(token) => {
|
|
||||||
setTurnstileToken(token);
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<></>
|
|
||||||
)}
|
|
||||||
<Button
|
|
||||||
color='green'
|
|
||||||
fluid
|
|
||||||
size='large'
|
|
||||||
onClick={handleSubmit}
|
|
||||||
loading={loading}
|
|
||||||
>
|
|
||||||
注册
|
|
||||||
</Button>
|
|
||||||
</Segment>
|
|
||||||
</Form>
|
|
||||||
<Message>
|
|
||||||
已有账户?
|
|
||||||
<Link to='/login' className='btn btn-link'>
|
|
||||||
点击登录
|
|
||||||
</Link>
|
|
||||||
</Message>
|
|
||||||
</Grid.Column>
|
|
||||||
</Grid>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -99,6 +99,13 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'orange',
|
color: 'orange',
|
||||||
label: 'Google PaLM2',
|
label: 'Google PaLM2',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 39,
|
||||||
|
text: 'Cloudflare',
|
||||||
|
value: 39,
|
||||||
|
color: 'grey',
|
||||||
|
label: 'Cloudflare',
|
||||||
|
},
|
||||||
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
||||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
||||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||||
|
@ -601,6 +601,24 @@ const EditChannel = (props) => {
|
|||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
{inputs.type === 39 && (
|
||||||
|
<>
|
||||||
|
<div style={{ marginTop: 10 }}>
|
||||||
|
<Typography.Text strong>Account ID:</Typography.Text>
|
||||||
|
</div>
|
||||||
|
<Input
|
||||||
|
name='other'
|
||||||
|
placeholder={
|
||||||
|
'请输入Account ID,例如:d6b5da8hk1awo8nap34ube6gh'
|
||||||
|
}
|
||||||
|
onChange={(value) => {
|
||||||
|
handleInputChange('other', value);
|
||||||
|
}}
|
||||||
|
value={inputs.other}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<div style={{ marginTop: 10 }}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型:</Typography.Text>
|
<Typography.Text strong>模型:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
|
Loading…
Reference in New Issue
Block a user