mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-24 18:33:41 +08:00
Compare commits
34 Commits
v0.5.0-alp
...
v0.5.2-alp
Author | SHA1 | Date | |
---|---|---|---|
|
d96cf2e84d | ||
|
446337c329 | ||
|
1dfa190e79 | ||
|
2d49ca6a07 | ||
|
89bcaaf989 | ||
|
afcd1bd27b | ||
|
c2c455c980 | ||
|
30a7f1a1c7 | ||
|
c9d2e42a9e | ||
|
3fca6ff534 | ||
|
8cbbeb784f | ||
|
ec88c0c240 | ||
|
065147b440 | ||
|
fe8f216dd9 | ||
|
b7d0616ae0 | ||
|
ce9c8024a6 | ||
|
8a866078b2 | ||
|
3e81d8af45 | ||
|
b8cb86c2c1 | ||
|
f45d586400 | ||
|
50dec03ff3 | ||
|
f31d400b6f | ||
|
130e6bfd83 | ||
|
d1335ebc01 | ||
|
e92da7928b | ||
|
d1b6f492b6 | ||
|
b9f6461dd4 | ||
|
0a39521a3d | ||
|
c134604cee | ||
|
929e43ef81 | ||
|
dce8bbe1ca | ||
|
bc2f48b1f2 | ||
|
889af8b2db | ||
|
4eea096654 |
25
README.en.md
25
README.en.md
@@ -57,15 +57,13 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
|
||||
> **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability.
|
||||
|
||||
## Features
|
||||
1. Supports multiple API access channels:
|
||||
+ [x] Official OpenAI channel (support proxy configuration)
|
||||
+ [x] **Azure OpenAI API**
|
||||
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
|
||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||
+ [x] [API2D](https://api2d.com/r/197971)
|
||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`)
|
||||
+ [x] Custom channel: Various third-party proxy services not included in the list
|
||||
1. Support for multiple large models:
|
||||
+ [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||
+ [x] [Anthropic Claude Series Models](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 Series Models](https://developers.generativeai.google)
|
||||
+ [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
|
||||
2. Supports access to multiple channels through **load balancing**.
|
||||
3. Supports **stream mode** that enables typewriter-like effect through stream transmission.
|
||||
4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details.
|
||||
@@ -139,7 +137,7 @@ The initial account username is `root` and password is `123456`.
|
||||
cd one-api/web
|
||||
npm install
|
||||
npm run build
|
||||
|
||||
|
||||
# Build the backend
|
||||
cd ..
|
||||
go mod download
|
||||
@@ -175,7 +173,12 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
|
||||
<summary><strong>Deploy on Sealos</strong></summary>
|
||||
<div>
|
||||
|
||||
Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md).
|
||||
> Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users.
|
||||
|
||||
> Click the button below to deploy with one click.👇
|
||||
|
||||
[](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
|
||||
|
||||
|
||||
</div>
|
||||
</details>
|
||||
|
14
README.md
14
README.md
@@ -63,9 +63,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
|
||||
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||
2. 支持配置镜像以及众多第三方代理服务:
|
||||
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
|
||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||
+ [x] [API2D](https://api2d.com/r/197971)
|
||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||
@@ -93,7 +94,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
19. 支持通过系统访问令牌访问管理 API。
|
||||
20. 支持 Cloudflare Turnstile 用户校验。
|
||||
21. 支持用户管理,支持**多种用户登录注册方式**:
|
||||
+ 邮箱登录注册以及通过邮箱进行密码重置。
|
||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||
|
||||
@@ -152,7 +153,7 @@ sudo service nginx restart
|
||||
cd one-api/web
|
||||
npm install
|
||||
npm run build
|
||||
|
||||
|
||||
# 构建后端
|
||||
cd ..
|
||||
go mod download
|
||||
@@ -210,9 +211,11 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
||||
<summary><strong>部署到 Sealos </strong></summary>
|
||||
<div>
|
||||
|
||||
> Sealos 可视化部署,仅需 1 分钟。
|
||||
> Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩。
|
||||
|
||||
参考这个[教程](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md)中 1~5 步。
|
||||
点击以下按钮一键部署:
|
||||
|
||||
[](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
|
||||
|
||||
</div>
|
||||
</details>
|
||||
@@ -313,6 +316,7 @@ https://openai.justsong.cn
|
||||
+ 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率)
|
||||
+ 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。
|
||||
+ 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。
|
||||
+ 注意,One API 的默认倍率就是官方倍率,是已经调整过的。
|
||||
2. 账户额度足够为什么提示额度不足?
|
||||
+ 请检查你的令牌额度是否足够,这个和账户额度是分开的。
|
||||
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。
|
||||
|
@@ -42,6 +42,19 @@ var WeChatAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
|
||||
var EmailDomainRestrictionEnabled = false
|
||||
var EmailDomainWhitelist = []string{
|
||||
"gmail.com",
|
||||
"163.com",
|
||||
"126.com",
|
||||
"qq.com",
|
||||
"outlook.com",
|
||||
"hotmail.com",
|
||||
"icloud.com",
|
||||
"yahoo.com",
|
||||
"foxmail.com",
|
||||
}
|
||||
|
||||
var LogConsumeEnabled = true
|
||||
|
||||
var SMTPServer = ""
|
||||
@@ -77,6 +90,8 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
|
||||
|
||||
const (
|
||||
RoleGuestUser = 0
|
||||
RoleCommonUser = 1
|
||||
@@ -154,24 +169,28 @@ const (
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
}
|
||||
|
@@ -42,10 +42,14 @@ var ModelRatio = map[string]float64{
|
||||
"claude-2": 30,
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
|
||||
"qwen-plus-v1": 0.5715, // Same as above
|
||||
"SparkDesk": 0.8572, // TBD
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
|
@@ -16,6 +16,16 @@ import (
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
case common.ChannelTypeAnthropic:
|
||||
fallthrough
|
||||
case common.ChannelTypeBaidu:
|
||||
fallthrough
|
||||
case common.ChannelTypeZhipu:
|
||||
fallthrough
|
||||
case common.ChannelTypeXunfei:
|
||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||
case common.ChannelTypeAzure:
|
||||
request.Model = "gpt-35-turbo"
|
||||
default:
|
||||
|
@@ -3,10 +3,12 @@ package controller
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
@@ -78,6 +80,22 @@ func SendEmailVerification(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if common.EmailDomainRestrictionEnabled {
|
||||
allowed := false
|
||||
for _, domain := range common.EmailDomainWhitelist {
|
||||
if strings.HasSuffix(email, "@"+domain) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
@@ -288,6 +288,15 @@ func init() {
|
||||
Root: "ERNIE-Bot-turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "Embedding-V1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "baidu",
|
||||
Permission: permission,
|
||||
Root: "Embedding-V1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "PaLM-2",
|
||||
Object: "model",
|
||||
@@ -324,6 +333,33 @@ func init() {
|
||||
Root: "chatglm_lite",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-v1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-v1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-plus-v1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-plus-v1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "SparkDesk",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "xunfei",
|
||||
Permission: permission,
|
||||
Root: "SparkDesk",
|
||||
Parent: nil,
|
||||
},
|
||||
}
|
||||
openAIModelsMap = make(map[string]OpenAIModels)
|
||||
for _, model := range openAIModels {
|
||||
|
@@ -2,11 +2,12 @@ package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetOptions(c *gin.Context) {
|
||||
@@ -49,6 +50,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "WeChatAuthEnabled":
|
||||
if option.Value == "true" && common.WeChatServerAddress == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
240
controller/relay-ali.go
Normal file
240
controller/relay-ali.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = message.Content
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
Bot: request.Messages[i+1].Content,
|
||||
})
|
||||
i++
|
||||
}
|
||||
}
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Input: AliInput{
|
||||
Prompt: prompt,
|
||||
History: messages,
|
||||
},
|
||||
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
||||
// TopP: request.TopP,
|
||||
// TopK: 50,
|
||||
// //Seed: 0,
|
||||
// //EnableSearch: false,
|
||||
//},
|
||||
}
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: response.Output.Text,
|
||||
},
|
||||
FinishReason: response.Output.FinishReason,
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Usage: Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
choice.FinishReason = aliResponse.Output.FinishReason
|
||||
response := ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "ernie-bot",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage.PromptTokens += aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens += aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var aliResponse AliChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
@@ -54,13 +54,43 @@ type BaiduChatStreamResponse struct {
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
@@ -101,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
|
||||
return &response
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
baiduEmbeddingRequest := BaiduEmbeddingRequest{
|
||||
Input: nil,
|
||||
}
|
||||
switch request.Input.(type) {
|
||||
case string:
|
||||
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
|
||||
case []string:
|
||||
baiduEmbeddingRequest.Input = request.Input.([]string)
|
||||
}
|
||||
return &baiduEmbeddingRequest
|
||||
}
|
||||
|
||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "baidu-embedding",
|
||||
Usage: response.Usage,
|
||||
}
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -201,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var baiduResponse BaiduEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
@@ -69,11 +69,11 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else {
|
||||
// ignore other roles
|
||||
} else if message.Role == "system" {
|
||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
claudeRequest.Prompt = prompt
|
||||
return &claudeRequest
|
||||
}
|
||||
|
@@ -34,6 +34,9 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
@@ -43,7 +46,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
continue // just ignore the error
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
@@ -53,7 +56,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += choice.Text
|
||||
@@ -89,7 +92,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var textResponse TextResponse
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
@@ -129,5 +132,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += countTokenText(choice.Message.Content, model)
|
||||
}
|
||||
textResponse.Usage = Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
return nil, &textResponse.Usage
|
||||
}
|
||||
|
@@ -20,6 +20,8 @@ const (
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
@@ -73,7 +75,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
@@ -85,14 +87,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
}
|
||||
apiType := APITypeOpenAI
|
||||
if strings.HasPrefix(textRequest.Model, "claude") {
|
||||
switch channelType {
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeClaude
|
||||
} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
} else if strings.HasPrefix(textRequest.Model, "PaLM") {
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
} else if strings.HasPrefix(textRequest.Model, "chatglm_") {
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
@@ -134,12 +141,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
|
||||
case APITypePaLM:
|
||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
||||
if baseURL != "" {
|
||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
@@ -149,6 +161,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||
case APITypeAli:
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
}
|
||||
var promptTokens int
|
||||
var completionTokens int
|
||||
@@ -202,12 +216,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeBaidu:
|
||||
baiduRequest := requestOpenAI2Baidu(textRequest)
|
||||
jsonStr, err := json.Marshal(baiduRequest)
|
||||
var jsonData []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
|
||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Baidu(textRequest)
|
||||
jsonData, err = json.Marshal(baiduRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
case APITypePaLM:
|
||||
palmRequest := requestOpenAI2PaLM(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
@@ -222,49 +244,68 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", apiKey)
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
case APITypeAli:
|
||||
aliRequest := requestOpenAI2Ali(textRequest)
|
||||
jsonStr, err := json.Marshal(aliRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
case APITypeClaude:
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
var resp *http.Response
|
||||
isStream := textRequest.Stream
|
||||
|
||||
if apiType != APITypeXunfei { // cause xunfei use websocket
|
||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
case APITypeZhipu:
|
||||
token := getZhipuToken(apiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", apiKey)
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
}
|
||||
case APITypeClaude:
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
case APITypeZhipu:
|
||||
token := getZhipuToken(apiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
case APITypeAli:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if textRequest.Stream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
||||
resp, err = httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
}
|
||||
|
||||
var textResponse TextResponse
|
||||
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
var streamResponseText string
|
||||
|
||||
defer func() {
|
||||
if consumeQuota {
|
||||
@@ -276,16 +317,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
||||
completionRatio = 2
|
||||
}
|
||||
if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
|
||||
completionTokens = countTokenText(streamResponseText, textRequest.Model)
|
||||
} else {
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
if apiType == APITypeZhipu {
|
||||
// zhipu's API does not return prompt tokens & completion tokens
|
||||
promptTokens = textResponse.Usage.TotalTokens
|
||||
}
|
||||
}
|
||||
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
@@ -323,10 +358,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
streamResponseText = responseText
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := openaiHandler(c, resp, consumeQuota)
|
||||
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -341,7 +377,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
streamResponseText = responseText
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
||||
@@ -364,7 +401,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
err, usage := baiduHandler(c, resp)
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -379,7 +423,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
streamResponseText = responseText
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
||||
@@ -400,6 +445,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
// zhipu's API does not return prompt tokens & completion tokens
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
} else {
|
||||
err, usage := zhipuHandler(c, resp)
|
||||
@@ -409,8 +456,49 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
// zhipu's API does not return prompt tokens & completion tokens
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
}
|
||||
case APITypeAli:
|
||||
if isStream {
|
||||
err, usage := aliStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
err, usage := aliHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeXunfei:
|
||||
if isStream {
|
||||
auth := c.Request.Header.Get("Authorization")
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
splits := strings.Split(auth, "|")
|
||||
if len(splits) != 3 {
|
||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
default:
|
||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||
}
|
||||
|
278
controller/relay-xunfei.go
Normal file
278
controller/relay-xunfei.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
|
||||
type XunfeiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type XunfeiChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []XunfeiMessage `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
|
||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
})
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
xunfeiRequest := XunfeiChatRequest{}
|
||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
||||
xunfeiRequest.Parameter.Chat.Domain = "general"
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
},
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Usage: response.Payload.Usage.Text,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
response := ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
encodeData := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(encodeData)
|
||||
}
|
||||
ul, err := url.Parse(hostUrl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
date := time.Now().UTC().Format(time.RFC1123)
|
||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
sign := strings.Join(signString, "\n")
|
||||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
"hmac-sha256", "host date request-line", sha)
|
||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
v := url.Values{}
|
||||
v.Add("host", ul.Host)
|
||||
v.Add("date", date)
|
||||
v.Add("authorization", authorization)
|
||||
callUrl := hostUrl + "?" + v.Encode()
|
||||
return callUrl
|
||||
}
|
||||
|
||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
hostUrl := "wss://aichat.xf-yun.com/v1/chat"
|
||||
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
data := requestOpenAI2Xunfei(textRequest, appId)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
var response XunfeiChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &xunfeiResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if xunfeiResponse.Header.Code != 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: xunfeiResponse.Header.Message,
|
||||
Type: "xunfei_error",
|
||||
Param: "",
|
||||
Code: xunfeiResponse.Header.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
@@ -111,10 +111,21 @@ func getZhipuToken(apikey string) string {
|
||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "system",
|
||||
Content: message.Content,
|
||||
})
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "user",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
return &ZhipuRequest{
|
||||
Prompt: messages,
|
||||
@@ -183,8 +194,8 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
@@ -197,14 +208,19 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.Trim(data, "\"")
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] == "data:" {
|
||||
dataChan <- data[5:]
|
||||
} else if data[:5] == "meta:" {
|
||||
metaChan <- data[5:]
|
||||
lines := strings.Split(data, "\n")
|
||||
for i, line := range lines {
|
||||
if len(line) < 5 {
|
||||
continue
|
||||
}
|
||||
if line[:5] == "data:" {
|
||||
dataChan <- line[5:]
|
||||
if i != len(lines)-1 {
|
||||
dataChan <- "\n"
|
||||
}
|
||||
} else if line[:5] == "meta:" {
|
||||
metaChan <- line[5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
|
@@ -81,8 +81,9 @@ type OpenAIErrorWithStatusCode struct {
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAITextResponseChoice struct {
|
||||
@@ -99,6 +100,19 @@ type OpenAITextResponse struct {
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int `json:"created"`
|
||||
Data []struct {
|
||||
|
1
go.mod
1
go.mod
@@ -13,6 +13,7 @@ require (
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/pkoukk/tiktoken-go v0.1.1
|
||||
golang.org/x/crypto v0.9.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
|
2
go.sum
2
go.sum
@@ -67,6 +67,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
|
@@ -503,5 +503,12 @@
|
||||
"请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT",
|
||||
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
|
||||
"Homepage URL 填": "Fill in the Homepage URL",
|
||||
"Authorization callback URL 填": "Fill in the Authorization callback URL"
|
||||
"Authorization callback URL 填": "Fill in the Authorization callback URL",
|
||||
"请为通道命名": "Please name the channel",
|
||||
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
|
||||
"模型重定向": "Model redirection",
|
||||
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
|
||||
"注意,": "Note that, ",
|
||||
",图片演示。": "related image demo.",
|
||||
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!"
|
||||
}
|
||||
|
1
main.go
1
main.go
@@ -54,6 +54,7 @@ func main() {
|
||||
if err != nil {
|
||||
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
|
||||
}
|
||||
common.SyncFrequency = frequency
|
||||
go model.SyncOptions(frequency)
|
||||
if common.RedisEnabled {
|
||||
go model.SyncChannelCache(frequency)
|
||||
|
@@ -12,11 +12,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenCacheSeconds = 60 * 60
|
||||
UserId2GroupCacheSeconds = 60 * 60
|
||||
UserId2QuotaCacheSeconds = 10 * 60
|
||||
UserId2StatusCacheSeconds = 60 * 60
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
@@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set token error: " + err.Error())
|
||||
}
|
||||
@@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user quota error: " + err.Error())
|
||||
}
|
||||
@@ -91,7 +91,7 @@ func CacheUpdateUserQuota(id int) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool {
|
||||
status = common.UserStatusEnabled
|
||||
}
|
||||
enabled = fmt.Sprintf("%d", status)
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
|
@@ -39,6 +39,8 @@ func InitOptionMap() {
|
||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
||||
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
||||
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
||||
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
|
||||
common.OptionMap["SMTPServer"] = ""
|
||||
common.OptionMap["SMTPFrom"] = ""
|
||||
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
|
||||
@@ -141,6 +143,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.TurnstileCheckEnabled = boolValue
|
||||
case "RegisterEnabled":
|
||||
common.RegisterEnabled = boolValue
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "AutomaticDisableChannelEnabled":
|
||||
common.AutomaticDisableChannelEnabled = boolValue
|
||||
case "ApproximateTokenEnabled":
|
||||
@@ -154,6 +158,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
case "EmailDomainWhitelist":
|
||||
common.EmailDomainWhitelist = strings.Split(value, ",")
|
||||
case "SMTPServer":
|
||||
common.SMTPServer = value
|
||||
case "SMTPPort":
|
||||
|
@@ -51,20 +51,21 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
redemption := &Redemption{}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := DB.Where("`key` = ?", key).First(redemption).Error
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
return errors.New("无效的兑换码")
|
||||
}
|
||||
if redemption.Status != common.RedemptionCodeStatusEnabled {
|
||||
return errors.New("该兑换码已被使用")
|
||||
}
|
||||
err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
||||
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
redemption.RedeemedTime = common.GetTimestamp()
|
||||
redemption.Status = common.RedemptionCodeStatusUsed
|
||||
return redemption.SelectUpdate()
|
||||
err = tx.Save(redemption).Error
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.New("兑换失败," + err.Error())
|
||||
|
@@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.TokenAuth())
|
||||
{
|
||||
modelsRouter.GET("/", controller.ListModels)
|
||||
modelsRouter.GET("", controller.ListModels)
|
||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
|
@@ -363,9 +363,12 @@ const ChannelsTable = () => {
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<Popup
|
||||
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
|
||||
key={channel.id}
|
||||
trigger={renderBalance(channel.type, channel.balance)}
|
||||
trigger={<span onClick={() => {
|
||||
updateChannelBalance(channel.id, channel.name, idx);
|
||||
}} style={{ cursor: 'pointer' }}>
|
||||
{renderBalance(channel.type, channel.balance)}
|
||||
</span>}
|
||||
content="点击更新"
|
||||
basic
|
||||
/>
|
||||
</Table.Cell>
|
||||
@@ -380,16 +383,16 @@ const ChannelsTable = () => {
|
||||
>
|
||||
测试
|
||||
</Button>
|
||||
<Button
|
||||
size={'small'}
|
||||
positive
|
||||
loading={updatingBalance}
|
||||
onClick={() => {
|
||||
updateChannelBalance(channel.id, channel.name, idx);
|
||||
}}
|
||||
>
|
||||
更新余额
|
||||
</Button>
|
||||
{/*<Button*/}
|
||||
{/* size={'small'}*/}
|
||||
{/* positive*/}
|
||||
{/* loading={updatingBalance}*/}
|
||||
{/* onClick={() => {*/}
|
||||
{/* updateChannelBalance(channel.id, channel.name, idx);*/}
|
||||
{/* }}*/}
|
||||
{/*>*/}
|
||||
{/* 更新余额*/}
|
||||
{/*</Button>*/}
|
||||
<Popup
|
||||
trigger={
|
||||
<Button size='small' negative>
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react';
|
||||
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
|
||||
import { Button, Divider, Form, Grid, Header, Input, Message } from 'semantic-ui-react';
|
||||
import { API, removeTrailingSlash, showError } from '../helpers';
|
||||
|
||||
const SystemSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
@@ -26,9 +26,13 @@ const SystemSetting = () => {
|
||||
TurnstileSiteKey: '',
|
||||
TurnstileSecretKey: '',
|
||||
RegisterEnabled: '',
|
||||
EmailDomainRestrictionEnabled: '',
|
||||
EmailDomainWhitelist: ''
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]);
|
||||
const [restrictedDomainInput, setRestrictedDomainInput] = useState('');
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
@@ -38,8 +42,15 @@ const SystemSetting = () => {
|
||||
data.forEach((item) => {
|
||||
newInputs[item.key] = item.value;
|
||||
});
|
||||
setInputs(newInputs);
|
||||
setInputs({
|
||||
...newInputs,
|
||||
EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',')
|
||||
});
|
||||
setOriginInputs(newInputs);
|
||||
|
||||
setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => {
|
||||
return { key: item, text: item, value: item };
|
||||
}));
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -58,6 +69,7 @@ const SystemSetting = () => {
|
||||
case 'GitHubOAuthEnabled':
|
||||
case 'WeChatAuthEnabled':
|
||||
case 'TurnstileCheckEnabled':
|
||||
case 'EmailDomainRestrictionEnabled':
|
||||
case 'RegisterEnabled':
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
break;
|
||||
@@ -70,7 +82,12 @@ const SystemSetting = () => {
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
setInputs((inputs) => ({ ...inputs, [key]: value }));
|
||||
if (key === 'EmailDomainWhitelist') {
|
||||
value = value.split(',');
|
||||
}
|
||||
setInputs((inputs) => ({
|
||||
...inputs, [key]: value
|
||||
}));
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -88,7 +105,8 @@ const SystemSetting = () => {
|
||||
name === 'WeChatServerToken' ||
|
||||
name === 'WeChatAccountQRCodeImageURL' ||
|
||||
name === 'TurnstileSiteKey' ||
|
||||
name === 'TurnstileSecretKey'
|
||||
name === 'TurnstileSecretKey' ||
|
||||
name === 'EmailDomainWhitelist'
|
||||
) {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
} else {
|
||||
@@ -125,6 +143,16 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const submitEmailDomainWhitelist = async () => {
|
||||
if (
|
||||
originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') &&
|
||||
inputs.SMTPToken !== ''
|
||||
) {
|
||||
await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(','));
|
||||
}
|
||||
};
|
||||
|
||||
const submitWeChat = async () => {
|
||||
if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) {
|
||||
await updateOption(
|
||||
@@ -173,6 +201,22 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const submitNewRestrictedDomain = () => {
|
||||
const localDomainList = inputs.EmailDomainWhitelist;
|
||||
if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) {
|
||||
setRestrictedDomainInput('');
|
||||
setInputs({
|
||||
...inputs,
|
||||
EmailDomainWhitelist: [...localDomainList, restrictedDomainInput],
|
||||
});
|
||||
setEmailDomainWhitelist([...EmailDomainWhitelist, {
|
||||
key: restrictedDomainInput,
|
||||
text: restrictedDomainInput,
|
||||
value: restrictedDomainInput,
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
@@ -239,6 +283,54 @@ const SystemSetting = () => {
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
配置邮箱域名白名单
|
||||
<Header.Subheader>用以防止恶意用户利用临时邮箱批量注册</Header.Subheader>
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Checkbox
|
||||
label='启用邮箱域名白名单'
|
||||
name='EmailDomainRestrictionEnabled'
|
||||
onChange={handleInputChange}
|
||||
checked={inputs.EmailDomainRestrictionEnabled === 'true'}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={2}>
|
||||
<Form.Dropdown
|
||||
label='允许的邮箱域名'
|
||||
placeholder='允许的邮箱域名'
|
||||
name='EmailDomainWhitelist'
|
||||
required
|
||||
fluid
|
||||
multiple
|
||||
selection
|
||||
onChange={handleInputChange}
|
||||
value={inputs.EmailDomainWhitelist}
|
||||
autoComplete='new-password'
|
||||
options={EmailDomainWhitelist}
|
||||
/>
|
||||
<Form.Input
|
||||
label='添加新的允许的邮箱域名'
|
||||
action={
|
||||
<Button type='button' onClick={() => {
|
||||
submitNewRestrictedDomain();
|
||||
}}>填入</Button>
|
||||
}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
submitNewRestrictedDomain();
|
||||
}
|
||||
}}
|
||||
autoComplete='new-password'
|
||||
placeholder='输入新的允许的邮箱域名'
|
||||
value={restrictedDomainInput}
|
||||
onChange={(e, { value }) => {
|
||||
setRestrictedDomainInput(value);
|
||||
}}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitEmailDomainWhitelist}>保存邮箱域名白名单设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
配置 SMTP
|
||||
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
|
||||
@@ -284,7 +376,7 @@ const SystemSetting = () => {
|
||||
onChange={handleInputChange}
|
||||
type='password'
|
||||
autoComplete='new-password'
|
||||
value={inputs.SMTPToken}
|
||||
checked={inputs.RegisterEnabled === 'true'}
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
/>
|
||||
</Form.Group>
|
||||
|
@@ -1,11 +1,17 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Button, Form, Label, Modal, Pagination, Popup, Table } from 'semantic-ui-react';
|
||||
import { Button, Dropdown, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
|
||||
import { Link } from 'react-router-dom';
|
||||
import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers';
|
||||
|
||||
import { ITEMS_PER_PAGE } from '../constants';
|
||||
import { renderQuota } from '../helpers/render';
|
||||
|
||||
const COPY_OPTIONS = [
|
||||
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
|
||||
{ key: 'ama', text: 'AMA 问天', value: 'ama' },
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||
];
|
||||
|
||||
function renderTimestamp(timestamp) {
|
||||
return (
|
||||
<>
|
||||
@@ -68,7 +74,40 @@ const TokensTable = () => {
|
||||
const refresh = async () => {
|
||||
setLoading(true);
|
||||
await loadTokens(activePage - 1);
|
||||
}
|
||||
};
|
||||
|
||||
const onCopy = async (type, key) => {
|
||||
let status = localStorage.getItem('status');
|
||||
let serverAddress = '';
|
||||
if (status) {
|
||||
status = JSON.parse(status);
|
||||
serverAddress = status.server_address;
|
||||
}
|
||||
if (serverAddress === '') {
|
||||
serverAddress = window.location.origin;
|
||||
}
|
||||
let encodedServerAddress = encodeURIComponent(serverAddress);
|
||||
let url;
|
||||
switch (type) {
|
||||
case 'ama':
|
||||
url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`;
|
||||
break;
|
||||
case 'opencat':
|
||||
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
|
||||
break;
|
||||
case 'next':
|
||||
url = `https://chatgpt1.nextweb.fun/#/?settings=%7B%22key%22:%22sk-${key}%22,%22url%22:%22${serverAddress}%22%7D`;
|
||||
break;
|
||||
default:
|
||||
url = `sk-${key}`;
|
||||
}
|
||||
if (await copy(url)) {
|
||||
showSuccess('已复制到剪贴板!');
|
||||
} else {
|
||||
showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。');
|
||||
setSearchKeyword(url);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
loadTokens(0)
|
||||
@@ -235,21 +274,28 @@ const TokensTable = () => {
|
||||
<Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell>
|
||||
<Table.Cell>
|
||||
<div>
|
||||
<Button
|
||||
size={'small'}
|
||||
positive
|
||||
onClick={async () => {
|
||||
let key = "sk-" + token.key;
|
||||
if (await copy(key)) {
|
||||
showSuccess('已复制到剪贴板!');
|
||||
} else {
|
||||
showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。');
|
||||
setSearchKeyword(key);
|
||||
<Button.Group color='green' size={'small'}>
|
||||
<Button
|
||||
size={'small'}
|
||||
positive
|
||||
onClick={async () => {
|
||||
await onCopy('', token.key);
|
||||
}
|
||||
}}
|
||||
>
|
||||
复制
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
复制
|
||||
</Button>
|
||||
<Dropdown
|
||||
className='button icon'
|
||||
floating
|
||||
options={COPY_OPTIONS}
|
||||
onChange={async (e, { value } = {}) => {
|
||||
await onCopy(value, token.key);
|
||||
}}
|
||||
trigger={<></>}
|
||||
/>
|
||||
</Button.Group>
|
||||
{' '}
|
||||
<Popup
|
||||
trigger={
|
||||
<Button size='small' negative>
|
||||
|
@@ -227,7 +227,7 @@ const UsersTable = () => {
|
||||
content={user.email ? user.email : '未绑定邮箱地址'}
|
||||
key={user.username}
|
||||
header={user.display_name ? user.display_name : user.username}
|
||||
trigger={<span>{renderText(user.username, 10)}</span>}
|
||||
trigger={<span>{renderText(user.username, 15)}</span>}
|
||||
hoverable
|
||||
/>
|
||||
</Table.Cell>
|
||||
|
@@ -4,6 +4,8 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
|
||||
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
|
||||
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
|
||||
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
|
||||
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
||||
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
||||
|
@@ -1,5 +1,5 @@
|
||||
export const toastConstants = {
|
||||
SUCCESS_TIMEOUT: 500,
|
||||
SUCCESS_TIMEOUT: 1500,
|
||||
INFO_TIMEOUT: 3000,
|
||||
ERROR_TIMEOUT: 5000,
|
||||
WARNING_TIMEOUT: 10000,
|
||||
|
@@ -35,6 +35,30 @@ const EditChannel = () => {
|
||||
const [customModel, setCustomModel] = useState('');
|
||||
const handleInputChange = (e, { name, value }) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
if (name === 'type' && inputs.models.length === 0) {
|
||||
let localModels = [];
|
||||
switch (value) {
|
||||
case 14:
|
||||
localModels = ['claude-instant-1', 'claude-2'];
|
||||
break;
|
||||
case 11:
|
||||
localModels = ['PaLM-2'];
|
||||
break;
|
||||
case 15:
|
||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
||||
break;
|
||||
case 17:
|
||||
localModels = ['qwen-v1', 'qwen-plus-v1'];
|
||||
break;
|
||||
case 16:
|
||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||
break;
|
||||
case 18:
|
||||
localModels = ['SparkDesk'];
|
||||
break;
|
||||
}
|
||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||
}
|
||||
};
|
||||
|
||||
const loadChannel = async () => {
|
||||
@@ -132,7 +156,10 @@ const EditChannel = () => {
|
||||
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
||||
}
|
||||
if (localInputs.type === 3 && localInputs.other === '') {
|
||||
localInputs.other = '2023-03-15-preview';
|
||||
localInputs.other = '2023-06-01-preview';
|
||||
}
|
||||
if (localInputs.model_mapping === '') {
|
||||
localInputs.model_mapping = '{}';
|
||||
}
|
||||
let res;
|
||||
localInputs.models = localInputs.models.join(',');
|
||||
@@ -192,7 +219,7 @@ const EditChannel = () => {
|
||||
<Form.Input
|
||||
label='默认 API 版本'
|
||||
name='other'
|
||||
placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'}
|
||||
placeholder={'请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖'}
|
||||
onChange={handleInputChange}
|
||||
value={inputs.other}
|
||||
autoComplete='new-password'
|
||||
@@ -270,8 +297,8 @@ const EditChannel = () => {
|
||||
}}>清除所有模型</Button>
|
||||
<Input
|
||||
action={
|
||||
<Button type={'button'} onClick={()=>{
|
||||
if (customModel.trim() === "") return;
|
||||
<Button type={'button'} onClick={() => {
|
||||
if (customModel.trim() === '') return;
|
||||
if (inputs.models.includes(customModel)) return;
|
||||
let localModels = [...inputs.models];
|
||||
localModels.push(customModel);
|
||||
@@ -279,9 +306,9 @@ const EditChannel = () => {
|
||||
localModelOptions.push({
|
||||
key: customModel,
|
||||
text: customModel,
|
||||
value: customModel,
|
||||
value: customModel
|
||||
});
|
||||
setModelOptions(modelOptions=>{
|
||||
setModelOptions(modelOptions => {
|
||||
return [...modelOptions, ...localModelOptions];
|
||||
});
|
||||
setCustomModel('');
|
||||
@@ -297,7 +324,7 @@ const EditChannel = () => {
|
||||
</div>
|
||||
<Form.Field>
|
||||
<Form.TextArea
|
||||
label='模型映射'
|
||||
label='模型重定向'
|
||||
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
||||
name='model_mapping'
|
||||
onChange={handleInputChange}
|
||||
@@ -323,7 +350,7 @@ const EditChannel = () => {
|
||||
label='密钥'
|
||||
name='key'
|
||||
required
|
||||
placeholder={inputs.type === 15 ? "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次" : '请输入渠道对应的鉴权密钥'}
|
||||
placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
|
||||
onChange={handleInputChange}
|
||||
value={inputs.key}
|
||||
autoComplete='new-password'
|
||||
@@ -354,7 +381,7 @@ const EditChannel = () => {
|
||||
</Form.Field>
|
||||
)
|
||||
}
|
||||
<Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button>
|
||||
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
|
||||
</Form>
|
||||
</Segment>
|
||||
</>
|
||||
|
@@ -83,7 +83,7 @@ const EditToken = () => {
|
||||
if (isEdit) {
|
||||
showSuccess('令牌更新成功!');
|
||||
} else {
|
||||
showSuccess('令牌创建成功!');
|
||||
showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!');
|
||||
setInputs(originInputs);
|
||||
}
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user