mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
53 Commits
v0.2.9.6
...
v0.3.1.0-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a859ff5985 | ||
|
|
0a80231e18 | ||
|
|
7b1ff41e4c | ||
|
|
4e0c522cd0 | ||
|
|
f08f7ae940 | ||
|
|
be64408a25 | ||
|
|
d596699250 | ||
|
|
f0907bf60a | ||
|
|
e5c05d77b7 | ||
|
|
24b3ed50d7 | ||
|
|
8de79382f0 | ||
|
|
74f9006b40 | ||
|
|
33af069fae | ||
|
|
e3c85572d4 | ||
|
|
4b48e490fa | ||
|
|
3e2ae29ba0 | ||
|
|
fe0ed128c6 | ||
|
|
3785e9d754 | ||
|
|
902a66b60f | ||
|
|
aaf3f09eec | ||
|
|
e523555844 | ||
|
|
139a104b26 | ||
|
|
8b8abfadaf | ||
|
|
65e65097b2 | ||
|
|
62e321fe30 | ||
|
|
312ab44800 | ||
|
|
a2678a256d | ||
|
|
8b67664995 | ||
|
|
ade6d0f56a | ||
|
|
f599c65944 | ||
|
|
40baa636e4 | ||
|
|
d6359ec4ff | ||
|
|
89ddf83b44 | ||
|
|
6a8a4bcf65 | ||
|
|
e298f2e5a4 | ||
|
|
8cea6dff4a | ||
|
|
5035cd054a | ||
|
|
02c0c6501e | ||
|
|
f0b808a41d | ||
|
|
31d84ee32f | ||
|
|
9969ed2d7c | ||
|
|
746311242b | ||
|
|
04a68a85dd | ||
|
|
f9ba10f180 | ||
|
|
334a6f8280 | ||
|
|
0cf53ac5ff | ||
|
|
af02cdc58b | ||
|
|
9a4ca1e210 | ||
|
|
9fe1f35fd1 | ||
|
|
972ac1ee0f | ||
|
|
0f95502b04 | ||
|
|
b58b1dc0ec | ||
|
|
05d9aa61df |
@@ -124,10 +124,11 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
[对接文档](Suno.md)
|
||||
|
||||
## 界面截图
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||

|
||||
夜间模式
|
||||

|
||||

|
||||
|
||||
@@ -222,6 +222,7 @@ const (
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
@@ -270,4 +271,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
}
|
||||
|
||||
@@ -32,25 +32,27 @@ var defaultModelRatio = map[string]float64{
|
||||
"gpt-4-0613": 15,
|
||||
"gpt-4-32k": 30,
|
||||
//"gpt-4-32k-0314": 30, //deprecated
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
|
||||
"o1-preview": 7.5,
|
||||
"o1-preview-2024-09-12": 7.5,
|
||||
"o1-mini": 1.5,
|
||||
"o1-mini-2024-09-12": 1.5,
|
||||
"gpt-4o-mini": 0.075,
|
||||
"gpt-4o-mini-2024-07-18": 0.075,
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o": 1.25, // $0.01 / 1K tokens
|
||||
"gpt-4o-audio-preview": 1.25, // $0.0015 / 1K tokens
|
||||
"gpt-4o-audio-preview-2024-10-01": 1.25, // $0.0015 / 1K tokens
|
||||
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
|
||||
"gpt-4o-2024-05-13": 2.5,
|
||||
"gpt-4o-realtime-preview": 2.5,
|
||||
"o1-preview": 7.5,
|
||||
"o1-preview-2024-09-12": 7.5,
|
||||
"o1-mini": 1.5,
|
||||
"o1-mini-2024-09-12": 1.5,
|
||||
"gpt-4o-mini": 0.075,
|
||||
"gpt-4o-mini-2024-07-18": 0.075,
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
//"gpt-3.5-turbo-0301": 0.75, //deprecated
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
@@ -86,8 +88,10 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5,
|
||||
"claude-3-5-sonnet-20241022": 1.5,
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||
@@ -336,13 +340,13 @@ func GetCompletionRatio(name string) float64 {
|
||||
name = "gpt-4o-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
|
||||
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o") {
|
||||
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
|
||||
return 4
|
||||
if name == "gpt-4o-2024-05-13" {
|
||||
return 3
|
||||
}
|
||||
return 4
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
@@ -375,10 +379,7 @@ func GetCompletionRatio(name string) float64 {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "gemini-") {
|
||||
if strings.Contains(name, "flash") {
|
||||
return 4
|
||||
}
|
||||
return 3
|
||||
return 4
|
||||
}
|
||||
if strings.HasPrefix(name, "command") {
|
||||
switch name {
|
||||
@@ -420,6 +421,34 @@ func GetCompletionRatio(name string) float64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func GetAudioRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
return 20
|
||||
}
|
||||
return 20
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
return 10
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
//func GetAudioPricePerMinute(name string) float64 {
|
||||
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
// return 0.06
|
||||
// }
|
||||
// return 0.06
|
||||
//}
|
||||
//
|
||||
//func GetAudioCompletionPricePerMinute(name string) float64 {
|
||||
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
// return 0.24
|
||||
// }
|
||||
// return 0.24
|
||||
//}
|
||||
|
||||
func GetCompletionRatioMap() map[string]float64 {
|
||||
if CompletionRatio == nil {
|
||||
CompletionRatio = defaultCompletionRatio
|
||||
|
||||
@@ -21,3 +21,26 @@ func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
|
||||
UserUsableGroups = make(map[string]string)
|
||||
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
|
||||
}
|
||||
|
||||
func GetUserUsableGroups(userGroup string) map[string]string {
|
||||
if userGroup == "" {
|
||||
// 如果userGroup为空,返回UserUsableGroups
|
||||
return UserUsableGroups
|
||||
}
|
||||
// 如果userGroup不在UserUsableGroups中,返回UserUsableGroups + userGroup
|
||||
if _, ok := UserUsableGroups[userGroup]; !ok {
|
||||
appendUserUsableGroups := make(map[string]string)
|
||||
for k, v := range UserUsableGroups {
|
||||
appendUserUsableGroups[k] = v
|
||||
}
|
||||
appendUserUsableGroups[userGroup] = "用户分组"
|
||||
return appendUserUsableGroups
|
||||
}
|
||||
// 如果userGroup在UserUsableGroups中,返回UserUsableGroups
|
||||
return UserUsableGroups
|
||||
}
|
||||
|
||||
func GroupInUserUsableGroups(groupName string) bool {
|
||||
_, ok := UserUsableGroups[groupName]
|
||||
return ok
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"html/template"
|
||||
"log"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os/exec"
|
||||
@@ -142,24 +145,35 @@ func GetUUID() string {
|
||||
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
}
|
||||
|
||||
func GenerateKey() string {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
key := make([]byte, 48)
|
||||
for i := 0; i < 16; i++ {
|
||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||
}
|
||||
uuid_ := GetUUID()
|
||||
for i := 0; i < 32; i++ {
|
||||
c := uuid_[i]
|
||||
if i%2 == 0 && c >= 'a' && c <= 'z' {
|
||||
c = c - 'a' + 'A'
|
||||
func GenerateRandomCharsKey(length int) (string, error) {
|
||||
b := make([]byte, length)
|
||||
maxI := big.NewInt(int64(len(keyChars)))
|
||||
|
||||
for i := range b {
|
||||
n, err := crand.Int(crand.Reader, maxI)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
key[i+16] = c
|
||||
b[i] = keyChars[n.Int64()]
|
||||
}
|
||||
return string(key)
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func GenerateRandomKey(length int) (string, error) {
|
||||
bytes := make([]byte, length*3/4) // 对于48位的输出,这里应该是36
|
||||
if _, err := crand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateKey() (string, error) {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
return GenerateRandomCharsKey(48)
|
||||
}
|
||||
|
||||
func GetRandomInt(max int) int {
|
||||
|
||||
35
constant/chat.go
Normal file
35
constant/chat.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var Chats = []map[string]string{
|
||||
{
|
||||
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
|
||||
},
|
||||
{
|
||||
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
|
||||
},
|
||||
{
|
||||
"AMA 问天": "ama://set-api-key?server={address}&key={key}",
|
||||
},
|
||||
{
|
||||
"OpenCat": "opencat://team/join?domain={address}&token={key}",
|
||||
},
|
||||
}
|
||||
|
||||
func UpdateChatsByJsonString(jsonString string) error {
|
||||
Chats = make([]map[string]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &Chats)
|
||||
}
|
||||
|
||||
func Chats2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(Chats)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling chats: " + err.Error())
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
@@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
}
|
||||
if usage == nil {
|
||||
if usageA == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
|
||||
@@ -112,7 +112,9 @@ func GitHubOAuth(c *gin.Context) {
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
// IsGitHubIdAlreadyTaken is unscoped
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
// FillUserByGitHubId is scoped
|
||||
err := user.FillUserByGitHubId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -121,6 +123,14 @@ func GitHubOAuth(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// if user.Id == 0 , user has been deleted
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
@@ -20,10 +21,14 @@ func GetGroups(c *gin.Context) {
|
||||
|
||||
func GetUserGroups(c *gin.Context) {
|
||||
usableGroups := make(map[string]string)
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.CacheGetUserGroup(userId)
|
||||
for groupName, _ := range common.GroupRatio {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
if _, ok := common.UserUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = common.UserUsableGroups[groupName]
|
||||
userUsableGroups := common.GetUserUsableGroups(userGroup)
|
||||
if _, ok := userUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = userUsableGroups[groupName]
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -63,6 +63,7 @@ func GetStatus(c *gin.Context) {
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
"chats": constant.Chats,
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -38,6 +39,67 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
return err
|
||||
}
|
||||
|
||||
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
switch relayMode {
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
group := playgroundRequest.Group
|
||||
userGroup := c.GetString("group")
|
||||
|
||||
if group == "" {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !common.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
@@ -82,6 +144,67 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许跨域
|
||||
},
|
||||
}
|
||||
|
||||
func WssRelay(c *gin.Context) {
|
||||
// 将 HTTP 连接升级为 WebSocket 连接
|
||||
|
||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
service.WssError(c, ws, openaiErr.Error)
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
service.WssError(c, ws, openaiErr.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
@@ -89,6 +212,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func addUsedChannel(c *gin.Context, channelId int) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"sort"
|
||||
@@ -48,6 +49,13 @@ func TelegramBind(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.TelegramId = telegramId
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
|
||||
@@ -123,10 +123,19 @@ func AddToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成令牌失败",
|
||||
})
|
||||
common.SysError("failed to generate token key: " + err.Error())
|
||||
return
|
||||
}
|
||||
cleanToken := model.Token{
|
||||
UserId: c.GetInt("id"),
|
||||
Name: token.Name,
|
||||
Key: common.GenerateKey(),
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
AccessedTime: common.GetTimestamp(),
|
||||
ExpiredTime: token.ExpiredTime,
|
||||
|
||||
@@ -68,6 +68,7 @@ func setupLogin(user *model.User, c *gin.Context) {
|
||||
session.Set("username", user.Username)
|
||||
session.Set("role", user.Role)
|
||||
session.Set("status", user.Status)
|
||||
session.Set("group", user.Group)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -159,8 +160,9 @@ func Register(c *gin.Context) {
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": "数据库错误,请稍后重试",
|
||||
})
|
||||
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||
return
|
||||
}
|
||||
if exist {
|
||||
@@ -200,11 +202,20 @@ func Register(c *gin.Context) {
|
||||
}
|
||||
// 生成默认令牌
|
||||
if constant.GenerateDefaultToken {
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成默认令牌失败",
|
||||
})
|
||||
common.SysError("failed to generate token key: " + err.Error())
|
||||
return
|
||||
}
|
||||
// 生成默认令牌
|
||||
token := model.Token{
|
||||
UserId: insertedUser.Id, // 使用插入后的用户ID
|
||||
Name: cleanUser.Username + "的初始令牌",
|
||||
Key: common.GenerateKey(),
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
AccessedTime: common.GetTimestamp(),
|
||||
ExpiredTime: -1, // 永不过期
|
||||
@@ -311,7 +322,18 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
user.AccessToken = common.GetUUID()
|
||||
// get rand int 28-32
|
||||
randI := common.GetRandomInt(4)
|
||||
key, err := common.GenerateRandomKey(29 + randI)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成失败",
|
||||
})
|
||||
common.SysError("failed to generate key: " + err.Error())
|
||||
return
|
||||
}
|
||||
user.SetAccessToken(key)
|
||||
|
||||
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -78,6 +78,13 @@ func WeChatAuth(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
|
||||
@@ -26,6 +26,7 @@ type GeneralOpenAIRequest struct {
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
@@ -33,6 +34,8 @@ type GeneralOpenAIRequest struct {
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
@@ -82,9 +85,10 @@ type Message struct {
|
||||
}
|
||||
|
||||
type MediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
InputAudio any `json:"input_audio,omitempty"`
|
||||
}
|
||||
|
||||
type MessageImageUrl struct {
|
||||
@@ -92,9 +96,15 @@ type MessageImageUrl struct {
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
type MessageInputAudio struct {
|
||||
Data string `json:"data"` //base64
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
ContentTypeInputAudio = "input_audio"
|
||||
)
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
@@ -167,11 +177,19 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case ContentTypeInputAudio:
|
||||
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: MessageInputAudio{
|
||||
Data: subObj["data"].(string),
|
||||
Format: subObj["format"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
6
dto/playground.go
Normal file
6
dto/playground.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package dto
|
||||
|
||||
type PlayGroundRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
97
dto/realtime.go
Normal file
97
dto/realtime.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package dto
|
||||
|
||||
const (
|
||||
RealtimeEventTypeError = "error"
|
||||
RealtimeEventTypeSessionUpdate = "session.update"
|
||||
RealtimeEventTypeConversationCreate = "conversation.item.create"
|
||||
RealtimeEventTypeResponseCreate = "response.create"
|
||||
RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
|
||||
)
|
||||
|
||||
const (
|
||||
RealtimeEventTypeResponseDone = "response.done"
|
||||
RealtimeEventTypeSessionUpdated = "session.updated"
|
||||
RealtimeEventTypeSessionCreated = "session.created"
|
||||
RealtimeEventResponseAudioDelta = "response.audio.delta"
|
||||
RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
|
||||
RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
|
||||
RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
|
||||
RealtimeEventConversationItemCreated = "conversation.item.created"
|
||||
)
|
||||
|
||||
type RealtimeEvent struct {
|
||||
EventId string `json:"event_id"`
|
||||
Type string `json:"type"`
|
||||
//PreviousItemId string `json:"previous_item_id"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type RealtimeResponse struct {
|
||||
Usage *RealtimeUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type RealtimeUsage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails InputTokenDetails `json:"input_token_details"`
|
||||
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
}
|
||||
|
||||
type OutputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
}
|
||||
|
||||
type RealtimeSession struct {
|
||||
Modalities []string `json:"modalities"`
|
||||
Instructions string `json:"instructions"`
|
||||
Voice string `json:"voice"`
|
||||
InputAudioFormat string `json:"input_audio_format"`
|
||||
OutputAudioFormat string `json:"output_audio_format"`
|
||||
InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
|
||||
TurnDetection interface{} `json:"turn_detection"`
|
||||
Tools []RealTimeTool `json:"tools"`
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
//MaxResponseOutputTokens int `json:"max_response_output_tokens"`
|
||||
}
|
||||
|
||||
type InputAudioTranscription struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type RealTimeTool struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters any `json:"parameters"`
|
||||
}
|
||||
|
||||
type RealtimeItem struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Content []RealtimeContent `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
CallId string `json:"call_id,omitempty"`
|
||||
}
|
||||
type RealtimeContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
|
||||
Transcript string `json:"transcript,omitempty"`
|
||||
}
|
||||
@@ -121,6 +121,8 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Set("username", username)
|
||||
c.Set("role", role)
|
||||
c.Set("id", id)
|
||||
c.Set("group", session.Get("group"))
|
||||
c.Set("use_access_token", useAccessToken)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -153,8 +155,27 @@ func RootAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func WssAuth(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
// 先检测是否为ws
|
||||
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
|
||||
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
|
||||
// read sk from Sec-WebSocket-Protocol
|
||||
key := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||
parts := strings.Split(key, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "openai-insecure-api-key") {
|
||||
key = strings.TrimPrefix(part, "openai-insecure-api-key.")
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := make([]string, 0)
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
|
||||
@@ -42,7 +42,7 @@ func Distribute() func(c *gin.Context) {
|
||||
tokenGroup := c.GetString("token_group")
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := common.UserUsableGroups[tokenGroup]; !ok {
|
||||
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||
return
|
||||
}
|
||||
@@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
modelRequest.Model = c.Query("model")
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
|
||||
@@ -32,7 +32,7 @@ func createRootAccountIfNeed() error {
|
||||
Role: common.RoleRootUser,
|
||||
Status: common.UserStatusEnabled,
|
||||
DisplayName: "Root User",
|
||||
AccessToken: common.GetUUID(),
|
||||
AccessToken: nil,
|
||||
Quota: 100000000,
|
||||
}
|
||||
DB.Create(&rootUser)
|
||||
|
||||
@@ -69,6 +69,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = constant.Chats2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -248,6 +249,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
constant.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
constant.PayAddress = value
|
||||
case "Chats":
|
||||
err = constant.UpdateChatsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
constant.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -257,51 +258,56 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
|
||||
if quota < 0 {
|
||||
return 0, errors.New("quota 不能为负数!")
|
||||
}
|
||||
token, err := GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
if !relayInfo.IsPlayground {
|
||||
token, err := GetTokenById(relayInfo.TokenId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
||||
return 0, errors.New("令牌额度不足")
|
||||
}
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
||||
return 0, errors.New("令牌额度不足")
|
||||
}
|
||||
userQuota, err = GetUserQuota(token.UserId)
|
||||
userQuota, err = GetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if userQuota < quota {
|
||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
}
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
if !relayInfo.IsPlayground {
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
err = DecreaseUserQuota(token.UserId, quota)
|
||||
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
||||
return userQuota - quota, err
|
||||
}
|
||||
|
||||
func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
||||
token, err := GetTokenById(tokenId)
|
||||
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
||||
|
||||
if quota > 0 {
|
||||
err = DecreaseUserQuota(token.UserId, quota)
|
||||
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
||||
} else {
|
||||
err = IncreaseUserQuota(token.UserId, -quota)
|
||||
err = IncreaseUserQuota(relayInfo.UserId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(tokenId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
if !relayInfo.IsPlayground {
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if sendEmail {
|
||||
@@ -310,7 +316,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
||||
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
|
||||
if quotaTooLow || noMoreQuota {
|
||||
go func() {
|
||||
email, err := GetUserEmail(token.UserId)
|
||||
email, err := GetUserEmail(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.SysError("failed to fetch user email: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ type User struct {
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
||||
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
||||
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||
@@ -38,6 +38,17 @@ type User struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (user *User) GetAccessToken() string {
|
||||
if user.AccessToken == nil {
|
||||
return ""
|
||||
}
|
||||
return *user.AccessToken
|
||||
}
|
||||
|
||||
func (user *User) SetAccessToken(token string) {
|
||||
user.AccessToken = &token
|
||||
}
|
||||
|
||||
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
||||
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
var user User
|
||||
@@ -201,7 +212,7 @@ func (user *User) Insert(inviterId int) error {
|
||||
}
|
||||
}
|
||||
user.Quota = common.QuotaForNewUser
|
||||
user.AccessToken = common.GetUUID()
|
||||
//user.SetAccessToken(common.GetUUID())
|
||||
user.AffCode = common.GetRandomString(4)
|
||||
result := DB.Create(user)
|
||||
if result.Error != nil {
|
||||
@@ -340,14 +351,6 @@ func (user *User) FillUserByWeChatId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByUsername() error {
|
||||
if user.Username == "" {
|
||||
return errors.New("username 为空!")
|
||||
}
|
||||
DB.Where(User{Username: user.Username}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByTelegramId() error {
|
||||
if user.TelegramId == "" {
|
||||
return errors.New("Telegram id 为空!")
|
||||
@@ -360,23 +363,19 @@ func (user *User) FillUserByTelegramId() error {
|
||||
}
|
||||
|
||||
func IsEmailAlreadyTaken(email string) bool {
|
||||
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
|
||||
return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsWeChatIdAlreadyTaken(wechatId string) bool {
|
||||
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
|
||||
return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsTelegramIdAlreadyTaken(telegramId string) bool {
|
||||
return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
|
||||
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func ResetUserPasswordByEmail(email string, password string) error {
|
||||
|
||||
@@ -12,13 +12,13 @@ type Adaptor interface {
|
||||
// Init IsStream bool
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
|
||||
@@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
if info.IsStream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
req.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/relay/common"
|
||||
@@ -11,14 +12,16 @@ import (
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
// multipart/form-data
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
// websocket
|
||||
} else {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, req, info)
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
// set form data
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
|
||||
err = a.SetupRequestHeader(c, req, info)
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
targetHeader := http.Header{}
|
||||
err = a.SetupRequestHeader(c, &targetHeader, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
|
||||
}
|
||||
// send request body
|
||||
//all, err := io.ReadAll(requestBody)
|
||||
//err = service.WssString(c, targetConn, string(all))
|
||||
return targetConn, nil
|
||||
}
|
||||
|
||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,18 +30,14 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
} else {
|
||||
a.RequestMode = RequestModeCompletion
|
||||
}
|
||||
a.RequestMode = RequestModeMessage
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -53,11 +48,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
|
||||
var claudeReq *claude.ClaudeRequest
|
||||
var err error
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
|
||||
} else {
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
}
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
return claudeReq, err
|
||||
@@ -67,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package aws
|
||||
|
||||
var awsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package aws
|
||||
|
||||
import "one-api/relay/channel/claude"
|
||||
import (
|
||||
"one-api/relay/channel/claude"
|
||||
)
|
||||
|
||||
type AwsClaudeRequest struct {
|
||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||
@@ -12,4 +14,6 @@ type AwsClaudeRequest struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []claude.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func awsModelID(requestModel string) (string, error) {
|
||||
return awsModelID, nil
|
||||
}
|
||||
|
||||
return "", errors.Errorf("model %s not found", requestModel)
|
||||
return requestModel, nil
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
|
||||
@@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
} else {
|
||||
|
||||
@@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-api-key", info.ApiKey)
|
||||
req.Set("x-api-key", info.ApiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
req.Set("anthropic-version", anthropicVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -8,7 +8,9 @@ var ModelList = []string{
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
@@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fallthrough
|
||||
|
||||
@@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return nil, usage
|
||||
|
||||
@@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return requestConvertRerank2Cohere(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = difyStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return nil, usage
|
||||
|
||||
@@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
req.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = jinaRerankHandler(c, resp)
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
|
||||
72
relay/channel/mistral/adaptor.go
Normal file
72
relay/channel/mistral/adaptor.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
mistralReq := requestOpenAI2Mistral(*request)
|
||||
//common.LogJson(c, "body", mistralReq)
|
||||
return mistralReq, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
12
relay/channel/mistral/constants.go
Normal file
12
relay/channel/mistral/constants.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package mistral
|
||||
|
||||
var ModelList = []string{
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
"mistral-embed",
|
||||
}
|
||||
|
||||
var ChannelName = "mistral"
|
||||
40
relay/channel/mistral/text.go
Normal file
40
relay/channel/mistral/text.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
mediaMessage.ImageUrl = imageUrl.Url
|
||||
mediaMessages[j] = mediaMessage
|
||||
}
|
||||
}
|
||||
messageRaw, _ := json.Marshal(mediaMessages)
|
||||
message.Content = messageRaw
|
||||
}
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
})
|
||||
}
|
||||
return &dto.GeneralOpenAIRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
Tools: request.Tools,
|
||||
ToolChoice: request.ToolChoice,
|
||||
}
|
||||
}
|
||||
@@ -31,13 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.BaseUrl + "/api/embeddings", nil
|
||||
return info.BaseUrl + "/api/embed", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
@@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
// trim https
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
baseUrl = strings.TrimPrefix(baseUrl, "http://")
|
||||
baseUrl = "wss://" + baseUrl
|
||||
info.BaseUrl = baseUrl
|
||||
}
|
||||
switch info.ChannelType {
|
||||
case common.ChannelTypeAzure:
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
@@ -40,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
model_ := info.UpstreamModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
case common.ChannelTypeMiniMax:
|
||||
return minimax.GetRequestURL(info)
|
||||
@@ -54,16 +63,34 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, header)
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", info.ApiKey)
|
||||
header.Set("api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||||
req.Header.Set("OpenAI-Organization", info.Organization)
|
||||
header.Set("OpenAI-Organization", info.Organization)
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||
if swp != "" {
|
||||
items := []string{
|
||||
"realtime",
|
||||
"openai-insecure-api-key." + info.ApiKey,
|
||||
"openai-beta.realtime-v1",
|
||||
}
|
||||
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
|
||||
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
|
||||
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
|
||||
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
|
||||
} else {
|
||||
header.Set("openai-beta", "realtime=v1")
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
} else {
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
// req.Header.Set("X-Title", "One API")
|
||||
@@ -131,16 +158,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
return channel.DoFormRequest(a, c, info, requestBody)
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
return channel.DoWssRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRealtime:
|
||||
err, usage = OpenaiRealtimeHandler(c, info)
|
||||
case constant.RelayModeAudioSpeech:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeAudioTranslation:
|
||||
|
||||
@@ -13,6 +13,8 @@ var ModelList = []string{
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -231,7 +232,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
|
||||
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
@@ -324,7 +325,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
|
||||
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
}
|
||||
@@ -373,3 +374,210 @@ func getTextFromJSON(body []byte) (string, error) {
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
||||
info.IsStream = true
|
||||
clientConn := info.ClientWs
|
||||
targetConn := info.TargetWs
|
||||
|
||||
clientClosed := make(chan struct{})
|
||||
targetClosed := make(chan struct{})
|
||||
sendChan := make(chan []byte, 100)
|
||||
receiveChan := make(chan []byte, 100)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
usage := &dto.RealtimeUsage{}
|
||||
localUsage := &dto.RealtimeUsage{}
|
||||
sumUsage := &dto.RealtimeUsage{}
|
||||
|
||||
gopool.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||
}
|
||||
close(clientClosed)
|
||||
return
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = json.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||
if realtimeEvent.Session != nil {
|
||||
if realtimeEvent.Session.Tools != nil {
|
||||
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = service.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sendChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
gopool.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||
}
|
||||
close(targetClosed)
|
||||
return
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = json.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||
realtimeUsage := realtimeEvent.Response.Usage
|
||||
if realtimeUsage != nil {
|
||||
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||
usage.InputTokens += realtimeUsage.InputTokens
|
||||
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||
err := preConsumeUsage(c, info, usage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
usage = &dto.RealtimeUsage{}
|
||||
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
info.IsFirstRequest = false
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
if realtimeSession != nil {
|
||||
// update audio format
|
||||
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||
}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.OutputTokens += textToken + audioToken
|
||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = service.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case receiveChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
select {
|
||||
case <-clientClosed:
|
||||
case <-targetClosed:
|
||||
case err := <-errChan:
|
||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||
common.LogError(c, "realtime error: "+err.Error())
|
||||
case <-c.Done():
|
||||
}
|
||||
|
||||
if usage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, usage, sumUsage)
|
||||
}
|
||||
|
||||
if localUsage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
}
|
||||
|
||||
// check usage total tokens, if 0, use local usage
|
||||
|
||||
return nil, sumUsage
|
||||
}
|
||||
|
||||
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
||||
totalUsage.TotalTokens += usage.TotalTokens
|
||||
totalUsage.InputTokens += usage.InputTokens
|
||||
totalUsage.OutputTokens += usage.OutputTokens
|
||||
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
||||
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
||||
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
||||
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
||||
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
||||
// clear usage
|
||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
req.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
|
||||
@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
|
||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
|
||||
@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = siliconflowRerankHandler(c, resp)
|
||||
|
||||
@@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", a.Action)
|
||||
req.Header.Set("X-TC-Version", a.Version)
|
||||
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||
req.Set("Authorization", a.Sign)
|
||||
req.Set("X-TC-Action", a.Action)
|
||||
req.Set("X-TC-Version", a.Version)
|
||||
req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
|
||||
@@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
|
||||
@@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
@@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
dummyResp.StatusCode = http.StatusOK
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
splits := strings.Split(info.ApiKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
|
||||
@@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
token := getZhipuToken(info.ApiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
req.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = zhipuStreamHandler(c, resp)
|
||||
} else {
|
||||
|
||||
@@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
token := getZhipuToken(info.ApiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
req.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -2,7 +2,9 @@ package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -20,6 +22,8 @@ type RelayInfo struct {
|
||||
setFirstResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
@@ -31,6 +35,21 @@ type RelayInfo struct {
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
InputAudioFormat string
|
||||
OutputAudioFormat string
|
||||
RealtimeTools []dto.RealTimeTool
|
||||
IsFirstRequest bool
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.ClientWs = ws
|
||||
info.InputAudioFormat = "pcm16"
|
||||
info.OutputAudioFormat = "pcm16"
|
||||
info.IsFirstRequest = true
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
@@ -65,6 +84,11 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
|
||||
info.RequestURLPath = "/v1" + info.RequestURLPath
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
}
|
||||
@@ -146,3 +170,20 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
|
||||
return &RelayInfo{
|
||||
ChannelType: info.ChannelType,
|
||||
ChannelId: info.ChannelId,
|
||||
TokenId: info.TokenId,
|
||||
UserId: info.UserId,
|
||||
Group: info.Group,
|
||||
StartTime: info.StartTime,
|
||||
ApiType: info.ApiType,
|
||||
RelayMode: info.RelayMode,
|
||||
UpstreamModelName: info.UpstreamModelName,
|
||||
RequestURLPath: info.RequestURLPath,
|
||||
ApiKey: info.ApiKey,
|
||||
BaseUrl: info.BaseUrl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ const (
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -72,6 +73,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeSiliconFlow
|
||||
case common.ChannelTypeVertexAi:
|
||||
apiType = APITypeVertexAi
|
||||
case common.ChannelTypeMistral:
|
||||
apiType = APITypeMistral
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -38,11 +38,13 @@ const (
|
||||
RelayModeSunoSubmit
|
||||
|
||||
RelayModeRerank
|
||||
|
||||
RelayModeRealtime
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/pg/chat/completions") {
|
||||
relayMode = RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||
relayMode = RelayModeCompletions
|
||||
@@ -64,6 +66,8 @@ func Path2RelayMode(path string) int {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
} else if strings.HasPrefix(path, "/v1/rerank") {
|
||||
relayMode = RelayModeRerank
|
||||
} else if strings.HasPrefix(path, "/v1/realtime") {
|
||||
relayMode = RelayModeRealtime
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
@@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -87,11 +87,16 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
@@ -122,27 +127,28 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
||||
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -149,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr := service.RelayErrorHandler(httpResp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
@@ -146,6 +147,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
var swapFaceRequest dto.SwapFaceRequest
|
||||
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||
if err != nil {
|
||||
@@ -191,7 +193,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
}
|
||||
defer func(ctx context.Context) {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
@@ -356,6 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
consumeQuota := true
|
||||
var midjRequest dto.MidjourneyRequest
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
@@ -495,7 +498,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
@@ -76,6 +76,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
|
||||
// map model name
|
||||
isModelMapped := false
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
@@ -85,6 +86,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
isModelMapped = true
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
// set upstream model name
|
||||
//isModelMapped = true
|
||||
@@ -129,7 +131,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
includeUsage := false
|
||||
// 判断用户是否需要返回使用情况
|
||||
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||
@@ -159,41 +165,49 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
if relayInfo.ChannelType == common.ChannelTypeOpenAI && !isModelMapped {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -266,7 +280,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
}
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
@@ -274,11 +288,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
return preConsumedQuota, userQuota, nil
|
||||
}
|
||||
|
||||
func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) {
|
||||
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
@@ -336,7 +350,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
//}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/relay/channel/dify"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/mistral"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
@@ -68,6 +69,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &siliconflow.Adaptor{}
|
||||
case constant.APITypeVertexAi:
|
||||
return &vertex.Adaptor{}
|
||||
case constant.APITypeMistral:
|
||||
return &mistral.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
var rerankRequest *dto.RerankRequest
|
||||
@@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
@@ -99,23 +105,24 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -111,7 +111,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
defer func(ctx context.Context) {
|
||||
// release quota
|
||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
|
||||
|
||||
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
|
||||
159
relay/websocket.go
Normal file
159
relay/websocket.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
|
||||
// _, p, err := ws.ReadMessage()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// realtimeEvent := &dto.RealtimeEvent{}
|
||||
// err = json.Unmarshal(p, realtimeEvent)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// // save the original request
|
||||
// if realtimeEvent.Session == nil {
|
||||
// return nil, errors.New("session object is nil")
|
||||
// }
|
||||
// c.Set("first_wss_request", p)
|
||||
// return realtimeEvent, nil
|
||||
//}
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
//realtimeEvent, err := getAndValidateWssRequest(c, ws)
|
||||
//if err != nil {
|
||||
// common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
|
||||
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
//}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[relayInfo.OriginModelName] != "" {
|
||||
relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
|
||||
// set upstream model name
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
//if constant.ShouldCheckPromptSensitive() {
|
||||
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||
// if err != nil {
|
||||
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
// }
|
||||
//}
|
||||
|
||||
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||
//// count messages token error 计算promptTokens错误
|
||||
//if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
//}
|
||||
//
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||
//}
|
||||
modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
relayInfo.UsePrice = true
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
//var requestBody io.Reader
|
||||
//firstWssRequest, _ := c.Get("first_wss_request")
|
||||
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, nil)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
relayInfo.TargetWs = resp.(*websocket.Conn)
|
||||
defer relayInfo.TargetWs.Close()
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
|
||||
if openaiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
|
||||
// var promptTokens int
|
||||
// var err error
|
||||
// switch info.RelayMode {
|
||||
// default:
|
||||
// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
|
||||
// }
|
||||
// info.PromptTokens = promptTokens
|
||||
// return promptTokens, err
|
||||
//}
|
||||
|
||||
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
||||
// var err error
|
||||
// switch info.RelayMode {
|
||||
// case relayconstant.RelayModeChatCompletions:
|
||||
// err = service.CheckSensitiveMessages(textRequest.Messages)
|
||||
// case relayconstant.RelayModeCompletions:
|
||||
// err = service.CheckSensitiveInput(textRequest.Prompt)
|
||||
// case relayconstant.RelayModeModerations:
|
||||
// err = service.CheckSensitiveInput(textRequest.Input)
|
||||
// case relayconstant.RelayModeEmbeddings:
|
||||
// err = service.CheckSensitiveInput(textRequest.Input)
|
||||
// }
|
||||
// return err
|
||||
//}
|
||||
@@ -44,6 +44,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute := userRoute.Group("/")
|
||||
selfRoute.Use(middleware.UserAuth())
|
||||
{
|
||||
selfRoute.GET("/self/groups", controller.GetUserGroups)
|
||||
selfRoute.GET("/self", controller.GetSelf)
|
||||
selfRoute.GET("/models", controller.GetUserModels)
|
||||
selfRoute.PUT("/self", controller.UpdateSelf)
|
||||
|
||||
@@ -16,33 +16,47 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
modelsRouter.GET("", controller.ListModels)
|
||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
playgroundRouter := router.Group("/pg")
|
||||
playgroundRouter.Use(middleware.UserAuth())
|
||||
{
|
||||
relayV1Router.POST("/completions", controller.Relay)
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
relayV1Router.POST("/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/generations", controller.Relay)
|
||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||
relayV1Router.POST("/audio/translations", controller.Relay)
|
||||
relayV1Router.POST("/audio/speech", controller.Relay)
|
||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
relayV1Router.POST("/rerank", controller.Relay)
|
||||
playgroundRouter.POST("/chat/completions", controller.Playground)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.TokenAuth())
|
||||
{
|
||||
// WebSocket 路由
|
||||
wsRouter := relayV1Router.Group("")
|
||||
wsRouter.Use(middleware.Distribute())
|
||||
wsRouter.GET("/realtime", controller.WssRelay)
|
||||
}
|
||||
{
|
||||
//http router
|
||||
httpRouter := relayV1Router.Group("")
|
||||
httpRouter.Use(middleware.Distribute())
|
||||
httpRouter.POST("/completions", controller.Relay)
|
||||
httpRouter.POST("/chat/completions", controller.Relay)
|
||||
httpRouter.POST("/edits", controller.Relay)
|
||||
httpRouter.POST("/images/generations", controller.Relay)
|
||||
httpRouter.POST("/images/edits", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/embeddings", controller.Relay)
|
||||
httpRouter.POST("/engines/:model/embeddings", controller.Relay)
|
||||
httpRouter.POST("/audio/transcriptions", controller.Relay)
|
||||
httpRouter.POST("/audio/translations", controller.Relay)
|
||||
httpRouter.POST("/audio/speech", controller.Relay)
|
||||
httpRouter.GET("/files", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/files", controller.RelayNotImplemented)
|
||||
httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/files/:id", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/moderations", controller.Relay)
|
||||
httpRouter.POST("/rerank", controller.Relay)
|
||||
}
|
||||
|
||||
relayMjRouter := router.Group("/mj")
|
||||
|
||||
31
service/audio.go
Normal file
31
service/audio.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
|
||||
audioData, err := base64.StdEncoding.DecodeString(audioBase64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("base64 decode error: %v", err)
|
||||
}
|
||||
|
||||
var samplesCount int
|
||||
var sampleRate int
|
||||
|
||||
switch format {
|
||||
case "pcm16":
|
||||
samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
|
||||
sampleRate = 24000 // 24kHz
|
||||
case "g711_ulaw", "g711_alaw":
|
||||
samplesCount = len(audioData) // 8位 = 1字节每样本
|
||||
sampleRate = 8000 // 8kHz
|
||||
default:
|
||||
samplesCount = len(audioData) // 8位 = 1字节每样本
|
||||
sampleRate = 8000 // 8kHz
|
||||
}
|
||||
|
||||
duration = float64(samplesCount) / float64(sampleRate)
|
||||
return duration, nil
|
||||
}
|
||||
@@ -73,6 +73,15 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
|
||||
} else if strings.HasPrefix(err.Error.Message, "Permission denied") {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
|
||||
return true
|
||||
} else if strings.Contains(err.Error.Message, "Operation not allowed") {
|
||||
return true
|
||||
} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
@@ -17,3 +18,13 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
other["admin_info"] = adminInfo
|
||||
return other
|
||||
}
|
||||
|
||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
info["ws"] = true
|
||||
info["audio_input"] = usage.InputTokenDetails.AudioTokens
|
||||
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
|
||||
info["text_input"] = usage.InputTokenDetails.TextTokens
|
||||
info["text_output"] = usage.OutputTokenDetails.TextTokens
|
||||
return info
|
||||
}
|
||||
|
||||
140
service/quota.go
Normal file
140
service/quota.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
|
||||
if relayInfo.UsePrice {
|
||||
return nil
|
||||
}
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelName := relayInfo.UpstreamModelName
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
modelRatio := common.GetModelRatio(modelName)
|
||||
|
||||
ratio := groupRatio * modelRatio
|
||||
|
||||
quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
|
||||
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
|
||||
|
||||
quota = int(math.Round(float64(quota) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
|
||||
if userQuota < quota {
|
||||
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
}
|
||||
|
||||
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
|
||||
groupRatio float64,
|
||||
modelPrice float64, usePrice bool, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
quota = int(math.Round(float64(textInputTokens)*ratio + float64(textOutTokens)*ratio*completionRatio))
|
||||
quota += int(math.Round(float64(audioInputTokens)*ratio*audioRatio + float64(audioOutTokens)*ratio*audioRatio*audioCompletionRatio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
totalTokens := usage.TotalTokens
|
||||
var logContent string
|
||||
if !usePrice {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
} else {
|
||||
//if sensitiveResp != nil {
|
||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
//}
|
||||
//quotaDelta := quota - preConsumedQuota
|
||||
//if quotaDelta != 0 {
|
||||
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
// if err != nil {
|
||||
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
// }
|
||||
//}
|
||||
|
||||
//err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
//if err != nil {
|
||||
// common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
//}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -42,11 +43,47 @@ func Done(c *gin.Context) {
|
||||
_ = StringData(c, "[DONE]")
|
||||
}
|
||||
|
||||
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
|
||||
if ws == nil {
|
||||
common.LogError(c, "websocket connection is nil")
|
||||
return errors.New("websocket connection is nil")
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
|
||||
return ws.WriteMessage(1, []byte(str))
|
||||
}
|
||||
|
||||
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
||||
jsonData, err := json.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
if ws == nil {
|
||||
common.LogError(c, "websocket connection is nil")
|
||||
return errors.New("websocket connection is nil")
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
|
||||
return ws.WriteMessage(1, jsonData)
|
||||
}
|
||||
|
||||
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
|
||||
errorObj := &dto.RealtimeEvent{
|
||||
Type: "error",
|
||||
EventId: GetLocalRealtimeID(c),
|
||||
Error: &openaiError,
|
||||
}
|
||||
_ = WssObject(c, ws, errorObj)
|
||||
}
|
||||
|
||||
func GetResponseID(c *gin.Context) string {
|
||||
logID := c.GetString("X-Oneapi-Request-Id")
|
||||
logID := c.GetString(common.RequestIdKey)
|
||||
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||
}
|
||||
|
||||
func GetLocalRealtimeID(c *gin.Context) string {
|
||||
logID := c.GetString(common.RequestIdKey)
|
||||
return fmt.Sprintf("evt_%s", logID)
|
||||
}
|
||||
|
||||
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@@ -191,6 +192,72 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||||
audioToken := 0
|
||||
textToken := 0
|
||||
switch request.Type {
|
||||
case dto.RealtimeEventTypeSessionUpdate:
|
||||
if request.Session != nil {
|
||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
textToken += msgTokens
|
||||
}
|
||||
case dto.RealtimeEventResponseAudioDelta:
|
||||
// count audio token
|
||||
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||||
}
|
||||
audioToken += atk
|
||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||
// count text token
|
||||
tkm, err := CountTextToken(request.Delta, model)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
||||
}
|
||||
textToken += tkm
|
||||
case dto.RealtimeEventInputAudioBufferAppend:
|
||||
// count audio token
|
||||
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||||
}
|
||||
audioToken += atk
|
||||
case dto.RealtimeEventConversationItemCreated:
|
||||
if request.Item != nil {
|
||||
switch request.Item.Type {
|
||||
case "message":
|
||||
for _, content := range request.Item.Content {
|
||||
if content.Type == "input_text" {
|
||||
tokens, err := CountTextToken(content.Text, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
textToken += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case dto.RealtimeEventTypeResponseDone:
|
||||
// count tools token
|
||||
if !info.IsFirstRequest {
|
||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||
for _, tool := range info.RealtimeTools {
|
||||
toolTokens, err := CountTokenInput(tool, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
textToken += 8
|
||||
textToken += toolTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return textToken, audioToken, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
@@ -223,7 +290,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
} else {
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == "image_url" {
|
||||
if m.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
||||
if err != nil {
|
||||
@@ -231,6 +298,9 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
}
|
||||
tokenNum += imageTokenNum
|
||||
log.Printf("image token num: %d", imageTokenNum)
|
||||
} else if m.Type == dto.ContentTypeInputAudio {
|
||||
// TODO: 音频token数量计算
|
||||
tokenNum += 100
|
||||
} else {
|
||||
tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||
}
|
||||
@@ -245,13 +315,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
func CountTokenInput(input any, model string) (int, error) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTokenText(v, model)
|
||||
return CountTextToken(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return CountTokenText(text, model)
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||
}
|
||||
@@ -273,16 +343,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountAudioToken(text string, model string) (int, error) {
|
||||
func CountTTSToken(text string, model string) (int, error) {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text), nil
|
||||
} else {
|
||||
return CountTokenText(text, model)
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
}
|
||||
|
||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTokenText(text string, model string) (int, error) {
|
||||
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||||
if audioBase64 == "" {
|
||||
return 0, nil
|
||||
}
|
||||
duration, err := parseAudio(audioBase64, audioFormat)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(duration / 60 * 100 / 0.06), nil
|
||||
}
|
||||
|
||||
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
|
||||
if audioBase64 == "" {
|
||||
return 0, nil
|
||||
}
|
||||
duration, err := parseAudio(audioBase64, audioFormat)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(duration / 60 * 200 / 0.24), nil
|
||||
}
|
||||
|
||||
//func CountAudioToken(sec float64, audioType string) {
|
||||
// if audioType == "input" {
|
||||
//
|
||||
// }
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTextToken(text string, model string) (int, error) {
|
||||
var err error
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text), err
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm, err := CountTokenText(responseText, modeName)
|
||||
ctkm, err := CountTextToken(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"@douyinfe/semi-icons": "^2.46.1",
|
||||
"@douyinfe/semi-ui": "^2.55.3",
|
||||
"@douyinfe/semi-icons": "^2.63.1",
|
||||
"@douyinfe/semi-ui": "^2.63.1",
|
||||
"@visactor/react-vchart": "~1.8.8",
|
||||
"@visactor/vchart": "~1.8.8",
|
||||
"@visactor/vchart-semi-theme": "~1.8.8",
|
||||
@@ -22,7 +22,8 @@
|
||||
"react-toastify": "^9.0.8",
|
||||
"react-turnstile": "^1.0.5",
|
||||
"semantic-ui-offline": "^2.5.0",
|
||||
"semantic-ui-react": "^2.1.3"
|
||||
"semantic-ui-react": "^2.1.3",
|
||||
"sse": "github:mpetazzoni/sse.js"
|
||||
},
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
|
||||
1316
web/pnpm-lock.yaml
generated
1316
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,7 @@ import { Layout } from '@douyinfe/semi-ui';
|
||||
import Midjourney from './pages/Midjourney';
|
||||
import Pricing from './pages/Pricing/index.js';
|
||||
import Task from "./pages/Task/index.js";
|
||||
import Playground from './components/Playground.js';
|
||||
|
||||
const Home = lazy(() => import('./pages/Home'));
|
||||
const Detail = lazy(() => import('./pages/Detail'));
|
||||
@@ -100,6 +101,14 @@ function App() {
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/playground'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Playground />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/redemption'
|
||||
element={
|
||||
@@ -247,7 +256,7 @@ function App() {
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/chat'
|
||||
path='/chat/:id?'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Chat />
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { API, showError, showSuccess } from '../helpers';
|
||||
import { API, showError, showSuccess, updateAPI } from '../helpers';
|
||||
import { UserContext } from '../context/User';
|
||||
import { setUserData } from '../helpers/data.js';
|
||||
|
||||
const GitHubOAuth = () => {
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
@@ -23,8 +24,10 @@ const GitHubOAuth = () => {
|
||||
} else {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
setUserData(data);
|
||||
updateAPI()
|
||||
showSuccess('登录成功!');
|
||||
navigate('/');
|
||||
navigate('/token');
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
|
||||
@@ -39,10 +39,10 @@ let buttons = [
|
||||
// icon: <IconHomeStroked />,
|
||||
},
|
||||
// {
|
||||
// text: '模型价格',
|
||||
// itemKey: 'pricing',
|
||||
// to: '/pricing',
|
||||
// icon: <IconNoteMoneyStroked />,
|
||||
// text: 'Playground',
|
||||
// itemKey: 'playground',
|
||||
// to: '/playground',
|
||||
// // icon: <IconNoteMoneyStroked />,
|
||||
// },
|
||||
];
|
||||
|
||||
|
||||
@@ -71,6 +71,8 @@ const LoginForm = () => {
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
setUserData(data);
|
||||
updateAPI()
|
||||
navigate('/');
|
||||
showSuccess('登录成功!');
|
||||
setShowWeChatLoginModal(false);
|
||||
@@ -143,6 +145,8 @@ const LoginForm = () => {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
setUserData(data);
|
||||
updateAPI()
|
||||
navigate('/');
|
||||
} else {
|
||||
showError(message);
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
|
||||
import {
|
||||
Avatar,
|
||||
Button,
|
||||
Button, Descriptions,
|
||||
Form,
|
||||
Layout,
|
||||
Modal,
|
||||
@@ -20,7 +20,7 @@ import {
|
||||
Spin,
|
||||
Table,
|
||||
Tag,
|
||||
Tooltip,
|
||||
Tooltip
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { ITEMS_PER_PAGE } from '../constants';
|
||||
import {
|
||||
@@ -250,7 +250,7 @@ const LogsTable = () => {
|
||||
title: '类型',
|
||||
dataIndex: 'type',
|
||||
render: (text, record, index) => {
|
||||
return <div>{renderType(text)}</div>;
|
||||
return <>{renderType(text)}</>;
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -258,7 +258,7 @@ const LogsTable = () => {
|
||||
dataIndex: 'model_name',
|
||||
render: (text, record, index) => {
|
||||
return record.type === 0 || record.type === 2 ? (
|
||||
<div>
|
||||
<>
|
||||
<Tag
|
||||
color={stringToColor(text)}
|
||||
size='large'
|
||||
@@ -269,7 +269,7 @@ const LogsTable = () => {
|
||||
{' '}
|
||||
{text}{' '}
|
||||
</Tag>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
);
|
||||
@@ -282,22 +282,22 @@ const LogsTable = () => {
|
||||
if (record.is_stream) {
|
||||
let other = getLogOther(record.other);
|
||||
return (
|
||||
<div>
|
||||
<>
|
||||
<Space>
|
||||
{renderUseTime(text)}
|
||||
{renderFirstUseTime(other.frt)}
|
||||
{renderIsStream(record.is_stream)}
|
||||
</Space>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div>
|
||||
<>
|
||||
<Space>
|
||||
{renderUseTime(text)}
|
||||
{renderIsStream(record.is_stream)}
|
||||
</Space>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
},
|
||||
@@ -307,7 +307,7 @@ const LogsTable = () => {
|
||||
dataIndex: 'prompt_tokens',
|
||||
render: (text, record, index) => {
|
||||
return record.type === 0 || record.type === 2 ? (
|
||||
<div>{<span> {text} </span>}</div>
|
||||
<>{<span> {text} </span>}</>
|
||||
) : (
|
||||
<></>
|
||||
);
|
||||
@@ -319,7 +319,7 @@ const LogsTable = () => {
|
||||
render: (text, record, index) => {
|
||||
return parseInt(text) > 0 &&
|
||||
(record.type === 0 || record.type === 2) ? (
|
||||
<div>{<span> {text} </span>}</div>
|
||||
<>{<span> {text} </span>}</>
|
||||
) : (
|
||||
<></>
|
||||
);
|
||||
@@ -330,39 +330,39 @@ const LogsTable = () => {
|
||||
dataIndex: 'quota',
|
||||
render: (text, record, index) => {
|
||||
return record.type === 0 || record.type === 2 ? (
|
||||
<div>{renderQuota(text, 6)}</div>
|
||||
<>{renderQuota(text, 6)}</>
|
||||
) : (
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '重试',
|
||||
dataIndex: 'retry',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
let content = '渠道:' + record.channel;
|
||||
if (record.other !== '') {
|
||||
let other = JSON.parse(record.other);
|
||||
if (other === null) {
|
||||
return <></>;
|
||||
}
|
||||
if (other.admin_info !== undefined) {
|
||||
if (
|
||||
other.admin_info.use_channel !== null &&
|
||||
other.admin_info.use_channel !== undefined &&
|
||||
other.admin_info.use_channel !== ''
|
||||
) {
|
||||
// channel id array
|
||||
let useChannel = other.admin_info.use_channel;
|
||||
let useChannelStr = useChannel.join('->');
|
||||
content = `渠道:${useChannelStr}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
return isAdminUser ? <div>{content}</div> : <></>;
|
||||
},
|
||||
},
|
||||
// {
|
||||
// title: '重试',
|
||||
// dataIndex: 'retry',
|
||||
// className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
// render: (text, record, index) => {
|
||||
// let content = '渠道:' + record.channel;
|
||||
// if (record.other !== '') {
|
||||
// let other = JSON.parse(record.other);
|
||||
// if (other === null) {
|
||||
// return <></>;
|
||||
// }
|
||||
// if (other.admin_info !== undefined) {
|
||||
// if (
|
||||
// other.admin_info.use_channel !== null &&
|
||||
// other.admin_info.use_channel !== undefined &&
|
||||
// other.admin_info.use_channel !== ''
|
||||
// ) {
|
||||
// // channel id array
|
||||
// let useChannel = other.admin_info.use_channel;
|
||||
// let useChannelStr = useChannel.join('->');
|
||||
// content = `渠道:${useChannelStr}`;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// return isAdminUser ? <div>{content}</div> : <></>;
|
||||
// },
|
||||
// },
|
||||
{
|
||||
title: '详情',
|
||||
dataIndex: 'content',
|
||||
@@ -409,6 +409,7 @@ const LogsTable = () => {
|
||||
];
|
||||
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [expandData, setExpandData] = useState({});
|
||||
const [showStat, setShowStat] = useState(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [loadingStat, setLoadingStat] = useState(false);
|
||||
@@ -512,10 +513,54 @@ const LogsTable = () => {
|
||||
};
|
||||
|
||||
const setLogsFormat = (logs) => {
|
||||
let expandDatesLocal = {};
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = '' + logs[i].id;
|
||||
let other = getLogOther(logs[i].other);
|
||||
let expandDataLocal = [];
|
||||
if (isAdmin()) {
|
||||
let content = '渠道:' + logs[i].channel;
|
||||
if (other.admin_info !== undefined) {
|
||||
if (
|
||||
other.admin_info.use_channel !== null &&
|
||||
other.admin_info.use_channel !== undefined &&
|
||||
other.admin_info.use_channel !== ''
|
||||
) {
|
||||
// channel id array
|
||||
let useChannel = other.admin_info.use_channel;
|
||||
let useChannelStr = useChannel.join('->');
|
||||
content = `渠道:${useChannelStr}`;
|
||||
}
|
||||
}
|
||||
expandDataLocal.push({
|
||||
key: '重试',
|
||||
value: content,
|
||||
})
|
||||
}
|
||||
if (other.ws) {
|
||||
expandDataLocal.push({
|
||||
key: '语音输入',
|
||||
value: other.audio_input,
|
||||
});
|
||||
expandDataLocal.push({
|
||||
key: '语音输出',
|
||||
value: other.audio_output,
|
||||
});
|
||||
expandDataLocal.push({
|
||||
key: '文字输入',
|
||||
value: other.text_input,
|
||||
});
|
||||
expandDataLocal.push({
|
||||
key: '文字输出',
|
||||
value: other.text_output,
|
||||
});
|
||||
}
|
||||
expandDatesLocal[logs[i].key] = expandDataLocal;
|
||||
}
|
||||
console.log(expandDatesLocal);
|
||||
setExpandData(expandDatesLocal);
|
||||
|
||||
setLogs(logs);
|
||||
};
|
||||
|
||||
@@ -588,6 +633,10 @@ const LogsTable = () => {
|
||||
handleEyeClick();
|
||||
}, []);
|
||||
|
||||
const expandRowRender = (record, index) => {
|
||||
return <Descriptions align="justify" data={expandData[record.key]} />;
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Layout>
|
||||
@@ -686,7 +735,9 @@ const LogsTable = () => {
|
||||
<Table
|
||||
style={{ marginTop: 5 }}
|
||||
columns={columns}
|
||||
expandedRowRender={expandRowRender}
|
||||
dataSource={logs}
|
||||
rowKey="key"
|
||||
pagination={{
|
||||
currentPage: activePage,
|
||||
pageSize: pageSize,
|
||||
|
||||
@@ -10,6 +10,7 @@ import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.
|
||||
import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
|
||||
|
||||
import { API, showError, showSuccess } from '../helpers';
|
||||
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
|
||||
|
||||
const OperationSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
@@ -50,6 +51,7 @@ const OperationSetting = () => {
|
||||
DataExportInterval: 5,
|
||||
DefaultCollapseSidebar: false, // 默认折叠侧边栏
|
||||
RetryTimes: 0,
|
||||
Chats: "[]",
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
@@ -131,6 +133,10 @@ const OperationSetting = () => {
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsCreditLimit options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 聊天设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsChats options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 倍率设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsMagnification options={inputs} refresh={onRefresh} />
|
||||
|
||||
342
web/src/components/Playground.js
Normal file
342
web/src/components/Playground.js
Normal file
@@ -0,0 +1,342 @@
|
||||
import React, { useCallback, useContext, useEffect, useState } from 'react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
import { API, getUserIdFromLocalStorage, showError } from '../helpers';
|
||||
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography } from '@douyinfe/semi-ui';
|
||||
import { SSE } from 'sse';
|
||||
|
||||
const defaultMessage = [
|
||||
{
|
||||
role: 'user',
|
||||
id: '2',
|
||||
createAt: 1715676751919,
|
||||
content: "你好",
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
id: '3',
|
||||
createAt: 1715676751919,
|
||||
content: "你好,请问有什么可以帮助您的吗?",
|
||||
}
|
||||
];
|
||||
|
||||
let id = 4;
|
||||
function getId() {
|
||||
return `${id++}`
|
||||
}
|
||||
|
||||
const Playground = () => {
|
||||
const [inputs, setInputs] = useState({
|
||||
model: 'gpt-4o-mini',
|
||||
group: '',
|
||||
max_tokens: 0,
|
||||
temperature: 0,
|
||||
});
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [status, setStatus] = useState({});
|
||||
const [systemPrompt, setSystemPrompt] = useState('You are a helpful assistant. You can help me by answering my questions. You can also ask me questions.');
|
||||
const [message, setMessage] = useState(defaultMessage);
|
||||
const [models, setModels] = useState([]);
|
||||
const [groups, setGroups] = useState([]);
|
||||
|
||||
const handleInputChange = (name, value) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (searchParams.get('expired')) {
|
||||
showError('未登录或登录已过期,请重新登录!');
|
||||
}
|
||||
let status = localStorage.getItem('status');
|
||||
if (status) {
|
||||
status = JSON.parse(status);
|
||||
setStatus(status);
|
||||
}
|
||||
loadModels();
|
||||
loadGroups();
|
||||
}, []);
|
||||
|
||||
const loadModels = async () => {
|
||||
let res = await API.get(`/api/user/models`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let localModelOptions = data.map((model) => ({
|
||||
label: model,
|
||||
value: model,
|
||||
}));
|
||||
setModels(localModelOptions);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const loadGroups = async () => {
|
||||
let res = await API.get(`/api/user/self/groups`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
// return data is a map, key is group name, value is group description
|
||||
// label is group description, value is group name
|
||||
let localGroupOptions = Object.keys(data).map((group) => ({
|
||||
label: data[group],
|
||||
value: group,
|
||||
}));
|
||||
// handleInputChange('group', localGroupOptions[0].value);
|
||||
|
||||
if (localGroupOptions.length > 0) {
|
||||
} else {
|
||||
localGroupOptions = [{
|
||||
label: '用户分组',
|
||||
value: '',
|
||||
}];
|
||||
setGroups(localGroupOptions);
|
||||
}
|
||||
setGroups(localGroupOptions);
|
||||
handleInputChange('group', localGroupOptions[0].value);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const commonOuterStyle = {
|
||||
border: '1px solid var(--semi-color-border)',
|
||||
borderRadius: '16px',
|
||||
margin: '0px 8px',
|
||||
}
|
||||
|
||||
const getSystemMessage = () => {
|
||||
if (systemPrompt !== '') {
|
||||
return {
|
||||
role: 'system',
|
||||
id: '1',
|
||||
createAt: 1715676751919,
|
||||
content: systemPrompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let handleSSE = (payload) => {
|
||||
let source = new SSE('/pg/chat/completions', {
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"New-Api-User": getUserIdFromLocalStorage(),
|
||||
},
|
||||
method: "POST",
|
||||
payload: JSON.stringify(payload),
|
||||
});
|
||||
source.addEventListener("message", (e) => {
|
||||
if (e.data !== "[DONE]") {
|
||||
let payload = JSON.parse(e.data);
|
||||
// console.log("Payload: ", payload);
|
||||
if (payload.choices.length === 0) {
|
||||
source.close();
|
||||
completeMessage();
|
||||
} else {
|
||||
let text = payload.choices[0].delta.content;
|
||||
if (text) {
|
||||
generateMockResponse(text);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
completeMessage();
|
||||
}
|
||||
});
|
||||
|
||||
source.addEventListener("error", (e) => {
|
||||
generateMockResponse(e.data)
|
||||
completeMessage('error')
|
||||
});
|
||||
|
||||
source.addEventListener("readystatechange", (e) => {
|
||||
if (e.readyState >= 2) {
|
||||
if (source.status === undefined) {
|
||||
source.close();
|
||||
completeMessage();
|
||||
}
|
||||
}
|
||||
});
|
||||
source.stream();
|
||||
}
|
||||
|
||||
const onMessageSend = useCallback((content, attachment) => {
|
||||
console.log("attachment: ", attachment);
|
||||
setMessage((prevMessage) => {
|
||||
const newMessage = [
|
||||
...prevMessage,
|
||||
{
|
||||
role: 'user',
|
||||
content: content,
|
||||
createAt: Date.now(),
|
||||
id: getId()
|
||||
}
|
||||
];
|
||||
|
||||
// 将 getPayload 移到这里
|
||||
const getPayload = () => {
|
||||
let systemMessage = getSystemMessage();
|
||||
let messages = newMessage.map((item) => {
|
||||
return {
|
||||
role: item.role,
|
||||
content: item.content,
|
||||
}
|
||||
});
|
||||
if (systemMessage) {
|
||||
messages.unshift(systemMessage);
|
||||
}
|
||||
return {
|
||||
messages: messages,
|
||||
stream: true,
|
||||
model: inputs.model,
|
||||
group: inputs.group,
|
||||
max_tokens: parseInt(inputs.max_tokens),
|
||||
temperature: inputs.temperature,
|
||||
};
|
||||
};
|
||||
|
||||
// 使用更新后的消息状态调用 handleSSE
|
||||
handleSSE(getPayload());
|
||||
newMessage.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
createAt: Date.now(),
|
||||
id: getId(),
|
||||
status: 'loading'
|
||||
});
|
||||
return newMessage;
|
||||
});
|
||||
}, [getSystemMessage]);
|
||||
|
||||
const completeMessage = useCallback((status = 'complete') => {
|
||||
// console.log("Complete Message: ", status)
|
||||
setMessage((prevMessage) => {
|
||||
const lastMessage = prevMessage[prevMessage.length - 1];
|
||||
// only change the status if the last message is not complete and not error
|
||||
if (lastMessage.status === 'complete' || lastMessage.status === 'error') {
|
||||
return prevMessage;
|
||||
}
|
||||
return [
|
||||
...prevMessage.slice(0, -1),
|
||||
{ ...lastMessage, status: status }
|
||||
];
|
||||
});
|
||||
}, [])
|
||||
|
||||
const generateMockResponse = useCallback((content) => {
|
||||
// console.log("Generate Mock Response: ", content);
|
||||
setMessage((message) => {
|
||||
const lastMessage = message[message.length - 1];
|
||||
let newMessage = {...lastMessage};
|
||||
if (lastMessage.status === 'loading' || lastMessage.status === 'incomplete') {
|
||||
newMessage = {
|
||||
...newMessage,
|
||||
content: (lastMessage.content || '') + content,
|
||||
status: 'incomplete'
|
||||
}
|
||||
}
|
||||
return [ ...message.slice(0, -1), newMessage ]
|
||||
})
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Layout style={{height: '100%'}}>
|
||||
<Layout.Sider>
|
||||
<Card style={commonOuterStyle}>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>分组:</Typography.Text>
|
||||
</div>
|
||||
<Select
|
||||
placeholder={'请选择分组'}
|
||||
name='group'
|
||||
required
|
||||
selection
|
||||
onChange={(value) => {
|
||||
handleInputChange('group', value);
|
||||
}}
|
||||
value={inputs.group}
|
||||
autoComplete='new-password'
|
||||
optionList={groups}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>模型:</Typography.Text>
|
||||
</div>
|
||||
<Select
|
||||
placeholder={'请选择模型'}
|
||||
name='model'
|
||||
required
|
||||
selection
|
||||
filter
|
||||
onChange={(value) => {
|
||||
handleInputChange('model', value);
|
||||
}}
|
||||
value={inputs.model}
|
||||
autoComplete='new-password'
|
||||
optionList={models}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>Temperature:</Typography.Text>
|
||||
</div>
|
||||
<Slider
|
||||
step={0.1}
|
||||
min={0.1}
|
||||
max={1}
|
||||
value={inputs.temperature}
|
||||
onChange={(value) => {
|
||||
handleInputChange('temperature', value);
|
||||
}}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>MaxTokens:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
placeholder='MaxTokens'
|
||||
name='max_tokens'
|
||||
required
|
||||
autoComplete='new-password'
|
||||
defaultValue={0}
|
||||
value={inputs.max_tokens}
|
||||
onChange={(value) => {
|
||||
handleInputChange('max_tokens', value);
|
||||
}}
|
||||
/>
|
||||
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>System:</Typography.Text>
|
||||
</div>
|
||||
<TextArea
|
||||
placeholder='System Prompt'
|
||||
name='system'
|
||||
required
|
||||
autoComplete='new-password'
|
||||
autosize
|
||||
defaultValue={systemPrompt}
|
||||
// value={systemPrompt}
|
||||
onChange={(value) => {
|
||||
setSystemPrompt(value);
|
||||
}}
|
||||
/>
|
||||
|
||||
</Card>
|
||||
</Layout.Sider>
|
||||
<Layout.Content>
|
||||
<div style={{height: '100%'}}>
|
||||
<Chat
|
||||
chatBoxRenderConfig={{
|
||||
renderChatBoxAction: () => {
|
||||
return <div></div>
|
||||
}
|
||||
}}
|
||||
style={commonOuterStyle}
|
||||
chats={message}
|
||||
onMessageSend={onMessageSend}
|
||||
showClearContext
|
||||
onClear={() => {
|
||||
setMessage([]);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
);
|
||||
};
|
||||
|
||||
export default Playground;
|
||||
790
web/src/components/SafetySetting.js
Normal file
790
web/src/components/SafetySetting.js
Normal file
@@ -0,0 +1,790 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Button,
|
||||
Divider,
|
||||
Form,
|
||||
Grid,
|
||||
Header,
|
||||
Message,
|
||||
Modal,
|
||||
} from 'semantic-ui-react';
|
||||
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
|
||||
|
||||
import { useTheme } from '../context/Theme';
|
||||
|
||||
const SafetySetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
PasswordLoginEnabled: '',
|
||||
PasswordRegisterEnabled: '',
|
||||
EmailVerificationEnabled: '',
|
||||
GitHubOAuthEnabled: '',
|
||||
GitHubClientId: '',
|
||||
GitHubClientSecret: '',
|
||||
Notice: '',
|
||||
SMTPServer: '',
|
||||
SMTPPort: '',
|
||||
SMTPAccount: '',
|
||||
SMTPFrom: '',
|
||||
SMTPToken: '',
|
||||
ServerAddress: '',
|
||||
WorkerUrl: '',
|
||||
WorkerValidKey: '',
|
||||
EpayId: '',
|
||||
EpayKey: '',
|
||||
Price: 7.3,
|
||||
MinTopUp: 1,
|
||||
TopupGroupRatio: '',
|
||||
PayAddress: '',
|
||||
CustomCallbackAddress: '',
|
||||
Footer: '',
|
||||
WeChatAuthEnabled: '',
|
||||
WeChatServerAddress: '',
|
||||
WeChatServerToken: '',
|
||||
WeChatAccountQRCodeImageURL: '',
|
||||
TurnstileCheckEnabled: '',
|
||||
TurnstileSiteKey: '',
|
||||
TurnstileSecretKey: '',
|
||||
RegisterEnabled: '',
|
||||
EmailDomainRestrictionEnabled: '',
|
||||
EmailAliasRestrictionEnabled: '',
|
||||
SMTPSSLEnabled: '',
|
||||
EmailDomainWhitelist: [],
|
||||
// telegram login
|
||||
TelegramOAuthEnabled: '',
|
||||
TelegramBotToken: '',
|
||||
TelegramBotName: '',
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]);
|
||||
const [restrictedDomainInput, setRestrictedDomainInput] = useState('');
|
||||
const [showPasswordWarningModal, setShowPasswordWarningModal] =
|
||||
useState(false);
|
||||
|
||||
const theme = useTheme();
|
||||
const isDark = theme === 'dark';
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key === 'TopupGroupRatio') {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
}
|
||||
newInputs[item.key] = item.value;
|
||||
});
|
||||
setInputs({
|
||||
...newInputs,
|
||||
EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(','),
|
||||
});
|
||||
setOriginInputs(newInputs);
|
||||
|
||||
setEmailDomainWhitelist(
|
||||
newInputs.EmailDomainWhitelist.split(',').map((item) => {
|
||||
return { key: item, text: item, value: item };
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions().then();
|
||||
}, []);
|
||||
useEffect(() => {}, [inputs.EmailDomainWhitelist]);
|
||||
|
||||
const updateOption = async (key, value) => {
|
||||
setLoading(true);
|
||||
switch (key) {
|
||||
case 'PasswordLoginEnabled':
|
||||
case 'PasswordRegisterEnabled':
|
||||
case 'EmailVerificationEnabled':
|
||||
case 'GitHubOAuthEnabled':
|
||||
case 'WeChatAuthEnabled':
|
||||
case 'TelegramOAuthEnabled':
|
||||
case 'TurnstileCheckEnabled':
|
||||
case 'EmailDomainRestrictionEnabled':
|
||||
case 'EmailAliasRestrictionEnabled':
|
||||
case 'SMTPSSLEnabled':
|
||||
case 'RegisterEnabled':
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
const res = await API.put('/api/option/', {
|
||||
key,
|
||||
value,
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
if (key === 'EmailDomainWhitelist') {
|
||||
value = value.split(',');
|
||||
}
|
||||
if (key === 'Price') {
|
||||
value = parseFloat(value);
|
||||
}
|
||||
setInputs((inputs) => ({
|
||||
...inputs,
|
||||
[key]: value,
|
||||
}));
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const handleInputChange = async (e, { name, value }) => {
|
||||
if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') {
|
||||
// block disabling password login
|
||||
setShowPasswordWarningModal(true);
|
||||
return;
|
||||
}
|
||||
if (
|
||||
name === 'Notice' ||
|
||||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
|
||||
name === 'ServerAddress' ||
|
||||
name === 'WorkerUrl' ||
|
||||
name === 'WorkerValidKey' ||
|
||||
name === 'EpayId' ||
|
||||
name === 'EpayKey' ||
|
||||
name === 'Price' ||
|
||||
name === 'PayAddress' ||
|
||||
name === 'GitHubClientId' ||
|
||||
name === 'GitHubClientSecret' ||
|
||||
name === 'WeChatServerAddress' ||
|
||||
name === 'WeChatServerToken' ||
|
||||
name === 'WeChatAccountQRCodeImageURL' ||
|
||||
name === 'TurnstileSiteKey' ||
|
||||
name === 'TurnstileSecretKey' ||
|
||||
name === 'EmailDomainWhitelist' ||
|
||||
name === 'TopupGroupRatio' ||
|
||||
name === 'TelegramBotToken' ||
|
||||
name === 'TelegramBotName'
|
||||
) {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
} else {
|
||||
await updateOption(name, value);
|
||||
}
|
||||
};
|
||||
|
||||
const submitServerAddress = async () => {
|
||||
let ServerAddress = removeTrailingSlash(inputs.ServerAddress);
|
||||
await updateOption('ServerAddress', ServerAddress);
|
||||
};
|
||||
|
||||
const submitWorker = async () => {
|
||||
let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl);
|
||||
await updateOption('WorkerUrl', WorkerUrl);
|
||||
if (inputs.WorkerValidKey !== '') {
|
||||
await updateOption('WorkerValidKey', inputs.WorkerValidKey);
|
||||
}
|
||||
}
|
||||
|
||||
const submitPayAddress = async () => {
|
||||
if (inputs.ServerAddress === '') {
|
||||
showError('请先填写服务器地址');
|
||||
return;
|
||||
}
|
||||
if (originInputs['TopupGroupRatio'] !== inputs.TopupGroupRatio) {
|
||||
if (!verifyJSON(inputs.TopupGroupRatio)) {
|
||||
showError('充值分组倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('TopupGroupRatio', inputs.TopupGroupRatio);
|
||||
}
|
||||
let PayAddress = removeTrailingSlash(inputs.PayAddress);
|
||||
await updateOption('PayAddress', PayAddress);
|
||||
if (inputs.EpayId !== '') {
|
||||
await updateOption('EpayId', inputs.EpayId);
|
||||
}
|
||||
if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') {
|
||||
await updateOption('EpayKey', inputs.EpayKey);
|
||||
}
|
||||
await updateOption('Price', '' + inputs.Price);
|
||||
};
|
||||
|
||||
const submitSMTP = async () => {
|
||||
if (originInputs['SMTPServer'] !== inputs.SMTPServer) {
|
||||
await updateOption('SMTPServer', inputs.SMTPServer);
|
||||
}
|
||||
if (originInputs['SMTPAccount'] !== inputs.SMTPAccount) {
|
||||
await updateOption('SMTPAccount', inputs.SMTPAccount);
|
||||
}
|
||||
if (originInputs['SMTPFrom'] !== inputs.SMTPFrom) {
|
||||
await updateOption('SMTPFrom', inputs.SMTPFrom);
|
||||
}
|
||||
if (
|
||||
originInputs['SMTPPort'] !== inputs.SMTPPort &&
|
||||
inputs.SMTPPort !== ''
|
||||
) {
|
||||
await updateOption('SMTPPort', inputs.SMTPPort);
|
||||
}
|
||||
if (
|
||||
originInputs['SMTPToken'] !== inputs.SMTPToken &&
|
||||
inputs.SMTPToken !== ''
|
||||
) {
|
||||
await updateOption('SMTPToken', inputs.SMTPToken);
|
||||
}
|
||||
};
|
||||
|
||||
const submitEmailDomainWhitelist = async () => {
|
||||
if (
|
||||
originInputs['EmailDomainWhitelist'] !==
|
||||
inputs.EmailDomainWhitelist.join(',') &&
|
||||
inputs.SMTPToken !== ''
|
||||
) {
|
||||
await updateOption(
|
||||
'EmailDomainWhitelist',
|
||||
inputs.EmailDomainWhitelist.join(','),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const submitWeChat = async () => {
|
||||
if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) {
|
||||
await updateOption(
|
||||
'WeChatServerAddress',
|
||||
removeTrailingSlash(inputs.WeChatServerAddress),
|
||||
);
|
||||
}
|
||||
if (
|
||||
originInputs['WeChatAccountQRCodeImageURL'] !==
|
||||
inputs.WeChatAccountQRCodeImageURL
|
||||
) {
|
||||
await updateOption(
|
||||
'WeChatAccountQRCodeImageURL',
|
||||
inputs.WeChatAccountQRCodeImageURL,
|
||||
);
|
||||
}
|
||||
if (
|
||||
originInputs['WeChatServerToken'] !== inputs.WeChatServerToken &&
|
||||
inputs.WeChatServerToken !== ''
|
||||
) {
|
||||
await updateOption('WeChatServerToken', inputs.WeChatServerToken);
|
||||
}
|
||||
};
|
||||
|
||||
const submitGitHubOAuth = async () => {
|
||||
if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) {
|
||||
await updateOption('GitHubClientId', inputs.GitHubClientId);
|
||||
}
|
||||
if (
|
||||
originInputs['GitHubClientSecret'] !== inputs.GitHubClientSecret &&
|
||||
inputs.GitHubClientSecret !== ''
|
||||
) {
|
||||
await updateOption('GitHubClientSecret', inputs.GitHubClientSecret);
|
||||
}
|
||||
};
|
||||
|
||||
const submitTelegramSettings = async () => {
|
||||
// await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled);
|
||||
await updateOption('TelegramBotToken', inputs.TelegramBotToken);
|
||||
await updateOption('TelegramBotName', inputs.TelegramBotName);
|
||||
};
|
||||
|
||||
const submitTurnstile = async () => {
|
||||
if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) {
|
||||
await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey);
|
||||
}
|
||||
if (
|
||||
originInputs['TurnstileSecretKey'] !== inputs.TurnstileSecretKey &&
|
||||
inputs.TurnstileSecretKey !== ''
|
||||
) {
|
||||
await updateOption('TurnstileSecretKey', inputs.TurnstileSecretKey);
|
||||
}
|
||||
};
|
||||
|
||||
const submitNewRestrictedDomain = () => {
|
||||
const localDomainList = inputs.EmailDomainWhitelist;
|
||||
if (
|
||||
restrictedDomainInput !== '' &&
|
||||
!localDomainList.includes(restrictedDomainInput)
|
||||
) {
|
||||
setRestrictedDomainInput('');
|
||||
setInputs({
|
||||
...inputs,
|
||||
EmailDomainWhitelist: [...localDomainList, restrictedDomainInput],
|
||||
});
|
||||
setEmailDomainWhitelist([
|
||||
...EmailDomainWhitelist,
|
||||
{
|
||||
key: restrictedDomainInput,
|
||||
text: restrictedDomainInput,
|
||||
value: restrictedDomainInput,
|
||||
},
|
||||
]);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
<Form loading={loading} inverted={isDark}>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
通用设置
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='服务器地址'
|
||||
placeholder='例如:https://yourdomain.com'
|
||||
value={inputs.ServerAddress}
|
||||
name='ServerAddress'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitServerAddress}>
|
||||
更新服务器地址
|
||||
</Form.Button>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
代理设置(支持 <a href='https://github.com/Calcium-Ion/new-api-worker' target='_blank' rel='noreferrer'>new-api-worker</a>)
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='Worker地址,不填写则不启用代理'
|
||||
placeholder='例如:https://workername.yourdomain.workers.dev'
|
||||
value={inputs.WorkerUrl}
|
||||
name='WorkerUrl'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='Worker密钥,根据你部署的 Worker 填写'
|
||||
placeholder='例如:your_secret_key'
|
||||
value={inputs.WorkerValidKey}
|
||||
name='WorkerValidKey'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitWorker}>
|
||||
更新Worker设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
支付设置(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='支付地址,不填写则不启用在线支付'
|
||||
placeholder='例如:https://yourdomain.com'
|
||||
value={inputs.PayAddress}
|
||||
name='PayAddress'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='易支付商户ID'
|
||||
placeholder='例如:0001'
|
||||
value={inputs.EpayId}
|
||||
name='EpayId'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='易支付商户密钥'
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
value={inputs.EpayKey}
|
||||
name='EpayKey'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='回调地址,不填写则使用上方服务器地址作为回调地址'
|
||||
placeholder='例如:https://yourdomain.com'
|
||||
value={inputs.CustomCallbackAddress}
|
||||
name='CustomCallbackAddress'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='充值价格(x元/美金)'
|
||||
placeholder='例如:7,就是7元/美金'
|
||||
value={inputs.Price}
|
||||
name='Price'
|
||||
min={0}
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='最低充值美元数量(以美金为单位,如果使用额度请自行换算!)'
|
||||
placeholder='例如:2,就是最低充值2$'
|
||||
value={inputs.MinTopUp}
|
||||
name='MinTopUp'
|
||||
min={1}
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='充值分组倍率'
|
||||
name='TopupGroupRatio'
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autoComplete='new-password'
|
||||
value={inputs.TopupGroupRatio}
|
||||
placeholder='为一个 JSON 文本,键为组名称,值为倍率'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitPayAddress}>更新支付设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置登录注册
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.PasswordLoginEnabled === 'true'}
|
||||
label='允许通过密码进行登录'
|
||||
name='PasswordLoginEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
{showPasswordWarningModal && (
|
||||
<Modal
|
||||
open={showPasswordWarningModal}
|
||||
onClose={() => setShowPasswordWarningModal(false)}
|
||||
size={'tiny'}
|
||||
style={{ maxWidth: '450px' }}
|
||||
>
|
||||
<Modal.Header>警告</Modal.Header>
|
||||
<Modal.Content>
|
||||
<p>
|
||||
取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?
|
||||
</p>
|
||||
</Modal.Content>
|
||||
<Modal.Actions>
|
||||
<Button onClick={() => setShowPasswordWarningModal(false)}>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
color='yellow'
|
||||
onClick={async () => {
|
||||
setShowPasswordWarningModal(false);
|
||||
await updateOption('PasswordLoginEnabled', 'false');
|
||||
}}
|
||||
>
|
||||
确定
|
||||
</Button>
|
||||
</Modal.Actions>
|
||||
</Modal>
|
||||
)}
|
||||
<Form.Checkbox
|
||||
checked={inputs.PasswordRegisterEnabled === 'true'}
|
||||
label='允许通过密码进行注册'
|
||||
name='PasswordRegisterEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.EmailVerificationEnabled === 'true'}
|
||||
label='通过密码注册时需要进行邮箱验证'
|
||||
name='EmailVerificationEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.GitHubOAuthEnabled === 'true'}
|
||||
label='允许通过 GitHub 账户登录 & 注册'
|
||||
name='GitHubOAuthEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.WeChatAuthEnabled === 'true'}
|
||||
label='允许通过微信登录 & 注册'
|
||||
name='WeChatAuthEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.TelegramOAuthEnabled === 'true'}
|
||||
label='允许通过 Telegram 进行登录'
|
||||
name='TelegramOAuthEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.RegisterEnabled === 'true'}
|
||||
label='允许新用户注册(此项为否时,新用户将无法以任何方式进行注册)'
|
||||
name='RegisterEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.TurnstileCheckEnabled === 'true'}
|
||||
label='启用 Turnstile 用户校验'
|
||||
name='TurnstileCheckEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置邮箱域名白名单
|
||||
<Header.Subheader>
|
||||
用以防止恶意用户利用临时邮箱批量注册
|
||||
</Header.Subheader>
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Checkbox
|
||||
label='启用邮箱域名白名单'
|
||||
name='EmailDomainRestrictionEnabled'
|
||||
onChange={handleInputChange}
|
||||
checked={inputs.EmailDomainRestrictionEnabled === 'true'}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Checkbox
|
||||
label='启用邮箱别名限制(例如:ab.cd@gmail.com)'
|
||||
name='EmailAliasRestrictionEnabled'
|
||||
onChange={handleInputChange}
|
||||
checked={inputs.EmailAliasRestrictionEnabled === 'true'}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={2}>
|
||||
<Form.Dropdown
|
||||
label='允许的邮箱域名'
|
||||
placeholder='允许的邮箱域名'
|
||||
name='EmailDomainWhitelist'
|
||||
required
|
||||
fluid
|
||||
multiple
|
||||
selection
|
||||
onChange={handleInputChange}
|
||||
value={inputs.EmailDomainWhitelist}
|
||||
autoComplete='new-password'
|
||||
options={EmailDomainWhitelist}
|
||||
/>
|
||||
<Form.Input
|
||||
label='添加新的允许的邮箱域名'
|
||||
action={
|
||||
<Button
|
||||
type='button'
|
||||
onClick={() => {
|
||||
submitNewRestrictedDomain();
|
||||
}}
|
||||
>
|
||||
填入
|
||||
</Button>
|
||||
}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
submitNewRestrictedDomain();
|
||||
}
|
||||
}}
|
||||
autoComplete='new-password'
|
||||
placeholder='输入新的允许的邮箱域名'
|
||||
value={restrictedDomainInput}
|
||||
onChange={(e, { value }) => {
|
||||
setRestrictedDomainInput(value);
|
||||
}}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitEmailDomainWhitelist}>
|
||||
保存邮箱域名白名单设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 SMTP
|
||||
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='SMTP 服务器地址'
|
||||
name='SMTPServer'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.SMTPServer}
|
||||
placeholder='例如:smtp.qq.com'
|
||||
/>
|
||||
<Form.Input
|
||||
label='SMTP 端口'
|
||||
name='SMTPPort'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.SMTPPort}
|
||||
placeholder='默认: 587'
|
||||
/>
|
||||
<Form.Input
|
||||
label='SMTP 账户'
|
||||
name='SMTPAccount'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.SMTPAccount}
|
||||
placeholder='通常是邮箱地址'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='SMTP 发送者邮箱'
|
||||
name='SMTPFrom'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.SMTPFrom}
|
||||
placeholder='通常和邮箱地址保持一致'
|
||||
/>
|
||||
<Form.Input
|
||||
label='SMTP 访问凭证'
|
||||
name='SMTPToken'
|
||||
onChange={handleInputChange}
|
||||
type='password'
|
||||
autoComplete='new-password'
|
||||
checked={inputs.RegisterEnabled === 'true'}
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Checkbox
|
||||
label='启用SMTP SSL(465端口强制开启)'
|
||||
name='SMTPSSLEnabled'
|
||||
onChange={handleInputChange}
|
||||
checked={inputs.SMTPSSLEnabled === 'true'}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 GitHub OAuth App
|
||||
<Header.Subheader>
|
||||
用以支持通过 GitHub 进行登录注册,
|
||||
<a
|
||||
href='https://github.com/settings/developers'
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
>
|
||||
点击此处
|
||||
</a>
|
||||
管理你的 GitHub OAuth App
|
||||
</Header.Subheader>
|
||||
</Header>
|
||||
<Message>
|
||||
Homepage URL 填 <code>{inputs.ServerAddress}</code>
|
||||
,Authorization callback URL 填{' '}
|
||||
<code>{`${inputs.ServerAddress}/oauth/github`}</code>
|
||||
</Message>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='GitHub Client ID'
|
||||
name='GitHubClientId'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.GitHubClientId}
|
||||
placeholder='输入你注册的 GitHub OAuth APP 的 ID'
|
||||
/>
|
||||
<Form.Input
|
||||
label='GitHub Client Secret'
|
||||
name='GitHubClientSecret'
|
||||
onChange={handleInputChange}
|
||||
type='password'
|
||||
autoComplete='new-password'
|
||||
value={inputs.GitHubClientSecret}
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitGitHubOAuth}>
|
||||
保存 GitHub OAuth 设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 WeChat Server
|
||||
<Header.Subheader>
|
||||
用以支持通过微信进行登录注册,
|
||||
<a
|
||||
href='https://github.com/songquanpeng/wechat-server'
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
>
|
||||
点击此处
|
||||
</a>
|
||||
了解 WeChat Server
|
||||
</Header.Subheader>
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='WeChat Server 服务器地址'
|
||||
name='WeChatServerAddress'
|
||||
placeholder='例如:https://yourdomain.com'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.WeChatServerAddress}
|
||||
/>
|
||||
<Form.Input
|
||||
label='WeChat Server 访问凭证'
|
||||
name='WeChatServerToken'
|
||||
type='password'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.WeChatServerToken}
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
/>
|
||||
<Form.Input
|
||||
label='微信公众号二维码图片链接'
|
||||
name='WeChatAccountQRCodeImageURL'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.WeChatAccountQRCodeImageURL}
|
||||
placeholder='输入一个图片链接'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitWeChat}>
|
||||
保存 WeChat Server 设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 Telegram 登录
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Input
|
||||
label='Telegram Bot Token'
|
||||
name='TelegramBotToken'
|
||||
onChange={handleInputChange}
|
||||
value={inputs.TelegramBotToken}
|
||||
placeholder='输入你的 Telegram Bot Token'
|
||||
/>
|
||||
<Form.Input
|
||||
label='Telegram Bot 名称'
|
||||
name='TelegramBotName'
|
||||
onChange={handleInputChange}
|
||||
value={inputs.TelegramBotName}
|
||||
placeholder='输入你的 Telegram Bot 名称'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitTelegramSettings}>
|
||||
保存 Telegram 登录设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 Turnstile
|
||||
<Header.Subheader>
|
||||
用以支持用户校验,
|
||||
<a
|
||||
href='https://dash.cloudflare.com/'
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
>
|
||||
点击此处
|
||||
</a>
|
||||
管理你的 Turnstile Sites,推荐选择 Invisible Widget Type
|
||||
</Header.Subheader>
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='Turnstile Site Key'
|
||||
name='TurnstileSiteKey'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.TurnstileSiteKey}
|
||||
placeholder='输入你注册的 Turnstile Site Key'
|
||||
/>
|
||||
<Form.Input
|
||||
label='Turnstile Secret Key'
|
||||
name='TurnstileSecretKey'
|
||||
onChange={handleInputChange}
|
||||
type='password'
|
||||
autoComplete='new-password'
|
||||
value={inputs.TurnstileSecretKey}
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitTurnstile}>
|
||||
保存 Turnstile 设置
|
||||
</Form.Button>
|
||||
</Form>
|
||||
</Grid.Column>
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
export default SystemSetting;
|
||||
@@ -15,7 +15,7 @@ import '../index.css';
|
||||
|
||||
import {
|
||||
IconCalendarClock, IconChecklistStroked,
|
||||
IconComment,
|
||||
IconComment, IconCommentStroked,
|
||||
IconCreditCard,
|
||||
IconGift, IconHelpCircle,
|
||||
IconHistogram,
|
||||
@@ -40,11 +40,9 @@ const SiderBar = () => {
|
||||
const defaultIsCollapsed =
|
||||
isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true';
|
||||
|
||||
let navigate = useNavigate();
|
||||
const [selectedKeys, setSelectedKeys] = useState(['home']);
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
const [isCollapsed, setIsCollapsed] = useState(defaultIsCollapsed);
|
||||
const [chatItems, setChatItems] = useState([]);
|
||||
const theme = useTheme();
|
||||
const setTheme = useSetTheme();
|
||||
|
||||
@@ -63,16 +61,17 @@ const SiderBar = () => {
|
||||
detail: '/detail',
|
||||
pricing: '/pricing',
|
||||
task: '/task',
|
||||
playground: '/playground',
|
||||
};
|
||||
|
||||
const headerButtons = useMemo(
|
||||
() => [
|
||||
// {
|
||||
// text: '首页',
|
||||
// itemKey: 'home',
|
||||
// to: '/',
|
||||
// icon: <IconHome />,
|
||||
// },
|
||||
{
|
||||
text: 'Playground',
|
||||
itemKey: 'playground',
|
||||
to: '/playground',
|
||||
icon: <IconCommentStroked />,
|
||||
},
|
||||
{
|
||||
text: '模型价格',
|
||||
itemKey: 'pricing',
|
||||
@@ -89,11 +88,12 @@ const SiderBar = () => {
|
||||
{
|
||||
text: '聊天',
|
||||
itemKey: 'chat',
|
||||
to: '/chat',
|
||||
// to: '/chat',
|
||||
items: chatItems,
|
||||
icon: <IconComment />,
|
||||
className: localStorage.getItem('chat_link')
|
||||
? 'semi-navigation-item-normal'
|
||||
: 'tableHiddle',
|
||||
// className: localStorage.getItem('chat_link')
|
||||
// ? 'semi-navigation-item-normal'
|
||||
// : 'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '令牌',
|
||||
@@ -174,7 +174,7 @@ const SiderBar = () => {
|
||||
localStorage.getItem('enable_data_export'),
|
||||
localStorage.getItem('enable_drawing'),
|
||||
localStorage.getItem('enable_task'),
|
||||
localStorage.getItem('chat_link'),
|
||||
localStorage.getItem('chat_link'), chatItems,
|
||||
isAdmin(),
|
||||
],
|
||||
);
|
||||
@@ -205,6 +205,33 @@ const SiderBar = () => {
|
||||
localKey = 'home';
|
||||
}
|
||||
setSelectedKeys([localKey]);
|
||||
let chatLink = localStorage.getItem('chat_link');
|
||||
if (!chatLink) {
|
||||
let chats = localStorage.getItem('chats');
|
||||
if (chats) {
|
||||
// console.log(chats);
|
||||
try {
|
||||
chats = JSON.parse(chats);
|
||||
if (Array.isArray(chats)) {
|
||||
let chatItems = [];
|
||||
for (let i = 0; i < chats.length; i++) {
|
||||
let chat = {};
|
||||
for (let key in chats[i]) {
|
||||
chat.text = key;
|
||||
chat.itemKey = 'chat' + i;
|
||||
chat.to = '/chat/' + i;
|
||||
}
|
||||
// setRouterMap({ ...routerMap, chat: '/chat/' + i })
|
||||
chatItems.push(chat);
|
||||
}
|
||||
setChatItems(chatItems);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
showError('聊天数据解析失败')
|
||||
}
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
@@ -221,6 +248,27 @@ const SiderBar = () => {
|
||||
}}
|
||||
selectedKeys={selectedKeys}
|
||||
renderWrapper={({ itemElement, isSubNav, isInSubNav, props }) => {
|
||||
let chatLink = localStorage.getItem('chat_link');
|
||||
if (!chatLink) {
|
||||
let chats = localStorage.getItem('chats');
|
||||
if (chats) {
|
||||
chats = JSON.parse(chats);
|
||||
if (Array.isArray(chats) && chats.length > 0) {
|
||||
for (let i = 0; i < chats.length; i++) {
|
||||
routerMap['chat' + i] = '/chat/' + i;
|
||||
}
|
||||
if (chats.length > 1) {
|
||||
// delete /chat
|
||||
if (routerMap['chat']) {
|
||||
delete routerMap['chat'];
|
||||
}
|
||||
} else {
|
||||
// rename /chat to /chat/0
|
||||
routerMap['chat'] = '/chat/0';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return (
|
||||
<Link
|
||||
style={{ textDecoration: 'none' }}
|
||||
@@ -234,15 +282,6 @@ const SiderBar = () => {
|
||||
onSelect={(key) => {
|
||||
setSelectedKeys([key.itemKey]);
|
||||
}}
|
||||
// header={{
|
||||
// logo: (
|
||||
// <img src={logo} alt='logo' style={{ marginRight: '0.75em' }} />
|
||||
// ),
|
||||
// text: systemName,
|
||||
// }}
|
||||
// footer={{
|
||||
// text: '© 2021 NekoAPI',
|
||||
// }}
|
||||
footer={
|
||||
<>
|
||||
{isMobile() && (
|
||||
|
||||
@@ -24,17 +24,6 @@ import {
|
||||
import { IconTreeTriangleDown } from '@douyinfe/semi-icons';
|
||||
import EditToken from '../pages/Token/EditToken';
|
||||
|
||||
const COPY_OPTIONS = [
|
||||
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
|
||||
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||
];
|
||||
|
||||
const OPEN_LINK_OPTIONS = [
|
||||
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||
];
|
||||
|
||||
function renderTimestamp(timestamp) {
|
||||
return <>{timestamp2string(timestamp)}</>;
|
||||
}
|
||||
@@ -87,27 +76,6 @@ function renderStatus(status, model_limits_enabled = false) {
|
||||
}
|
||||
|
||||
const TokensTable = () => {
|
||||
const link_menu = [
|
||||
{
|
||||
node: 'item',
|
||||
key: 'next',
|
||||
name: 'ChatGPT Next Web',
|
||||
onClick: () => {
|
||||
onOpenLink('next');
|
||||
},
|
||||
},
|
||||
{ node: 'item', key: 'ama', name: 'AMA 问天', value: 'ama' },
|
||||
{
|
||||
node: 'item',
|
||||
key: 'next-mj',
|
||||
name: 'ChatGPT Web & Midjourney',
|
||||
value: 'next-mj',
|
||||
onClick: () => {
|
||||
onOpenLink('next-mj');
|
||||
},
|
||||
},
|
||||
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' },
|
||||
];
|
||||
|
||||
const columns = [
|
||||
{
|
||||
@@ -174,149 +142,171 @@ const TokensTable = () => {
|
||||
{
|
||||
title: '',
|
||||
dataIndex: 'operate',
|
||||
render: (text, record, index) => (
|
||||
<div>
|
||||
<Popover
|
||||
content={'sk-' + record.key}
|
||||
style={{ padding: 20 }}
|
||||
position='top'
|
||||
>
|
||||
<Button theme='light' type='tertiary' style={{ marginRight: 1 }}>
|
||||
查看
|
||||
</Button>
|
||||
</Popover>
|
||||
<Button
|
||||
theme='light'
|
||||
type='secondary'
|
||||
style={{ marginRight: 1 }}
|
||||
onClick={async (text) => {
|
||||
await copyText('sk-' + record.key);
|
||||
}}
|
||||
>
|
||||
复制
|
||||
</Button>
|
||||
<SplitButtonGroup
|
||||
style={{ marginRight: 1 }}
|
||||
aria-label='项目操作按钮组'
|
||||
>
|
||||
<Button
|
||||
theme='light'
|
||||
style={{ color: 'rgba(var(--semi-teal-7), 1)' }}
|
||||
onClick={() => {
|
||||
onOpenLink('next', record.key);
|
||||
}}
|
||||
render: (text, record, index) => {
|
||||
let chats = localStorage.getItem('chats');
|
||||
let chatsArray = []
|
||||
let chatLink = localStorage.getItem('chat_link');
|
||||
let mjLink = localStorage.getItem('chat_link2');
|
||||
let shouldUseCustom = true;
|
||||
if (chatLink) {
|
||||
shouldUseCustom = false;
|
||||
chatLink += `/#/?settings={"key":"{key}","url":"{address}"}`;
|
||||
chatsArray.push({
|
||||
node: 'item',
|
||||
key: 'default',
|
||||
name: 'ChatGPT Next Web',
|
||||
onClick: () => {
|
||||
onOpenLink('default', chatLink, record);
|
||||
},
|
||||
});
|
||||
}
|
||||
if (mjLink) {
|
||||
shouldUseCustom = false;
|
||||
mjLink += `/#/?settings={"key":"{key}","url":"{address}"}`;
|
||||
chatsArray.push({
|
||||
node: 'item',
|
||||
key: 'mj',
|
||||
name: 'ChatGPT Next Midjourney',
|
||||
onClick: () => {
|
||||
onOpenLink('mj', mjLink, record);
|
||||
},
|
||||
});
|
||||
}
|
||||
if (shouldUseCustom) {
|
||||
try {
|
||||
// console.log(chats);
|
||||
chats = JSON.parse(chats);
|
||||
// check chats is array
|
||||
if (Array.isArray(chats)) {
|
||||
for (let i = 0; i < chats.length; i++) {
|
||||
let chat = {}
|
||||
chat.node = 'item';
|
||||
// c is a map
|
||||
// chat.key = chats[i].name;
|
||||
// console.log(chats[i])
|
||||
for (let key in chats[i]) {
|
||||
if (chats[i].hasOwnProperty(key)) {
|
||||
chat.key = i;
|
||||
chat.name = key;
|
||||
chat.onClick = () => {
|
||||
onOpenLink(key, chats[i][key], record);
|
||||
}
|
||||
}
|
||||
}
|
||||
chatsArray.push(chat);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (e) {
|
||||
console.log(e);
|
||||
showError('聊天链接配置错误,请联系管理员');
|
||||
}
|
||||
}
|
||||
return (
|
||||
<div>
|
||||
<Popover
|
||||
content={'sk-' + record.key}
|
||||
style={{ padding: 20 }}
|
||||
position='top'
|
||||
>
|
||||
聊天
|
||||
</Button>
|
||||
<Dropdown
|
||||
trigger='click'
|
||||
position='bottomRight'
|
||||
menu={[
|
||||
{
|
||||
node: 'item',
|
||||
key: 'next',
|
||||
disabled: !localStorage.getItem('chat_link'),
|
||||
name: 'ChatGPT Next Web',
|
||||
onClick: () => {
|
||||
onOpenLink('next', record.key);
|
||||
},
|
||||
},
|
||||
{
|
||||
node: 'item',
|
||||
key: 'next-mj',
|
||||
disabled: !localStorage.getItem('chat_link2'),
|
||||
name: 'ChatGPT Web & Midjourney',
|
||||
onClick: () => {
|
||||
onOpenLink('next-mj', record.key);
|
||||
},
|
||||
},
|
||||
// {
|
||||
// node: 'item',
|
||||
// key: 'lobe',
|
||||
// name: 'Lobe Chat',
|
||||
// onClick: () => {
|
||||
// onOpenLink('lobe', record.key);
|
||||
// },
|
||||
// },
|
||||
{
|
||||
node: 'item',
|
||||
key: 'ama',
|
||||
name: 'AMA 问天(BotGem)',
|
||||
onClick: () => {
|
||||
onOpenLink('ama', record.key);
|
||||
},
|
||||
},
|
||||
{
|
||||
node: 'item',
|
||||
key: 'opencat',
|
||||
name: 'OpenCat',
|
||||
onClick: () => {
|
||||
onOpenLink('opencat', record.key);
|
||||
},
|
||||
},
|
||||
]}
|
||||
>
|
||||
<Button
|
||||
style={{
|
||||
padding: '8px 4px',
|
||||
color: 'rgba(var(--semi-teal-7), 1)',
|
||||
}}
|
||||
type='primary'
|
||||
icon={<IconTreeTriangleDown />}
|
||||
></Button>
|
||||
</Dropdown>
|
||||
</SplitButtonGroup>
|
||||
<Popconfirm
|
||||
title='确定是否要删除此令牌?'
|
||||
content='此修改将不可逆'
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={() => {
|
||||
manageToken(record.id, 'delete', record).then(() => {
|
||||
removeRecord(record.key);
|
||||
});
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
|
||||
删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
{record.status === 1 ? (
|
||||
<Button
|
||||
theme='light'
|
||||
type='warning'
|
||||
style={{ marginRight: 1 }}
|
||||
onClick={async () => {
|
||||
manageToken(record.id, 'disable', record);
|
||||
}}
|
||||
>
|
||||
禁用
|
||||
</Button>
|
||||
) : (
|
||||
<Button theme='light' type='tertiary' style={{ marginRight: 1 }}>
|
||||
查看
|
||||
</Button>
|
||||
</Popover>
|
||||
<Button
|
||||
theme='light'
|
||||
type='secondary'
|
||||
style={{ marginRight: 1 }}
|
||||
onClick={async () => {
|
||||
manageToken(record.id, 'enable', record);
|
||||
onClick={async (text) => {
|
||||
await copyText('sk-' + record.key);
|
||||
}}
|
||||
>
|
||||
启用
|
||||
复制
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
theme='light'
|
||||
type='tertiary'
|
||||
style={{ marginRight: 1 }}
|
||||
onClick={() => {
|
||||
setEditingToken(record);
|
||||
setShowEdit(true);
|
||||
}}
|
||||
>
|
||||
编辑
|
||||
</Button>
|
||||
</div>
|
||||
),
|
||||
<SplitButtonGroup
|
||||
style={{ marginRight: 1 }}
|
||||
aria-label='项目操作按钮组'
|
||||
>
|
||||
<Button
|
||||
theme='light'
|
||||
style={{ color: 'rgba(var(--semi-teal-7), 1)' }}
|
||||
onClick={() => {
|
||||
if (chatsArray.length === 0) {
|
||||
showError('请联系管理员配置聊天链接');
|
||||
} else {
|
||||
onOpenLink('default', chats[0][Object.keys(chats[0])[0]], record);
|
||||
}
|
||||
}}
|
||||
>
|
||||
聊天
|
||||
</Button>
|
||||
<Dropdown
|
||||
trigger='click'
|
||||
position='bottomRight'
|
||||
menu={chatsArray}
|
||||
>
|
||||
<Button
|
||||
style={{
|
||||
padding: '8px 4px',
|
||||
color: 'rgba(var(--semi-teal-7), 1)',
|
||||
}}
|
||||
type='primary'
|
||||
icon={<IconTreeTriangleDown />}
|
||||
></Button>
|
||||
</Dropdown>
|
||||
</SplitButtonGroup>
|
||||
<Popconfirm
|
||||
title='确定是否要删除此令牌?'
|
||||
content='此修改将不可逆'
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={() => {
|
||||
manageToken(record.id, 'delete', record).then(() => {
|
||||
removeRecord(record.key);
|
||||
});
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
|
||||
删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
{record.status === 1 ? (
|
||||
<Button
|
||||
theme='light'
|
||||
type='warning'
|
||||
style={{ marginRight: 1 }}
|
||||
onClick={async () => {
|
||||
manageToken(record.id, 'disable', record);
|
||||
}}
|
||||
>
|
||||
禁用
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
theme='light'
|
||||
type='secondary'
|
||||
style={{ marginRight: 1 }}
|
||||
onClick={async () => {
|
||||
manageToken(record.id, 'enable', record);
|
||||
}}
|
||||
>
|
||||
启用
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
theme='light'
|
||||
type='tertiary'
|
||||
style={{ marginRight: 1 }}
|
||||
onClick={() => {
|
||||
setEditingToken(record);
|
||||
setShowEdit(true);
|
||||
}}
|
||||
>
|
||||
编辑
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
@@ -330,8 +320,7 @@ const TokensTable = () => {
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [searchToken, setSearchToken] = useState('');
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [showTopUpModal, setShowTopUpModal] = useState(false);
|
||||
const [targetTokenIdx, setTargetTokenIdx] = useState(0);
|
||||
const [chats, setChats] = useState([]);
|
||||
const [editingToken, setEditingToken] = useState({
|
||||
id: undefined,
|
||||
});
|
||||
@@ -376,16 +365,6 @@ const TokensTable = () => {
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const onPaginationChange = (e, { activePage }) => {
|
||||
(async () => {
|
||||
if (activePage === Math.ceil(tokens.length / pageSize) + 1) {
|
||||
// In this case we have to load more data and then append them.
|
||||
await loadTokens(activePage - 1);
|
||||
}
|
||||
setActivePage(activePage);
|
||||
})();
|
||||
};
|
||||
|
||||
const refresh = async () => {
|
||||
await loadTokens(activePage - 1);
|
||||
};
|
||||
@@ -402,7 +381,8 @@ const TokensTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const onOpenLink = async (type, key) => {
|
||||
const onOpenLink = async (type, url, record) => {
|
||||
// console.log(type, url, key);
|
||||
let status = localStorage.getItem('status');
|
||||
let serverAddress = '';
|
||||
if (status) {
|
||||
@@ -413,36 +393,39 @@ const TokensTable = () => {
|
||||
serverAddress = window.location.origin;
|
||||
}
|
||||
let encodedServerAddress = encodeURIComponent(serverAddress);
|
||||
const chatLink = localStorage.getItem('chat_link');
|
||||
const mjLink = localStorage.getItem('chat_link2');
|
||||
let defaultUrl;
|
||||
|
||||
if (chatLink) {
|
||||
defaultUrl =
|
||||
chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
}
|
||||
let url;
|
||||
switch (type) {
|
||||
case 'ama':
|
||||
url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`;
|
||||
break;
|
||||
case 'opencat':
|
||||
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
|
||||
break;
|
||||
case 'lobe':
|
||||
url = `https://chat-preview.lobehub.com/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${encodedServerAddress}/v1"}}}`;
|
||||
break;
|
||||
case 'next-mj':
|
||||
url =
|
||||
mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
break;
|
||||
default:
|
||||
if (!chatLink) {
|
||||
showError('管理员未设置聊天链接');
|
||||
return;
|
||||
}
|
||||
url = defaultUrl;
|
||||
}
|
||||
url = url.replaceAll('{address}', encodedServerAddress);
|
||||
url = url.replaceAll('{key}', 'sk-' + record.key);
|
||||
// console.log(url);
|
||||
// const chatLink = localStorage.getItem('chat_link');
|
||||
// const mjLink = localStorage.getItem('chat_link2');
|
||||
// let defaultUrl;
|
||||
//
|
||||
// if (chatLink) {
|
||||
// defaultUrl =
|
||||
// chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
// }
|
||||
// let url;
|
||||
// switch (type) {
|
||||
// case 'ama':
|
||||
// url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`;
|
||||
// break;
|
||||
// case 'opencat':
|
||||
// url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
|
||||
// break;
|
||||
// case 'lobe':
|
||||
// url = `https://chat-preview.lobehub.com/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${encodedServerAddress}/v1"}}}`;
|
||||
// break;
|
||||
// case 'next-mj':
|
||||
// url =
|
||||
// mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
// break;
|
||||
// default:
|
||||
// if (!chatLink) {
|
||||
// showError('管理员未设置聊天链接');
|
||||
// return;
|
||||
// }
|
||||
// url = defaultUrl;
|
||||
// }
|
||||
|
||||
window.open(url, '_blank');
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@ import { API, showError } from '../helpers';
|
||||
|
||||
async function fetchTokenKeys() {
|
||||
try {
|
||||
const response = await API.get('/api/token/?p=0&size=999');
|
||||
const response = await API.get('/api/token/?p=0&size=100');
|
||||
const { success, data } = response.data;
|
||||
if (success) {
|
||||
const activeTokens = data.filter((token) => token.status === 1);
|
||||
@@ -38,9 +38,9 @@ function getServerAddress() {
|
||||
return serverAddress;
|
||||
}
|
||||
|
||||
export function useTokenKeys() {
|
||||
export function useTokenKeys(id) {
|
||||
const [keys, setKeys] = useState([]);
|
||||
const [chatLink, setChatLink] = useState('');
|
||||
// const [chatLink, setChatLink] = useState('');
|
||||
const [serverAddress, setServerAddress] = useState('');
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
@@ -55,9 +55,7 @@ export function useTokenKeys() {
|
||||
}
|
||||
setKeys(fetchedKeys);
|
||||
setIsLoading(false);
|
||||
|
||||
const link = localStorage.getItem('chat_link');
|
||||
setChatLink(link);
|
||||
// setChatLink(link);
|
||||
|
||||
const address = getServerAddress();
|
||||
setServerAddress(address);
|
||||
@@ -66,5 +64,5 @@ export function useTokenKeys() {
|
||||
loadAllData();
|
||||
}, []);
|
||||
|
||||
return { keys, chatLink, serverAddress, isLoading };
|
||||
return { keys, serverAddress, isLoading };
|
||||
}
|
||||
@@ -109,6 +109,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
|
||||
{ key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
|
||||
{ key: 40, text: 'SiliconCloud', value: 40, color: 'purple', label: 'SiliconCloud' },
|
||||
{ key: 42, text: 'Mistral AI', value: 42, color: 'blue', label: 'Mistral AI' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
||||
{
|
||||
key: 22,
|
||||
|
||||
@@ -8,6 +8,7 @@ export function setStatusData(data) {
|
||||
localStorage.setItem('enable_drawing', data.enable_drawing);
|
||||
localStorage.setItem('enable_task', data.enable_task);
|
||||
localStorage.setItem('enable_data_export', data.enable_data_export);
|
||||
localStorage.setItem('chats', JSON.stringify(data.chats));
|
||||
localStorage.setItem(
|
||||
'data_export_default_time',
|
||||
data.data_export_default_time,
|
||||
|
||||
@@ -59,6 +59,17 @@ body {
|
||||
display: revert;
|
||||
}
|
||||
|
||||
.semi-chat {
|
||||
padding-top: 0 !important;
|
||||
padding-bottom: 0 !important;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.semi-chat-chatBox-content {
|
||||
min-width: auto;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tableHiddle {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
@@ -1,13 +1,32 @@
|
||||
import React from 'react';
|
||||
import React, {useEffect} from 'react';
|
||||
import { useTokenKeys } from '../../components/fetchTokenKeys';
|
||||
import { Layout } from '@douyinfe/semi-ui';
|
||||
import {Banner, Layout} from '@douyinfe/semi-ui';
|
||||
import { useParams } from 'react-router-dom';
|
||||
|
||||
const ChatPage = () => {
|
||||
const { keys, chatLink, serverAddress, isLoading } = useTokenKeys();
|
||||
const { id } = useParams();
|
||||
const { keys, serverAddress, isLoading } = useTokenKeys(id);
|
||||
|
||||
const comLink = (key) => {
|
||||
if (!chatLink || !serverAddress || !key) return '';
|
||||
return `${chatLink}/#/?settings={"key":"sk-${key}","url":"${encodeURIComponent(serverAddress)}"}`;
|
||||
// console.log('chatLink:', chatLink);
|
||||
if (!serverAddress || !key) return '';
|
||||
let link = localStorage.getItem('chat_link');
|
||||
if (link) {
|
||||
link = `${link}/#/?settings={"key":"sk-${key}","url":"${encodeURIComponent(serverAddress)}"}`;
|
||||
} else if (id) {
|
||||
let chats = localStorage.getItem('chats');
|
||||
if (chats) {
|
||||
chats = JSON.parse(chats);
|
||||
if (Array.isArray(chats) && chats.length > 0) {
|
||||
for (let k in chats[id]) {
|
||||
link = chats[id][k];
|
||||
link = link.replaceAll('{address}', encodeURIComponent(serverAddress));
|
||||
link = link.replaceAll('{key}', 'sk-' + key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return link;
|
||||
};
|
||||
|
||||
const iframeSrc = keys.length > 0 ? comLink(keys[0]) : '';
|
||||
@@ -22,10 +41,10 @@ const ChatPage = () => {
|
||||
<div>
|
||||
<Layout>
|
||||
<Layout.Header>
|
||||
<h3 style={{ color: 'red'}}>
|
||||
当前没有可用的已启用令牌,请确认是否有令牌处于启用状态!<br />
|
||||
正在跳转......
|
||||
</h3>
|
||||
<Banner
|
||||
description={"正在跳转......"}
|
||||
type={"warning"}
|
||||
/>
|
||||
</Layout.Header>
|
||||
</Layout>
|
||||
</div>
|
||||
|
||||
148
web/src/pages/Setting/Operation/SettingsChats.js
Normal file
148
web/src/pages/Setting/Operation/SettingsChats.js
Normal file
@@ -0,0 +1,148 @@
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import { Banner, Button, Col, Form, Popconfirm, Row, Space, Spin } from '@douyinfe/semi-ui';
|
||||
import {
|
||||
compareObjects,
|
||||
API,
|
||||
showError,
|
||||
showSuccess,
|
||||
showWarning,
|
||||
verifyJSON,
|
||||
verifyJSONPromise
|
||||
} from '../../../helpers';
|
||||
|
||||
export default function SettingsChats(props) {
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [inputs, setInputs] = useState({
|
||||
Chats: "[]",
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
|
||||
async function onSubmit() {
|
||||
try {
|
||||
console.log('Starting validation...');
|
||||
await refForm.current.validate().then(() => {
|
||||
console.log('Validation passed');
|
||||
const updateArray = compareObjects(inputs, inputsRow);
|
||||
if (!updateArray.length) return showWarning('你似乎并没有修改什么');
|
||||
const requestQueue = updateArray.map((item) => {
|
||||
let value = '';
|
||||
if (typeof inputs[item.key] === 'boolean') {
|
||||
value = String(inputs[item.key]);
|
||||
} else {
|
||||
value = inputs[item.key];
|
||||
}
|
||||
return API.put('/api/option/', {
|
||||
key: item.key,
|
||||
value
|
||||
});
|
||||
});
|
||||
setLoading(true);
|
||||
Promise.all(requestQueue)
|
||||
.then((res) => {
|
||||
if (requestQueue.length === 1) {
|
||||
if (res.includes(undefined)) return;
|
||||
} else if (requestQueue.length > 1) {
|
||||
if (res.includes(undefined))
|
||||
return showError('部分保存失败,请重试');
|
||||
}
|
||||
showSuccess('保存成功');
|
||||
props.refresh();
|
||||
})
|
||||
.catch(() => {
|
||||
showError('保存失败,请重试');
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
});
|
||||
}).catch((error) => {
|
||||
console.error('Validation failed:', error);
|
||||
showError('请检查输入');
|
||||
});
|
||||
} catch (error) {
|
||||
showError('请检查输入');
|
||||
console.error(error);
|
||||
}
|
||||
}
|
||||
|
||||
async function resetModelRatio() {
|
||||
try {
|
||||
let res = await API.post(`/api/option/rest_model_ratio`);
|
||||
// return {success, message}
|
||||
if (res.data.success) {
|
||||
showSuccess(res.data.message);
|
||||
props.refresh();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(error);
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const currentInputs = {};
|
||||
for (let key in props.options) {
|
||||
if (Object.keys(inputs).includes(key)) {
|
||||
if (key === 'Chats') {
|
||||
const obj = JSON.parse(props.options[key]);
|
||||
currentInputs[key] = JSON.stringify(obj, null, 2);
|
||||
} else {
|
||||
currentInputs[key] = props.options[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
setInputs(currentInputs);
|
||||
setInputsRow(structuredClone(currentInputs));
|
||||
refForm.current.setValues(currentInputs);
|
||||
}, [props.options]);
|
||||
|
||||
return (
|
||||
<Spin spinning={loading}>
|
||||
<Form
|
||||
values={inputs}
|
||||
getFormApi={(formAPI) => (refForm.current = formAPI)}
|
||||
style={{ marginBottom: 15 }}
|
||||
>
|
||||
<Form.Section text={'令牌聊天设置'}>
|
||||
<Banner
|
||||
type='warning'
|
||||
description={'必须将上方聊天链接全部设置为空,才能使用下方聊天设置功能'}
|
||||
/>
|
||||
<Banner
|
||||
type='info'
|
||||
description={'链接中的{key}将自动替换为sk-xxxx,{address}将自动替换为系统设置的服务器地址,末尾不带/和/v1'}
|
||||
/>
|
||||
<Form.TextArea
|
||||
label={'聊天配置'}
|
||||
extraText={''}
|
||||
placeholder={'为一个 JSON 文本'}
|
||||
field={'Chats'}
|
||||
autosize={{ minRows: 6, maxRows: 12 }}
|
||||
trigger='blur'
|
||||
stopValidateWithError
|
||||
rules={[
|
||||
{
|
||||
validator: (rule, value) => {
|
||||
return verifyJSON(value);
|
||||
},
|
||||
message: '不是合法的 JSON 字符串'
|
||||
}
|
||||
]}
|
||||
onChange={(value) =>
|
||||
setInputs({
|
||||
...inputs,
|
||||
Chats: value
|
||||
})
|
||||
}
|
||||
/>
|
||||
</Form.Section>
|
||||
</Form>
|
||||
<Space>
|
||||
<Button onClick={onSubmit}>
|
||||
保存聊天设置
|
||||
</Button>
|
||||
</Space>
|
||||
</Spin>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
|
||||
import { Banner, Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
|
||||
import {
|
||||
compareObjects,
|
||||
API,
|
||||
@@ -74,6 +74,10 @@ export default function GeneralSettings(props) {
|
||||
return (
|
||||
<>
|
||||
<Spin spinning={loading}>
|
||||
<Banner
|
||||
type='warning'
|
||||
description={'聊天链接功能已经弃用,请使用下方聊天设置功能'}
|
||||
/>
|
||||
<Form
|
||||
values={inputs}
|
||||
getFormApi={(formAPI) => (refForm.current = formAPI)}
|
||||
|
||||
@@ -49,7 +49,7 @@ const EditToken = (props) => {
|
||||
group
|
||||
} = inputs;
|
||||
// const [visible, setVisible] = useState(false);
|
||||
const [models, setModels] = useState({});
|
||||
const [models, setModels] = useState([]);
|
||||
const [groups, setGroups] = useState([]);
|
||||
const navigate = useNavigate();
|
||||
const handleInputChange = (name, value) => {
|
||||
@@ -92,7 +92,7 @@ const EditToken = (props) => {
|
||||
};
|
||||
|
||||
const loadGroups = async () => {
|
||||
let res = await API.get(`/api/user/groups`);
|
||||
let res = await API.get(`/api/user/self/groups`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
// return data is a map, key is group name, value is group description
|
||||
|
||||
@@ -55,6 +55,10 @@ export default defineConfig({
|
||||
target: 'http://localhost:3000',
|
||||
changeOrigin: true,
|
||||
},
|
||||
'/pg': {
|
||||
target: 'http://localhost:3000',
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user