mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
48 Commits
v0.2.2.0-a
...
v0.2.2.0-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acbc3649d6 | ||
|
|
5715fcf8fb | ||
|
|
21839ed13b | ||
|
|
71547849bc | ||
|
|
39f6812a2b | ||
|
|
5ac3d25f54 | ||
|
|
fd19798c92 | ||
|
|
12667ad17d | ||
|
|
e8800415b8 | ||
|
|
ecd06cf2f8 | ||
|
|
db575a1c25 | ||
|
|
2dbf50dc07 | ||
|
|
d8c006046f | ||
|
|
b427f0278f | ||
|
|
6fb1fbfe96 | ||
|
|
4641d44615 | ||
|
|
1cff3c100a | ||
|
|
d7a343e2f6 | ||
|
|
637801fba5 | ||
|
|
2bf404507f | ||
|
|
675de89c69 | ||
|
|
16b9aacb06 | ||
|
|
cad380eb16 | ||
|
|
234e39ddeb | ||
|
|
7fb6420e66 | ||
|
|
5425b5bfc3 | ||
|
|
21f32605c8 | ||
|
|
1c6fd87909 | ||
|
|
d1c8947851 | ||
|
|
7d2d525051 | ||
|
|
be4809b95a | ||
|
|
e2edd5e7e5 | ||
|
|
a14fa1adb1 | ||
|
|
2cb10b003a | ||
|
|
86b17fcce8 | ||
|
|
08b5336431 | ||
|
|
20aaf30785 | ||
|
|
bfcaccc2e3 | ||
|
|
3f448ba4fc | ||
|
|
408c2bdd9b | ||
|
|
b1b38a6bd4 | ||
|
|
608ec28761 | ||
|
|
a3ccc92f55 | ||
|
|
77e7d11151 | ||
|
|
783e8fd74a | ||
|
|
2841669246 | ||
|
|
89ebd85503 | ||
|
|
79cf70683f |
@@ -207,6 +207,9 @@ const (
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@@ -244,5 +247,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
|
||||
"https://api.cohere.ai", //34
|
||||
}
|
||||
|
||||
@@ -16,7 +16,22 @@ func SafeGoroutine(f func()) {
|
||||
}()
|
||||
}
|
||||
|
||||
func SafeSend(ch chan bool, value bool) (closed bool) {
|
||||
func SafeSendBool(ch chan bool, value bool) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
closed = true
|
||||
}
|
||||
}()
|
||||
|
||||
// This will panic if the channel is closed.
|
||||
ch <- value
|
||||
|
||||
// If the code reaches here, then the channel was not closed.
|
||||
return false
|
||||
}
|
||||
|
||||
func SafeSendString(ch chan string, value string) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
@@ -98,3 +99,13 @@ func LogQuota(quota int) string {
|
||||
return fmt.Sprintf("%d 点额度", quota)
|
||||
}
|
||||
}
|
||||
|
||||
// LogJson 仅供测试使用 only for test
|
||||
func LogJson(ctx context.Context, msg string, obj any) {
|
||||
jsonStr, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ var DefaultModelRatio = map[string]float64{
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
|
||||
//"gpt-3.5-turbo-0301": 0.75, //deprecated
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
@@ -60,8 +61,6 @@ var DefaultModelRatio = map[string]float64{
|
||||
"text-search-ada-doc-001": 10,
|
||||
"text-moderation-stable": 0.1,
|
||||
"text-moderation-latest": 0.1,
|
||||
"dall-e-2": 8,
|
||||
"dall-e-3": 16,
|
||||
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
@@ -102,12 +101,22 @@ var DefaultModelRatio = map[string]float64{
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||
// 已经按照 7.2 来换算美元价格
|
||||
"yi-34b-chat-0205": 0.018,
|
||||
"yi-34b-chat-200k": 0.0864,
|
||||
"yi-vl-plus": 0.0432,
|
||||
"yi-34b-chat-0205": 0.018,
|
||||
"yi-34b-chat-200k": 0.0864,
|
||||
"yi-vl-plus": 0.0432,
|
||||
"command": 0.5,
|
||||
"command-nightly": 0.5,
|
||||
"command-light": 0.5,
|
||||
"command-light-nightly": 0.5,
|
||||
"command-r": 0.25,
|
||||
"command-r-plus ": 1.5,
|
||||
"deepseek-chat": 0.07,
|
||||
"deepseek-coder": 0.07,
|
||||
}
|
||||
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"dall-e-2": 0.02,
|
||||
"dall-e-3": 0.04,
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
@@ -129,6 +138,12 @@ var DefaultModelPrice = map[string]float64{
|
||||
var modelPrice map[string]float64 = nil
|
||||
var modelRatio map[string]float64 = nil
|
||||
|
||||
var CompletionRatio map[string]float64 = nil
|
||||
var DefaultCompletionRatio = map[string]float64{
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"gpt-4-all": 2,
|
||||
}
|
||||
|
||||
func ModelPrice2JSONString() string {
|
||||
if modelPrice == nil {
|
||||
modelPrice = DefaultModelPrice
|
||||
@@ -145,7 +160,8 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||
return json.Unmarshal([]byte(jsonStr), &modelPrice)
|
||||
}
|
||||
|
||||
func GetModelPrice(name string, printErr bool) float64 {
|
||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
||||
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||
if modelPrice == nil {
|
||||
modelPrice = DefaultModelPrice
|
||||
}
|
||||
@@ -157,9 +173,16 @@ func GetModelPrice(name string, printErr bool) float64 {
|
||||
if printErr {
|
||||
SysError("model price not found: " + name)
|
||||
}
|
||||
return -1
|
||||
return -1, false
|
||||
}
|
||||
return price
|
||||
return price, true
|
||||
}
|
||||
|
||||
func GetModelPrices() map[string]float64 {
|
||||
if modelPrice == nil {
|
||||
modelPrice = DefaultModelPrice
|
||||
}
|
||||
return modelPrice
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
@@ -193,6 +216,29 @@ func GetModelRatio(name string) float64 {
|
||||
return ratio
|
||||
}
|
||||
|
||||
func GetModelRatios() map[string]float64 {
|
||||
if modelRatio == nil {
|
||||
modelRatio = DefaultModelRatio
|
||||
}
|
||||
return modelRatio
|
||||
}
|
||||
|
||||
func CompletionRatio2JSONString() string {
|
||||
if CompletionRatio == nil {
|
||||
CompletionRatio = DefaultCompletionRatio
|
||||
}
|
||||
jsonBytes, err := json.Marshal(CompletionRatio)
|
||||
if err != nil {
|
||||
SysError("error marshalling completion ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
||||
CompletionRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
||||
}
|
||||
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
|
||||
@@ -205,7 +251,7 @@ func GetCompletionRatio(name string) float64 {
|
||||
}
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") {
|
||||
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && !strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
|
||||
return 3
|
||||
}
|
||||
@@ -224,9 +270,36 @@ func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gemini-") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "command") {
|
||||
switch name {
|
||||
case "command-r":
|
||||
return 3
|
||||
case "command-r-plus":
|
||||
return 5
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(name, "deepseek") {
|
||||
return 2
|
||||
}
|
||||
switch name {
|
||||
case "llama2-70b-4096":
|
||||
return 0.8 / 0.7
|
||||
return 0.8 / 0.64
|
||||
case "llama3-8b-8192":
|
||||
return 2
|
||||
case "llama3-70b-8192":
|
||||
return 0.79 / 0.59
|
||||
}
|
||||
if ratio, ok := CompletionRatio[name]; ok {
|
||||
return ratio
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func GetCompletionRatios() map[string]float64 {
|
||||
if CompletionRatio == nil {
|
||||
CompletionRatio = DefaultCompletionRatio
|
||||
}
|
||||
return CompletionRatio
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"html/template"
|
||||
@@ -241,3 +242,19 @@ func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
func MapToJsonStr(m map[string]interface{}) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func MapToJsonStrFloat(m map[string]float64) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package constant
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType := constant.ChannelType2APIType(channel.Type)
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
@@ -208,7 +208,7 @@ func testAllChannels(notify bool) error {
|
||||
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
|
||||
@@ -86,7 +86,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
continue
|
||||
}
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
timeout := time.Second * 15
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
@@ -147,7 +147,7 @@ func SendEmailVerification(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if common.EmailAliasRestrictionEnabled {
|
||||
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
|
||||
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
|
||||
if containsSpecialSymbols {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -4,49 +4,27 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
|
||||
type OpenAIModelPermission struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group *string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
}
|
||||
var openAIModels []dto.OpenAIModels
|
||||
var openAIModelsMap map[string]dto.OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
}
|
||||
|
||||
var openAIModels []OpenAIModels
|
||||
var openAIModelsMap map[string]OpenAIModels
|
||||
|
||||
func init() {
|
||||
var permission []OpenAIModelPermission
|
||||
permission = append(permission, OpenAIModelPermission{
|
||||
func getPermission() []dto.OpenAIModelPermission {
|
||||
var permission []dto.OpenAIModelPermission
|
||||
permission = append(permission, dto.OpenAIModelPermission{
|
||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
||||
Object: "model_permission",
|
||||
Created: 1626777600,
|
||||
@@ -60,7 +38,12 @@ func init() {
|
||||
Group: nil,
|
||||
IsBlocking: false,
|
||||
})
|
||||
return permission
|
||||
}
|
||||
|
||||
func init() {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
permission := getPermission()
|
||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||
if i == relayconstant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
@@ -69,7 +52,7 @@ func init() {
|
||||
channelName := adaptor.GetChannelName()
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@@ -81,18 +64,18 @@ func init() {
|
||||
}
|
||||
}
|
||||
for _, modelName := range ai360.ModelList {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "360",
|
||||
OwnedBy: ai360.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
for _, modelName := range moonshot.ModelList {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@@ -103,7 +86,7 @@ func init() {
|
||||
})
|
||||
}
|
||||
for _, modelName := range lingyiwanwu.ModelList {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@@ -114,7 +97,7 @@ func init() {
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@@ -124,10 +107,21 @@ func init() {
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]OpenAIModels)
|
||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||
for _, model := range openAIModels {
|
||||
openAIModelsMap[model.Id] = model
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
adaptor.Init(meta, dto.GeneralOpenAIRequest{})
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
@@ -141,22 +135,40 @@ func ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
models := model.GetGroupModels(user.Group)
|
||||
userOpenAiModels := make([]OpenAIModels, 0)
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
permission := getPermission()
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: s,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: s,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": userOpenAiModels,
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
})
|
||||
}
|
||||
|
||||
func ChannelListModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": openAIModels,
|
||||
"success": true,
|
||||
"data": openAIModels,
|
||||
})
|
||||
}
|
||||
|
||||
func DashboardListModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": channelId2Models,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -176,3 +188,18 @@ func RetrieveModel(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
user, _ := model.GetUserById(userId, true)
|
||||
groupRatio := common.GetGroupRatio("default")
|
||||
if user != nil {
|
||||
groupRatio = common.GetGroupRatio(user.Group)
|
||||
}
|
||||
pricing := model.GetPricing(user, openAIModels)
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": groupRatio,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -109,6 +109,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == 408 {
|
||||
// azure处理超时不重试
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
return false
|
||||
}
|
||||
@@ -120,7 +124,7 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||
channelName := c.GetString("channel_name")
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
@@ -156,7 +160,7 @@ func RelayMidjourney(c *gin.Context) {
|
||||
"code": err.Code,
|
||||
})
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -216,7 +216,8 @@ func GetAllUsers(c *gin.Context) {
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
users, err := model.SearchUsers(keyword)
|
||||
group := c.Query("group")
|
||||
users, err := model.SearchUsers(keyword, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -452,7 +453,7 @@ func UpdateUser(c *gin.Context) {
|
||||
updatedUser.Password = "" // rollback to what it should be
|
||||
}
|
||||
updatePassword := updatedUser.Password != ""
|
||||
if err := updatedUser.Update(updatePassword); err != nil {
|
||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
@@ -725,7 +726,7 @@ func ManageUser(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
}
|
||||
|
||||
if err := user.UpdateAll(false); err != nil {
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
37
dto/pricing.go
Normal file
37
dto/pricing.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package dto
|
||||
|
||||
type OpenAIModelPermission struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group *string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
}
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
}
|
||||
|
||||
type ModelPricing struct {
|
||||
Available bool `json:"available"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_group,omitempty"`
|
||||
}
|
||||
@@ -32,6 +32,21 @@ type GeneralOpenAIRequest struct {
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
Type string `json:"type"`
|
||||
Function OpenAIFunction `json:"function"`
|
||||
}
|
||||
|
||||
type OpenAIFunction struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||
return int64(r.MaxTokens)
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
|
||||
@@ -54,17 +54,33 @@ type OpenAIEmbeddingResponse struct {
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
||||
Logprobs *any `json:"logprobs"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoiceDelta struct {
|
||||
Content string `json:"content"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
|
||||
return c.Content == nil && len(c.ToolCalls) == 0
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||
c.Content = &s
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
|
||||
if c.Content == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.Content
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
// Index is not nil only in chat completion chunk object
|
||||
Index *int `json:"index,omitempty"`
|
||||
@@ -80,11 +96,12 @@ type FunctionCall struct {
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint *string `json:"system_fingerprint"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseSimple struct {
|
||||
|
||||
@@ -64,6 +64,17 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func TryUserAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
if id != nil {
|
||||
c.Set("id", id)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func UserAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authHelper(c, common.RoleCommonUser)
|
||||
|
||||
@@ -29,6 +29,13 @@ func GetGroupModels(group string) []string {
|
||||
return models
|
||||
}
|
||||
|
||||
func GetEnabledModels() []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
|
||||
return models
|
||||
}
|
||||
|
||||
func getPriority(group string, model string, retry int) (int, error) {
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
|
||||
@@ -24,6 +24,7 @@ type Log struct {
|
||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -57,12 +58,13 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool) {
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username, _ := CacheGetUsername(userId)
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -78,6 +80,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
TokenId: tokenId,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Other: otherStr,
|
||||
}
|
||||
err := DB.Create(log).Error
|
||||
if err != nil {
|
||||
|
||||
@@ -83,6 +83,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -92,6 +93,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
|
||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
@@ -197,6 +199,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.DefaultCollapseSidebar = boolValue
|
||||
case "MjNotifyEnabled":
|
||||
constant.MjNotifyEnabled = boolValue
|
||||
case "MjAccountFilterEnabled":
|
||||
constant.MjAccountFilterEnabled = boolValue
|
||||
case "MjModeClearEnabled":
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
case "MjForwardUrlEnabled":
|
||||
@@ -287,6 +291,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = common.UpdateGroupRatioByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = common.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = common.UpdateModelPriceByJSONString(value)
|
||||
case "TopUpLink":
|
||||
|
||||
72
model/pricing.go
Normal file
72
model/pricing.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
pricingMap []dto.ModelPricing
|
||||
lastGetPricingTime time.Time
|
||||
updatePricingLock sync.Mutex
|
||||
)
|
||||
|
||||
func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
updatePricing(openAIModels)
|
||||
}
|
||||
if user != nil {
|
||||
userPricingMap := make([]dto.ModelPricing, 0)
|
||||
models := GetGroupModels(user.Group)
|
||||
for _, pricing := range pricingMap {
|
||||
if !common.StringsContains(models, pricing.ModelName) {
|
||||
pricing.Available = false
|
||||
}
|
||||
userPricingMap = append(userPricingMap, pricing)
|
||||
}
|
||||
return userPricingMap
|
||||
}
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func updatePricing(openAIModels []dto.OpenAIModels) {
|
||||
modelRatios := common.GetModelRatios()
|
||||
enabledModels := GetEnabledModels()
|
||||
allModels := make(map[string]string)
|
||||
for _, openAIModel := range openAIModels {
|
||||
if common.StringsContains(enabledModels, openAIModel.Id) {
|
||||
allModels[openAIModel.Id] = openAIModel.OwnedBy
|
||||
}
|
||||
}
|
||||
for model, _ := range modelRatios {
|
||||
if common.StringsContains(enabledModels, model) {
|
||||
if _, ok := allModels[model]; !ok {
|
||||
allModels[model] = "custom"
|
||||
}
|
||||
}
|
||||
}
|
||||
pricingMap = make([]dto.ModelPricing, 0)
|
||||
for model, ownerBy := range allModels {
|
||||
pricing := dto.ModelPricing{
|
||||
Available: true,
|
||||
ModelName: model,
|
||||
OwnerBy: ownerBy,
|
||||
}
|
||||
modelPrice, findPrice := common.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
pricing.ModelRatio = common.GetModelRatio(model)
|
||||
pricing.CompletionRatio = common.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
}
|
||||
lastGetPricingTime = time.Now()
|
||||
}
|
||||
@@ -45,6 +45,7 @@ func logQuotaDataCache(userId int, username string, modelName string, quota int,
|
||||
if ok {
|
||||
quotaData.Count += 1
|
||||
quotaData.Quota += quota
|
||||
quotaData.TokenUsed += tokenUsed
|
||||
} else {
|
||||
quotaData = &QuotaData{
|
||||
UserID: userId,
|
||||
|
||||
@@ -73,25 +73,34 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
||||
return users, err
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string) ([]*User, error) {
|
||||
func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
var users []*User
|
||||
var err error
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果转换成功,按照ID搜索用户
|
||||
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
|
||||
// 如果转换成功,按照ID和可选的组别搜索用户
|
||||
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
|
||||
if group != "" {
|
||||
query = query.Where("`group` = ?", group) // 使用反引号包围group
|
||||
}
|
||||
err = query.Find(&users).Error
|
||||
if err != nil || len(users) > 0 {
|
||||
// 如果依据ID找到用户或者发生错误,返回结果或错误
|
||||
return users, err
|
||||
}
|
||||
}
|
||||
|
||||
// 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索
|
||||
err = DB.Unscoped().Omit("password").
|
||||
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
|
||||
Find(&users).Error
|
||||
err = nil
|
||||
|
||||
query := DB.Unscoped().Omit("password")
|
||||
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
err = query.Find(&users).Error
|
||||
|
||||
return users, err
|
||||
}
|
||||
@@ -235,7 +244,7 @@ func (user *User) Update(updatePassword bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) UpdateAll(updatePassword bool) error {
|
||||
func (user *User) Edit(updatePassword bool) error {
|
||||
var err error
|
||||
if updatePassword {
|
||||
user.Password, err = common.Password2Hash(user.Password)
|
||||
@@ -244,8 +253,17 @@ func (user *User) UpdateAll(updatePassword bool) error {
|
||||
}
|
||||
}
|
||||
newUser := *user
|
||||
updates := map[string]interface{}{
|
||||
"username": newUser.Username,
|
||||
"display_name": newUser.DisplayName,
|
||||
"group": newUser.Group,
|
||||
"quota": newUser.Quota,
|
||||
}
|
||||
if updatePassword {
|
||||
updates["password"] = newUser.Password
|
||||
}
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Select("*").Updates(newUser).Error
|
||||
err = DB.Model(user).Updates(updates).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
|
||||
@@ -6,3 +6,5 @@ var ModelList = []string{
|
||||
"embedding_s1_v1",
|
||||
"semantic_similarity_s1_v1",
|
||||
}
|
||||
|
||||
var ChannelName = "ai360"
|
||||
|
||||
@@ -136,7 +136,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
choice.Delta.SetContentString(aliResponse.Output.Text)
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
@@ -199,7 +199,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import "one-api/relay/channel/claude"
|
||||
type AwsClaudeRequest struct {
|
||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
System string `json:"system"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
|
||||
@@ -156,6 +156,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
var usage relaymodel.Usage
|
||||
var id string
|
||||
var model string
|
||||
createdTime := common.GetTimestamp()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
@@ -188,6 +189,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
if response.Model != "" {
|
||||
model = response.Model
|
||||
}
|
||||
response.Created = createdTime
|
||||
response.Id = id
|
||||
response.Model = model
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
choice.Delta.SetContentString(baiduResponse.Result)
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
|
||||
@@ -24,15 +24,16 @@ type ClaudeMessage struct {
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
return "max_tokens"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
@@ -30,15 +30,14 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
if claudeRequest.MaxTokensToSample == 0 {
|
||||
claudeRequest.MaxTokensToSample = 4096
|
||||
}
|
||||
prompt := ""
|
||||
for _, message := range textRequest.Messages {
|
||||
@@ -73,13 +72,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
formatMessages := make([]dto.Message, 0)
|
||||
var lastMessage *dto.Message
|
||||
for i, message := range textRequest.Messages {
|
||||
if message.Role == "system" {
|
||||
if i != 0 {
|
||||
message.Role = "user"
|
||||
}
|
||||
}
|
||||
//if message.Role == "system" {
|
||||
// if i != 0 {
|
||||
// message.Role = "user"
|
||||
// }
|
||||
//}
|
||||
if message.Role == "" {
|
||||
message.Role = "user"
|
||||
textRequest.Messages[i].Role = "user"
|
||||
}
|
||||
fmtMessage := dto.Message{
|
||||
Role: message.Role,
|
||||
@@ -98,13 +97,24 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
fmtMessage.Content = content
|
||||
}
|
||||
formatMessages = append(formatMessages, fmtMessage)
|
||||
lastMessage = &message
|
||||
lastMessage = &textRequest.Messages[i]
|
||||
}
|
||||
|
||||
claudeMessages := make([]ClaudeMessage, 0)
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
claudeRequest.System = message.StringContent()
|
||||
if message.IsStringContent() {
|
||||
claudeRequest.System = message.StringContent()
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
content := ""
|
||||
for _, ctx := range contents {
|
||||
if ctx.Type == "text" {
|
||||
content += ctx.Text
|
||||
}
|
||||
}
|
||||
claudeRequest.System = content
|
||||
}
|
||||
} else {
|
||||
claudeMessage := ClaudeMessage{
|
||||
Role: message.Role,
|
||||
@@ -149,7 +159,6 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
}
|
||||
claudeRequest.Prompt = ""
|
||||
claudeRequest.Messages = claudeMessages
|
||||
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
@@ -161,7 +170,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
choice.Delta.SetContentString(claudeResponse.Completion)
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
@@ -171,9 +180,13 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
claudeUsage = &claudeResponse.Message.Usage
|
||||
choice.Delta.SetContentString("")
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
return nil, nil
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
choice.Index = claudeResponse.Index
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
choice.Delta.SetContentString(claudeResponse.Delta.Text)
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
@@ -182,12 +195,15 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
return nil, nil
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if claudeUsage == nil {
|
||||
claudeUsage = &ClaudeUsage{}
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
|
||||
return &response, claudeUsage
|
||||
}
|
||||
|
||||
|
||||
52
relay/channel/cohere/adaptor.go
Normal file
52
relay/channel/cohere/adaptor.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/cohere/constant.go
Normal file
7
relay/channel/cohere/constant.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package cohere
|
||||
|
||||
var ModelList = []string{
|
||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||
}
|
||||
|
||||
var ChannelName = "cohere"
|
||||
44
relay/channel/cohere/dto.go
Normal file
44
relay/channel/cohere/dto.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package cohere
|
||||
|
||||
type CohereRequest struct {
|
||||
Model string `json:"model"`
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type ChatHistory struct {
|
||||
Role string `json:"role"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type CohereResponse struct {
|
||||
IsFinished bool `json:"is_finished"`
|
||||
EventType string `json:"event_type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Response *CohereResponseResult `json:"response"`
|
||||
}
|
||||
|
||||
type CohereResponseResult struct {
|
||||
ResponseId string `json:"response_id"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Text string `json:"text"`
|
||||
Meta CohereMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type CohereMeta struct {
|
||||
//Tokens CohereTokens `json:"tokens"`
|
||||
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||
}
|
||||
|
||||
type CohereBilledUnits struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type CohereTokens struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
189
relay/channel/cohere/relay-cohere.go
Normal file
189
relay/channel/cohere/relay-cohere.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
cohereReq := CohereRequest{
|
||||
Model: textRequest.Model,
|
||||
ChatHistory: []ChatHistory{},
|
||||
Message: "",
|
||||
Stream: textRequest.Stream,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
}
|
||||
if cohereReq.MaxTokens == 0 {
|
||||
cohereReq.MaxTokens = 4000
|
||||
}
|
||||
for _, msg := range textRequest.Messages {
|
||||
if msg.Role == "user" {
|
||||
cohereReq.Message = msg.StringContent()
|
||||
} else {
|
||||
var role string
|
||||
if msg.Role == "assistant" {
|
||||
role = "CHATBOT"
|
||||
} else if msg.Role == "system" {
|
||||
role = "SYSTEM"
|
||||
} else {
|
||||
role = "USER"
|
||||
}
|
||||
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
|
||||
Role: role,
|
||||
Message: msg.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &cohereReq
|
||||
}
|
||||
|
||||
func stopReasonCohere2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "COMPLETE":
|
||||
return "stop"
|
||||
case "MAX_TOKENS":
|
||||
return "max_tokens"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var cohereResp CohereResponse
|
||||
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
var openaiResp dto.ChatCompletionsStreamResponse
|
||||
openaiResp.Id = responseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion.chunk"
|
||||
openaiResp.Model = modelName
|
||||
if cohereResp.IsFinished {
|
||||
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
||||
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
},
|
||||
}
|
||||
if cohereResp.Response != nil {
|
||||
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
|
||||
}
|
||||
} else {
|
||||
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
Content: &cohereResp.Text,
|
||||
},
|
||||
Index: 0,
|
||||
},
|
||||
}
|
||||
responseText += cohereResp.Text
|
||||
}
|
||||
jsonStr, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
createdTime := common.GetTimestamp()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var cohereResp CohereResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
||||
|
||||
var openaiResp dto.TextResponse
|
||||
openaiResp.Id = cohereResp.ResponseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion"
|
||||
openaiResp.Model = modelName
|
||||
openaiResp.Usage = usage
|
||||
|
||||
content, _ := json.Marshal(cohereResp.Text)
|
||||
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: dto.Message{Content: content, Role: "assistant"},
|
||||
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
|
||||
},
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -151,7 +151,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
@@ -203,7 +203,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
choice.Delta.SetContentString(dummy.Content)
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
|
||||
@@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package ollama
|
||||
|
||||
var ModelList []string
|
||||
var ModelList = []string{
|
||||
"llama3-7b",
|
||||
}
|
||||
|
||||
var ChannelName = "ollama"
|
||||
|
||||
@@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
var toolCount int
|
||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ var ModelList = []string{
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
|
||||
@@ -16,9 +16,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
var responseTextBuilder strings.Builder
|
||||
toolCount := 0
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -49,7 +50,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
common.SafeSendString(dataChan, data)
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
streamItems = append(streamItems, data)
|
||||
@@ -67,8 +68,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
@@ -80,8 +84,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
@@ -116,7 +123,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
// wait data out
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
common.SafeSend(stopChan, true)
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
@@ -135,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
||||
}
|
||||
wg.Wait()
|
||||
return nil, responseTextBuilder.String()
|
||||
return nil, responseTextBuilder.String(), toolCount
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
|
||||
@@ -61,7 +61,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
|
||||
}
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
|
||||
@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
|
||||
@@ -86,7 +86,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.Content
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
|
||||
@@ -87,7 +87,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCo
|
||||
}
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
|
||||
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
choice.Delta.SetContentString(zhipuResponse)
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
@@ -138,7 +138,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStream
|
||||
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.Delta.SetContentString("")
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.RequestId,
|
||||
|
||||
@@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
var toolCount int
|
||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType := constant.ChannelType2APIType(channelType)
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
|
||||
@@ -19,13 +19,16 @@ const (
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
func ChannelType2APIType(channelType int) int {
|
||||
apiType := APITypeOpenAI
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
apiType = APITypeOpenAI
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeAnthropic
|
||||
case common.ChannelTypeBaidu:
|
||||
@@ -52,6 +55,11 @@ func ChannelType2APIType(channelType int) int {
|
||||
apiType = APITypePerplexity
|
||||
case common.ChannelTypeAws:
|
||||
apiType = APITypeAws
|
||||
case common.ChannelTypeCohere:
|
||||
apiType = APITypeCohere
|
||||
}
|
||||
return apiType
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
|
||||
@@ -20,15 +20,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var availableVoices = []string{
|
||||
"alloy",
|
||||
"echo",
|
||||
"fable",
|
||||
"onyx",
|
||||
"nova",
|
||||
"shimmer",
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
@@ -59,9 +50,6 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
if audioRequest.Voice == "" {
|
||||
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
if !common.StringsContains(availableVoices, audioRequest.Voice) {
|
||||
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
var err error
|
||||
promptTokens := 0
|
||||
@@ -100,6 +88,22 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
}
|
||||
|
||||
succeed := false
|
||||
defer func() {
|
||||
if succeed {
|
||||
return
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
// we need to roll back the pre-consumed quota
|
||||
defer func() {
|
||||
go func() {
|
||||
// negative means add quota back for token & user
|
||||
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
@@ -163,6 +167,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return relaycommon.RelayErrorHandler(resp)
|
||||
}
|
||||
succeed = true
|
||||
|
||||
var audioResponse dto.AudioResponse
|
||||
|
||||
@@ -191,7 +196,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||
other := make(map[string]interface{})
|
||||
other["model_ratio"] = modelRatio
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -34,7 +34,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
imageRequest.Model = "dall-e-3"
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
@@ -106,21 +106,26 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(imageRequest.Model)
|
||||
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
|
||||
if !success {
|
||||
modelRatio := common.GetModelRatio(imageRequest.Model)
|
||||
// modelRatio 16 = modelPrice $0.04
|
||||
// per 1 modelRatio = $0.04 / 16
|
||||
modelPrice = 0.0025 * modelRatio
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
|
||||
sizeRatio := 1.0
|
||||
// Size
|
||||
if imageRequest.Size == "256x256" {
|
||||
sizeRatio = 1
|
||||
sizeRatio = 0.4
|
||||
} else if imageRequest.Size == "512x512" {
|
||||
sizeRatio = 1.125
|
||||
sizeRatio = 0.45
|
||||
} else if imageRequest.Size == "1024x1024" {
|
||||
sizeRatio = 1.25
|
||||
sizeRatio = 1
|
||||
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
||||
sizeRatio = 2.5
|
||||
sizeRatio = 2
|
||||
}
|
||||
|
||||
qualityRatio := 1.0
|
||||
@@ -131,7 +136,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
|
||||
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
@@ -186,8 +191,15 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||
quality := "normal"
|
||||
if imageRequest.Quality == "hd" {
|
||||
quality = "hd"
|
||||
}
|
||||
logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -155,9 +155,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelPrice := common.GetModelPrice(modelName, true)
|
||||
modelPrice, success := common.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
if !success {
|
||||
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
@@ -202,7 +202,10 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
@@ -451,9 +454,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(modelName, true)
|
||||
modelPrice, success := common.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
if !success {
|
||||
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
@@ -498,7 +501,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice := common.GetModelPrice(textRequest.Model, false)
|
||||
modelPrice, success := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
@@ -108,7 +108,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if modelPrice == -1 {
|
||||
if !success {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
@@ -178,7 +178,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -189,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
checkSensitive := constant.ShouldCheckPromptSensitive()
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
|
||||
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
||||
case relayconstant.RelayModeModerations:
|
||||
@@ -257,19 +257,19 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
|
||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64) {
|
||||
modelPrice float64, usePrice bool) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
|
||||
quota := 0
|
||||
if modelPrice == -1 {
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if !usePrice {
|
||||
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
|
||||
quota = int(math.Round(float64(quota) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
@@ -279,7 +279,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
totalTokens := promptTokens + completionTokens
|
||||
var logContent string
|
||||
if modelPrice == -1 {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
@@ -315,7 +315,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream)
|
||||
other := make(map[string]interface{})
|
||||
other["model_ratio"] = modelRatio
|
||||
other["group_ratio"] = groupRatio
|
||||
other["completion_ratio"] = completionRatio
|
||||
other["model_price"] = modelPrice
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/relay/channel/aws"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/cohere"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
@@ -48,6 +49,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &perplexity.Adaptor{}
|
||||
case constant.APITypeAws:
|
||||
return &aws.Adaptor{}
|
||||
case constant.APITypeCohere:
|
||||
return &cohere.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,11 +14,13 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||
{
|
||||
apiRouter.GET("/status", controller.GetStatus)
|
||||
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
|
||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
|
||||
apiRouter.GET("/pricing", middleware.CriticalRateLimit(), middleware.TryUserAuth(), controller.GetPricing)
|
||||
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
|
||||
@@ -63,7 +63,7 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool {
|
||||
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
@@ -73,5 +73,8 @@ func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool {
|
||||
if openAIErr != nil {
|
||||
return false
|
||||
}
|
||||
if status != common.ChannelStatusAutoDisabled {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -165,7 +165,9 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
delete(mapResult, "accountFilter")
|
||||
if !constant.MjAccountFilterEnabled {
|
||||
delete(mapResult, "accountFilter")
|
||||
}
|
||||
if !constant.MjNotifyEnabled {
|
||||
delete(mapResult, "notifyHook")
|
||||
}
|
||||
@@ -174,11 +176,11 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
}
|
||||
if constant.MjModeClearEnabled {
|
||||
if prompt, ok := mapResult["prompt"].(string); ok {
|
||||
prompt = strings.Replace(prompt, "--fast", "", -1)
|
||||
prompt = strings.Replace(prompt, "--relax", "", -1)
|
||||
prompt = strings.Replace(prompt, "--turbo", "", -1)
|
||||
|
||||
mapResult["prompt"] = prompt
|
||||
prompt = strings.Replace(prompt, "--fast", "", -1)
|
||||
prompt = strings.Replace(prompt, "--relax", "", -1)
|
||||
prompt = strings.Replace(prompt, "--turbo", "", -1)
|
||||
|
||||
mapResult["prompt"] = prompt
|
||||
}
|
||||
}
|
||||
reqBody, err := json.Marshal(mapResult)
|
||||
|
||||
@@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
||||
return tiles*170 + 85, nil
|
||||
}
|
||||
|
||||
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
||||
tkm := 0
|
||||
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
|
||||
if err != nil {
|
||||
return 0, err, b
|
||||
}
|
||||
tkm += msgTokens
|
||||
if request.Tools != nil {
|
||||
toolsData, _ := json.Marshal(request.Tools)
|
||||
var openaiTools []dto.OpenAITools
|
||||
err := json.Unmarshal(toolsData, &openaiTools)
|
||||
if err != nil {
|
||||
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
|
||||
}
|
||||
countStr := ""
|
||||
for _, tool := range openaiTools {
|
||||
countStr = tool.Function.Name
|
||||
if tool.Function.Description != "" {
|
||||
countStr += tool.Function.Description
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err, _ := CountTokenInput(countStr, model, false)
|
||||
if err != nil {
|
||||
return 0, err, false
|
||||
}
|
||||
tkm += 8
|
||||
tkm += toolTokens
|
||||
}
|
||||
|
||||
return tkm, nil, false
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
@@ -138,48 +173,31 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if len(message.Content) > 0 {
|
||||
var arrayContent []dto.MediaMessage
|
||||
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
||||
return 0, err, false
|
||||
} else {
|
||||
if checkSensitive {
|
||||
contains, words := SensitiveWordContains(stringContent)
|
||||
if contains {
|
||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||
return 0, err, true
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
if message.IsStringContent() {
|
||||
stringContent := message.StringContent()
|
||||
if checkSensitive {
|
||||
contains, words := SensitiveWordContains(stringContent)
|
||||
if contains {
|
||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||
return 0, err, true
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == "image_url" {
|
||||
var imageTokenNum int
|
||||
if model == "glm-4v" {
|
||||
imageTokenNum = 1047
|
||||
} else {
|
||||
if str, ok := m.ImageUrl.(string); ok {
|
||||
imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
|
||||
} else {
|
||||
imageUrlMap := m.ImageUrl.(map[string]interface{})
|
||||
detail, ok := imageUrlMap["detail"]
|
||||
if ok {
|
||||
imageUrlMap["detail"] = detail.(string)
|
||||
} else {
|
||||
imageUrlMap["detail"] = "auto"
|
||||
}
|
||||
imageUrl := dto.MessageImageUrl{
|
||||
Url: imageUrlMap["url"].(string),
|
||||
Detail: imageUrlMap["detail"].(string),
|
||||
}
|
||||
imageTokenNum, err = getImageToken(&imageUrl)
|
||||
}
|
||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||
imageTokenNum, err = getImageToken(&imageUrl)
|
||||
if err != nil {
|
||||
return 0, err, false
|
||||
}
|
||||
@@ -214,7 +232,7 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm, _, _ := CountTokenInput(message.Delta.Content, model, false)
|
||||
tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
|
||||
@@ -22,6 +22,7 @@ import Log from './pages/Log';
|
||||
import Chat from './pages/Chat';
|
||||
import { Layout } from '@douyinfe/semi-ui';
|
||||
import Midjourney from './pages/Midjourney';
|
||||
import Pricing from './pages/Pricing/index.js';
|
||||
// import Detail from './pages/Detail';
|
||||
|
||||
const Home = lazy(() => import('./pages/Home'));
|
||||
@@ -219,6 +220,14 @@ function App() {
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/pricing'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Pricing />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/about'
|
||||
element={
|
||||
|
||||
@@ -31,6 +31,7 @@ import {
|
||||
} from '@douyinfe/semi-ui';
|
||||
import EditChannel from '../pages/Channel/EditChannel';
|
||||
import { IconTreeTriangleDown } from '@douyinfe/semi-icons';
|
||||
import { loadChannelModels } from './utils.js';
|
||||
|
||||
function renderTimestamp(timestamp) {
|
||||
return <>{timestamp2string(timestamp)}</>;
|
||||
@@ -354,27 +355,29 @@ const ChannelsTable = () => {
|
||||
};
|
||||
|
||||
const copySelectedChannel = async (id) => {
|
||||
const channelToCopy = channels.find(channel => String(channel.id) === String(id));
|
||||
console.log(channelToCopy)
|
||||
const channelToCopy = channels.find(
|
||||
(channel) => String(channel.id) === String(id),
|
||||
);
|
||||
console.log(channelToCopy);
|
||||
channelToCopy.name += '_复制';
|
||||
channelToCopy.created_time = null;
|
||||
channelToCopy.balance = 0;
|
||||
channelToCopy.used_quota = 0;
|
||||
if (!channelToCopy) {
|
||||
showError("渠道未找到,请刷新页面后重试。");
|
||||
return;
|
||||
showError('渠道未找到,请刷新页面后重试。');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const newChannel = {...channelToCopy, id: undefined};
|
||||
const response = await API.post('/api/channel/', newChannel);
|
||||
if (response.data.success) {
|
||||
showSuccess("渠道复制成功");
|
||||
await refresh();
|
||||
} else {
|
||||
showError(response.data.message);
|
||||
}
|
||||
const newChannel = { ...channelToCopy, id: undefined };
|
||||
const response = await API.post('/api/channel/', newChannel);
|
||||
if (response.data.success) {
|
||||
showSuccess('渠道复制成功');
|
||||
await refresh();
|
||||
} else {
|
||||
showError(response.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError("渠道复制失败: " + error.message);
|
||||
showError('渠道复制失败: ' + error.message);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -395,6 +398,7 @@ const ChannelsTable = () => {
|
||||
showError(reason);
|
||||
});
|
||||
fetchGroups().then();
|
||||
loadChannelModels().then();
|
||||
}, []);
|
||||
|
||||
const manageChannel = async (id, action, record, value) => {
|
||||
|
||||
@@ -19,6 +19,7 @@ import TelegramLoginButton from 'react-telegram-login';
|
||||
|
||||
import { IconGithubLogo } from '@douyinfe/semi-icons';
|
||||
import WeChatIcon from './WeChatIcon';
|
||||
import { setUserData } from '../helpers/data.js';
|
||||
|
||||
const LoginForm = () => {
|
||||
const [inputs, setInputs] = useState({
|
||||
@@ -99,7 +100,7 @@ const LoginForm = () => {
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
setUserData(data);
|
||||
showSuccess('登录成功!');
|
||||
if (username === 'root' && password === '123456') {
|
||||
Modal.error({
|
||||
|
||||
@@ -19,9 +19,15 @@ import {
|
||||
Spin,
|
||||
Table,
|
||||
Tag,
|
||||
Tooltip,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { ITEMS_PER_PAGE } from '../constants';
|
||||
import { renderNumber, renderQuota, stringToColor } from '../helpers/render';
|
||||
import {
|
||||
renderModelPrice,
|
||||
renderNumber,
|
||||
renderQuota,
|
||||
stringToColor,
|
||||
} from '../helpers/render';
|
||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||
|
||||
const { Header } = Layout;
|
||||
@@ -292,16 +298,42 @@ const LogsTable = () => {
|
||||
title: '详情',
|
||||
dataIndex: 'content',
|
||||
render: (text, record, index) => {
|
||||
if (record.other === '') {
|
||||
return (
|
||||
<Paragraph
|
||||
ellipsis={{
|
||||
rows: 2,
|
||||
showTooltip: {
|
||||
type: 'popover',
|
||||
opts: { style: { width: 240 } },
|
||||
},
|
||||
}}
|
||||
style={{ maxWidth: 240 }}
|
||||
>
|
||||
{text}
|
||||
</Paragraph>
|
||||
);
|
||||
}
|
||||
let other = JSON.parse(record.other);
|
||||
let content = renderModelPrice(
|
||||
record.prompt_tokens,
|
||||
record.completion_tokens,
|
||||
other.model_ratio,
|
||||
other.model_price,
|
||||
other.completion_ratio,
|
||||
other.group_ratio,
|
||||
);
|
||||
return (
|
||||
<Paragraph
|
||||
ellipsis={{
|
||||
rows: 2,
|
||||
showTooltip: { type: 'popover', opts: { style: { width: 240 } } },
|
||||
}}
|
||||
style={{ maxWidth: 240 }}
|
||||
>
|
||||
{text}
|
||||
</Paragraph>
|
||||
<Tooltip content={content}>
|
||||
<Paragraph
|
||||
ellipsis={{
|
||||
rows: 2,
|
||||
}}
|
||||
style={{ maxWidth: 240 }}
|
||||
>
|
||||
{text}
|
||||
</Paragraph>
|
||||
</Tooltip>
|
||||
);
|
||||
},
|
||||
},
|
||||
|
||||
@@ -236,6 +236,31 @@ const renderTimestamp = (timestampInSeconds) => {
|
||||
|
||||
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
|
||||
};
|
||||
// 修改renderDuration函数以包含颜色逻辑
|
||||
function renderDuration(submit_time, finishTime) {
|
||||
// 确保startTime和finishTime都是有效的时间戳
|
||||
if (!submit_time || !finishTime) return 'N/A';
|
||||
|
||||
// 将时间戳转换为Date对象
|
||||
const start = new Date(submit_time);
|
||||
const finish = new Date(finishTime);
|
||||
|
||||
// 计算时间差(毫秒)
|
||||
const durationMs = finish - start;
|
||||
|
||||
// 将时间差转换为秒,并保留一位小数
|
||||
const durationSec = (durationMs / 1000).toFixed(1);
|
||||
|
||||
// 设置颜色:大于60秒则为红色,小于等于60秒则为绿色
|
||||
const color = durationSec > 60 ? 'red' : 'green';
|
||||
|
||||
// 返回带有样式的颜色标签
|
||||
return (
|
||||
<Tag color={color} size="large">
|
||||
{durationSec} 秒
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
const LogsTable = () => {
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
@@ -248,6 +273,15 @@ const LogsTable = () => {
|
||||
return <div>{renderTimestamp(text / 1000)}</div>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '花费时间',
|
||||
dataIndex: 'finish_time', // 以finish_time作为dataIndex
|
||||
key: 'finish_time',
|
||||
render: (finish, record) => {
|
||||
// 假设record.start_time是存在的,并且finish是完成时间的时间戳
|
||||
return renderDuration(record.submit_time, finish);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '渠道',
|
||||
dataIndex: 'channel_id',
|
||||
|
||||
229
web/src/components/ModelPricing.js
Normal file
229
web/src/components/ModelPricing.js
Normal file
@@ -0,0 +1,229 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { API, copy, showError, showSuccess } from '../helpers';
|
||||
|
||||
import { Banner, Layout, Modal, Table, Tag, Tooltip } from '@douyinfe/semi-ui';
|
||||
import { stringToColor } from '../helpers/render.js';
|
||||
import { UserContext } from '../context/User/index.js';
|
||||
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||
|
||||
function renderQuotaType(type) {
|
||||
// Ensure all cases are string literals by adding quotes.
|
||||
switch (type) {
|
||||
case 1:
|
||||
return (
|
||||
<Tag color='green' size='large'>
|
||||
按次计费
|
||||
</Tag>
|
||||
);
|
||||
case 0:
|
||||
return (
|
||||
<Tag color='blue' size='large'>
|
||||
按量计费
|
||||
</Tag>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<Tag color='white' size='large'>
|
||||
未知
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function renderAvailable(available) {
|
||||
return available ? (
|
||||
<Tag color='green' size='large'>
|
||||
可用
|
||||
</Tag>
|
||||
) : (
|
||||
<Tooltip content='您所在的分组不可用'>
|
||||
<Tag color='red' size='large'>
|
||||
不可用
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
const ModelPricing = () => {
|
||||
const columns = [
|
||||
{
|
||||
title: '可用性',
|
||||
dataIndex: 'available',
|
||||
render: (text, record, index) => {
|
||||
return renderAvailable(text);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '提供者',
|
||||
dataIndex: 'owner_by',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<>
|
||||
<Tag color={stringToColor(text)} size='large'>
|
||||
{text}
|
||||
</Tag>
|
||||
</>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '模型名称',
|
||||
dataIndex: 'model_name', // 以finish_time作为dataIndex
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<>
|
||||
<Tag
|
||||
color={stringToColor(record.owner_by)}
|
||||
size='large'
|
||||
onClick={() => {
|
||||
copyText(text);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Tag>
|
||||
</>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '计费类型',
|
||||
dataIndex: 'quota_type',
|
||||
render: (text, record, index) => {
|
||||
return renderQuotaType(parseInt(text));
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '模型倍率',
|
||||
dataIndex: 'model_ratio',
|
||||
render: (text, record, index) => {
|
||||
return <div>{record.quota_type === 0 ? text : 'N/A'}</div>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '补全倍率',
|
||||
dataIndex: 'completion_ratio',
|
||||
render: (text, record, index) => {
|
||||
let ratio = parseFloat(text.toFixed(3));
|
||||
return <div>{record.quota_type === 0 ? ratio : 'N/A'}</div>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '模型价格',
|
||||
dataIndex: 'model_price',
|
||||
render: (text, record, index) => {
|
||||
let content = text;
|
||||
if (record.quota_type === 0) {
|
||||
let inputRatioPrice = record.model_ratio * 2.0 * record.group_ratio;
|
||||
let completionRatioPrice =
|
||||
record.model_ratio *
|
||||
record.completion_ratio *
|
||||
2.0 *
|
||||
record.group_ratio;
|
||||
content = (
|
||||
<>
|
||||
<Text>提示 ${inputRatioPrice} / 1M tokens</Text>
|
||||
<br />
|
||||
<Text>补全 ${completionRatioPrice} / 1M tokens</Text>
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
let price = parseFloat(text) * record.group_ratio;
|
||||
content = <>模型价格:${price}</>;
|
||||
}
|
||||
return <div>{content}</div>;
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const [models, setModels] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [groupRatio, setGroupRatio] = useState(1);
|
||||
|
||||
const setModelsFormat = (models, groupRatio) => {
|
||||
for (let i = 0; i < models.length; i++) {
|
||||
models[i].key = i;
|
||||
models[i].group_ratio = groupRatio;
|
||||
}
|
||||
// sort by quota_type
|
||||
models.sort((a, b) => {
|
||||
return a.quota_type - b.quota_type;
|
||||
});
|
||||
|
||||
// sort by owner_by, openai is max, other use localeCompare
|
||||
models.sort((a, b) => {
|
||||
if (a.owner_by === 'openai') {
|
||||
return -1;
|
||||
} else if (b.owner_by === 'openai') {
|
||||
return 1;
|
||||
} else {
|
||||
return a.owner_by.localeCompare(b.owner_by);
|
||||
}
|
||||
});
|
||||
|
||||
setModels(models);
|
||||
};
|
||||
|
||||
const loadPricing = async () => {
|
||||
setLoading(true);
|
||||
|
||||
let url = '';
|
||||
url = `/api/pricing`;
|
||||
const res = await API.get(url);
|
||||
const { success, message, data, group_ratio } = res.data;
|
||||
if (success) {
|
||||
setGroupRatio(group_ratio);
|
||||
setModelsFormat(data, group_ratio);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const refresh = async () => {
|
||||
await loadPricing();
|
||||
};
|
||||
|
||||
const copyText = async (text) => {
|
||||
if (await copy(text)) {
|
||||
showSuccess('已复制:' + text);
|
||||
} else {
|
||||
// setSearchKeyword(text);
|
||||
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refresh().then();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Layout>
|
||||
{userState.user ? (
|
||||
<Banner
|
||||
type='info'
|
||||
description={`您的分组为:${userState.user.group},分组倍率为:${groupRatio}`}
|
||||
/>
|
||||
) : (
|
||||
<Banner
|
||||
type='warning'
|
||||
description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio}`}
|
||||
/>
|
||||
)}
|
||||
<Table
|
||||
style={{ marginTop: 5 }}
|
||||
columns={columns}
|
||||
dataSource={models}
|
||||
loading={loading}
|
||||
pagination={{
|
||||
pageSize: models.length,
|
||||
showSizeChanger: false,
|
||||
}}
|
||||
/>
|
||||
</Layout>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ModelPricing;
|
||||
@@ -20,6 +20,7 @@ const OperationSetting = () => {
|
||||
PreConsumedQuota: 0,
|
||||
StreamCacheQueueLength: 0,
|
||||
ModelRatio: '',
|
||||
CompletionRatio: '',
|
||||
ModelPrice: '',
|
||||
GroupRatio: '',
|
||||
TopUpLink: '',
|
||||
@@ -38,6 +39,7 @@ const OperationSetting = () => {
|
||||
StopOnSensitiveEnabled: '',
|
||||
SensitiveWords: '',
|
||||
MjNotifyEnabled: '',
|
||||
MjAccountFilterEnabled: '',
|
||||
MjModeClearEnabled: '',
|
||||
MjForwardUrlEnabled: '',
|
||||
DrawingEnabled: '',
|
||||
@@ -67,6 +69,7 @@ const OperationSetting = () => {
|
||||
if (
|
||||
item.key === 'ModelRatio' ||
|
||||
item.key === 'GroupRatio' ||
|
||||
item.key === 'CompletionRatio' ||
|
||||
item.key === 'ModelPrice'
|
||||
) {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
@@ -156,6 +159,13 @@ const OperationSetting = () => {
|
||||
}
|
||||
await updateOption('ModelRatio', inputs.ModelRatio);
|
||||
}
|
||||
if (originInputs['CompletionRatio'] !== inputs.CompletionRatio) {
|
||||
if (!verifyJSON(inputs.CompletionRatio)) {
|
||||
showError('模型补全倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('CompletionRatio', inputs.CompletionRatio);
|
||||
}
|
||||
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
|
||||
if (!verifyJSON(inputs.GroupRatio)) {
|
||||
showError('分组倍率不是合法的 JSON 字符串');
|
||||
@@ -323,6 +333,12 @@ const OperationSetting = () => {
|
||||
name='MjNotifyEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.MjAccountFilterEnabled === 'true'}
|
||||
label='允许AccountFilter参数'
|
||||
name='MjAccountFilterEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.MjForwardUrlEnabled === 'true'}
|
||||
label='开启之后将上游地址替换为服务器地址'
|
||||
@@ -586,6 +602,17 @@ const OperationSetting = () => {
|
||||
placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型补全倍率(仅对自定义模型有效)'
|
||||
name='CompletionRatio'
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autoComplete='new-password'
|
||||
value={inputs.CompletionRatio}
|
||||
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='分组倍率'
|
||||
|
||||
@@ -23,10 +23,12 @@ import {
|
||||
IconImage,
|
||||
IconKey,
|
||||
IconLayers,
|
||||
IconPriceTag,
|
||||
IconSetting,
|
||||
IconUser,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { Layout, Nav } from '@douyinfe/semi-ui';
|
||||
import { setStatusData } from '../helpers/data.js';
|
||||
|
||||
// HeaderBar Buttons
|
||||
|
||||
@@ -55,6 +57,7 @@ const SiderBar = () => {
|
||||
about: '/about',
|
||||
chat: '/chat',
|
||||
detail: '/detail',
|
||||
pricing: '/pricing',
|
||||
};
|
||||
|
||||
const headerButtons = useMemo(
|
||||
@@ -100,6 +103,12 @@ const SiderBar = () => {
|
||||
to: '/topup',
|
||||
icon: <IconCreditCard />,
|
||||
},
|
||||
{
|
||||
text: '模型价格',
|
||||
itemKey: 'pricing',
|
||||
to: '/pricing',
|
||||
icon: <IconPriceTag />,
|
||||
},
|
||||
{
|
||||
text: '用户管理',
|
||||
itemKey: 'user',
|
||||
@@ -161,34 +170,8 @@ const SiderBar = () => {
|
||||
}
|
||||
const { success, data } = res.data;
|
||||
if (success) {
|
||||
localStorage.setItem('status', JSON.stringify(data));
|
||||
statusDispatch({ type: 'set', payload: data });
|
||||
localStorage.setItem('system_name', data.system_name);
|
||||
localStorage.setItem('logo', data.logo);
|
||||
localStorage.setItem('footer_html', data.footer_html);
|
||||
localStorage.setItem('quota_per_unit', data.quota_per_unit);
|
||||
localStorage.setItem('display_in_currency', data.display_in_currency);
|
||||
localStorage.setItem('enable_drawing', data.enable_drawing);
|
||||
localStorage.setItem('enable_data_export', data.enable_data_export);
|
||||
localStorage.setItem(
|
||||
'data_export_default_time',
|
||||
data.data_export_default_time,
|
||||
);
|
||||
localStorage.setItem(
|
||||
'default_collapse_sidebar',
|
||||
data.default_collapse_sidebar,
|
||||
);
|
||||
localStorage.setItem('mj_notify_enabled', data.mj_notify_enabled);
|
||||
if (data.chat_link) {
|
||||
localStorage.setItem('chat_link', data.chat_link);
|
||||
} else {
|
||||
localStorage.removeItem('chat_link');
|
||||
}
|
||||
if (data.chat_link2) {
|
||||
localStorage.setItem('chat_link2', data.chat_link2);
|
||||
} else {
|
||||
localStorage.removeItem('chat_link2');
|
||||
}
|
||||
setStatusData(data);
|
||||
} else {
|
||||
showError('无法正常连接至服务器!');
|
||||
}
|
||||
|
||||
@@ -189,7 +189,7 @@ const SystemSetting = () => {
|
||||
if (inputs.EpayId !== '') {
|
||||
await updateOption('EpayId', inputs.EpayId);
|
||||
}
|
||||
if (inputs.EpayKey !== '') {
|
||||
if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') {
|
||||
await updateOption('EpayKey', inputs.EpayKey);
|
||||
}
|
||||
await updateOption('Price', '' + inputs.Price);
|
||||
|
||||
@@ -235,6 +235,8 @@ const UsersTable = () => {
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [searchGroup, setSearchGroup] = useState('');
|
||||
const [groupOptions, setGroupOptions] = useState([]);
|
||||
const [userCount, setUserCount] = useState(ITEMS_PER_PAGE);
|
||||
const [showAddUser, setShowAddUser] = useState(false);
|
||||
const [showEditUser, setShowEditUser] = useState(false);
|
||||
@@ -298,6 +300,7 @@ const UsersTable = () => {
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
});
|
||||
fetchGroups().then();
|
||||
}, []);
|
||||
|
||||
const manageUser = async (username, action, record) => {
|
||||
@@ -340,15 +343,15 @@ const UsersTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const searchUsers = async () => {
|
||||
if (searchKeyword === '') {
|
||||
const searchUsers = async (searchKeyword, searchGroup) => {
|
||||
if (searchKeyword === '' && searchGroup === '') {
|
||||
// if keyword is blank, load files instead.
|
||||
await loadUsers(0);
|
||||
setActivePage(1);
|
||||
return;
|
||||
}
|
||||
setSearching(true);
|
||||
const res = await API.get(`/api/user/search?keyword=${searchKeyword}`);
|
||||
const res = await API.get(`/api/user/search?keyword=${searchKeyword}&group=${searchGroup}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setUsers(data);
|
||||
@@ -409,6 +412,25 @@ const UsersTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const fetchGroups = async () => {
|
||||
try {
|
||||
let res = await API.get(`/api/group/`);
|
||||
// add 'all' option
|
||||
// res.data.data.unshift('all');
|
||||
if (res === undefined) {
|
||||
return;
|
||||
}
|
||||
setGroupOptions(
|
||||
res.data.data.map((group) => ({
|
||||
label: group,
|
||||
value: group,
|
||||
})),
|
||||
);
|
||||
} catch (error) {
|
||||
showError(error.message);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<AddUser
|
||||
@@ -422,17 +444,44 @@ const UsersTable = () => {
|
||||
handleClose={closeEditUser}
|
||||
editingUser={editingUser}
|
||||
></EditUser>
|
||||
<Form onSubmit={searchUsers}>
|
||||
<Form.Input
|
||||
label='搜索关键字'
|
||||
icon='search'
|
||||
field='keyword'
|
||||
iconPosition='left'
|
||||
placeholder='搜索用户的 ID,用户名,显示名称,以及邮箱地址 ...'
|
||||
value={searchKeyword}
|
||||
loading={searching}
|
||||
onChange={(value) => handleKeywordChange(value)}
|
||||
/>
|
||||
<Form
|
||||
onSubmit={() => {
|
||||
searchUsers(searchKeyword, searchGroup);
|
||||
}}
|
||||
labelPosition='left'
|
||||
>
|
||||
<div style={{ display: 'flex' }}>
|
||||
<Space>
|
||||
<Form.Input
|
||||
label='搜索关键字'
|
||||
icon='search'
|
||||
field='keyword'
|
||||
iconPosition='left'
|
||||
placeholder='搜索用户的 ID,用户名,显示名称,以及邮箱地址 ...'
|
||||
value={searchKeyword}
|
||||
loading={searching}
|
||||
onChange={(value) => handleKeywordChange(value)}
|
||||
/>
|
||||
<Form.Select
|
||||
field='group'
|
||||
label='分组'
|
||||
optionList={groupOptions}
|
||||
onChange={(value) => {
|
||||
setSearchGroup(value);
|
||||
searchUsers(searchKeyword, value);
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
label='查询'
|
||||
type='primary'
|
||||
htmlType='submit'
|
||||
className='btn-margin-right'
|
||||
style={{ marginRight: 8 }}
|
||||
>
|
||||
查询
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
</Form>
|
||||
|
||||
<Table
|
||||
|
||||
@@ -18,3 +18,32 @@ export async function onGitHubOAuthClicked(github_client_id) {
|
||||
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`,
|
||||
);
|
||||
}
|
||||
|
||||
let channelModels = undefined;
|
||||
export async function loadChannelModels() {
|
||||
const res = await API.get('/api/models');
|
||||
const { success, data } = res.data;
|
||||
if (!success) {
|
||||
return;
|
||||
}
|
||||
channelModels = data;
|
||||
localStorage.setItem('channel_models', JSON.stringify(data));
|
||||
}
|
||||
|
||||
export function getChannelModels(type) {
|
||||
if (channelModels !== undefined && type in channelModels) {
|
||||
if (!channelModels[type]) {
|
||||
return [];
|
||||
}
|
||||
return channelModels[type];
|
||||
}
|
||||
let models = localStorage.getItem('channel_models');
|
||||
if (!models) {
|
||||
return [];
|
||||
}
|
||||
channelModels = JSON.parse(models);
|
||||
if (type in channelModels) {
|
||||
return channelModels[type];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
@@ -50,6 +50,13 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'orange',
|
||||
label: 'Google Gemini',
|
||||
},
|
||||
{
|
||||
key: 34,
|
||||
text: 'Cohere',
|
||||
value: 34,
|
||||
color: 'purple',
|
||||
label: 'Cohere',
|
||||
},
|
||||
{
|
||||
key: 15,
|
||||
text: '百度文心千帆',
|
||||
@@ -79,13 +86,13 @@ export const CHANNEL_OPTIONS = [
|
||||
label: '智谱 ChatGLM',
|
||||
},
|
||||
{
|
||||
key: 16,
|
||||
key: 26,
|
||||
text: '智谱 GLM-4V',
|
||||
value: 26,
|
||||
color: 'purple',
|
||||
label: '智谱 GLM-4V',
|
||||
},
|
||||
{ key: 16, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
||||
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
|
||||
|
||||
33
web/src/helpers/data.js
Normal file
33
web/src/helpers/data.js
Normal file
@@ -0,0 +1,33 @@
|
||||
export function setStatusData(data) {
|
||||
localStorage.setItem('status', JSON.stringify(data));
|
||||
localStorage.setItem('system_name', data.system_name);
|
||||
localStorage.setItem('logo', data.logo);
|
||||
localStorage.setItem('footer_html', data.footer_html);
|
||||
localStorage.setItem('quota_per_unit', data.quota_per_unit);
|
||||
localStorage.setItem('display_in_currency', data.display_in_currency);
|
||||
localStorage.setItem('enable_drawing', data.enable_drawing);
|
||||
localStorage.setItem('enable_data_export', data.enable_data_export);
|
||||
localStorage.setItem(
|
||||
'data_export_default_time',
|
||||
data.data_export_default_time,
|
||||
);
|
||||
localStorage.setItem(
|
||||
'default_collapse_sidebar',
|
||||
data.default_collapse_sidebar,
|
||||
);
|
||||
localStorage.setItem('mj_notify_enabled', data.mj_notify_enabled);
|
||||
if (data.chat_link) {
|
||||
localStorage.setItem('chat_link', data.chat_link);
|
||||
} else {
|
||||
localStorage.removeItem('chat_link');
|
||||
}
|
||||
if (data.chat_link2) {
|
||||
localStorage.setItem('chat_link2', data.chat_link2);
|
||||
} else {
|
||||
localStorage.removeItem('chat_link2');
|
||||
}
|
||||
}
|
||||
|
||||
export function setUserData(data) {
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
import { Label } from 'semantic-ui-react';
|
||||
import { Tag } from '@douyinfe/semi-ui';
|
||||
|
||||
export function renderText(text, limit) {
|
||||
@@ -135,6 +134,43 @@ export function renderQuota(quota, digits = 2) {
|
||||
return renderNumber(quota);
|
||||
}
|
||||
|
||||
export function renderModelPrice(
|
||||
inputTokens,
|
||||
completionTokens,
|
||||
modelRatio,
|
||||
modelPrice = -1,
|
||||
completionRatio,
|
||||
groupRatio,
|
||||
) {
|
||||
// 1 ratio = $0.002 / 1K tokens
|
||||
if (modelPrice !== -1) {
|
||||
return '模型价格:$' + modelPrice * groupRatio;
|
||||
} else {
|
||||
if (completionRatio === undefined) {
|
||||
completionRatio = 0;
|
||||
}
|
||||
let inputRatioPrice = modelRatio * 2.0 * groupRatio;
|
||||
let completionRatioPrice = modelRatio * completionRatio * 2.0 * groupRatio;
|
||||
let price =
|
||||
(inputTokens / 1000000) * inputRatioPrice +
|
||||
(completionTokens / 1000000) * completionRatioPrice;
|
||||
return (
|
||||
<>
|
||||
<article>
|
||||
<p>提示 ${inputRatioPrice} / 1M tokens</p>
|
||||
<p>补全 ${completionRatioPrice} / 1M tokens</p>
|
||||
<p></p>
|
||||
<p>
|
||||
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
|
||||
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $
|
||||
{price.toFixed(6)}
|
||||
</p>
|
||||
</article>
|
||||
</>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export function renderQuotaWithPrompt(quota, digits) {
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
Banner,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { Divider } from 'semantic-ui-react';
|
||||
import { getChannelModels, loadChannelModels } from '../../components/utils.js';
|
||||
|
||||
const MODEL_MAPPING_EXAMPLE = {
|
||||
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
|
||||
@@ -87,87 +88,9 @@ const EditChannel = (props) => {
|
||||
const [customModel, setCustomModel] = useState('');
|
||||
const handleInputChange = (name, value) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
if (name === 'type' && inputs.models.length === 0) {
|
||||
if (name === 'type') {
|
||||
let localModels = [];
|
||||
switch (value) {
|
||||
case 33:
|
||||
case 14:
|
||||
localModels = [
|
||||
'claude-instant-1.2',
|
||||
'claude-2',
|
||||
'claude-2.0',
|
||||
'claude-2.1',
|
||||
'claude-3-opus-20240229',
|
||||
'claude-3-sonnet-20240229',
|
||||
'claude-3-haiku-20240307',
|
||||
];
|
||||
break;
|
||||
case 11:
|
||||
localModels = ['PaLM-2'];
|
||||
break;
|
||||
case 15:
|
||||
localModels = [
|
||||
'ERNIE-Bot',
|
||||
'ERNIE-Bot-turbo',
|
||||
'ERNIE-Bot-4',
|
||||
'Embedding-V1',
|
||||
];
|
||||
break;
|
||||
case 17:
|
||||
localModels = [
|
||||
'qwen-turbo',
|
||||
'qwen-plus',
|
||||
'qwen-max',
|
||||
'qwen-max-longcontext',
|
||||
'text-embedding-v1',
|
||||
];
|
||||
break;
|
||||
case 16:
|
||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||
break;
|
||||
case 18:
|
||||
localModels = [
|
||||
'SparkDesk',
|
||||
'SparkDesk-v1.1',
|
||||
'SparkDesk-v2.1',
|
||||
'SparkDesk-v3.1',
|
||||
'SparkDesk-v3.5',
|
||||
];
|
||||
break;
|
||||
case 19:
|
||||
localModels = [
|
||||
'360GPT_S2_V9',
|
||||
'embedding-bert-512-v1',
|
||||
'embedding_s1_v1',
|
||||
'semantic_similarity_s1_v1',
|
||||
];
|
||||
break;
|
||||
case 23:
|
||||
localModels = ['hunyuan'];
|
||||
break;
|
||||
case 24:
|
||||
localModels = [
|
||||
'gemini-1.0-pro-001',
|
||||
'gemini-1.0-pro-vision-001',
|
||||
'gemini-1.5-pro',
|
||||
'gemini-1.5-pro-latest',
|
||||
'gemini-pro',
|
||||
'gemini-pro-vision',
|
||||
];
|
||||
break;
|
||||
case 25:
|
||||
localModels = [
|
||||
'moonshot-v1-8k',
|
||||
'moonshot-v1-32k',
|
||||
'moonshot-v1-128k',
|
||||
];
|
||||
break;
|
||||
case 26:
|
||||
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
|
||||
break;
|
||||
case 31:
|
||||
localModels = ['yi-34b-chat-0205', 'yi-34b-chat-200k', 'yi-vl-plus'];
|
||||
break;
|
||||
case 2:
|
||||
localModels = [
|
||||
'mj_imagine',
|
||||
@@ -176,6 +99,7 @@ const EditChannel = (props) => {
|
||||
'mj_blend',
|
||||
'mj_upscale',
|
||||
'mj_describe',
|
||||
'mj_uploads',
|
||||
];
|
||||
break;
|
||||
case 5:
|
||||
@@ -195,10 +119,17 @@ const EditChannel = (props) => {
|
||||
'mj_high_variation',
|
||||
'mj_low_variation',
|
||||
'mj_pan',
|
||||
'mj_uploads',
|
||||
];
|
||||
break;
|
||||
default:
|
||||
localModels = getChannelModels(value);
|
||||
break;
|
||||
}
|
||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||
if (inputs.models.length === 0) {
|
||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||
}
|
||||
setBasicModels(localModels);
|
||||
}
|
||||
//setAutoBan
|
||||
};
|
||||
@@ -234,6 +165,7 @@ const EditChannel = (props) => {
|
||||
} else {
|
||||
setAutoBan(true);
|
||||
}
|
||||
setBasicModels(getChannelModels(data.type));
|
||||
// console.log(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -302,6 +234,9 @@ const EditChannel = (props) => {
|
||||
loadChannel().then(() => {});
|
||||
} else {
|
||||
setInputs(originInputs);
|
||||
let localModels = getChannelModels(inputs.type);
|
||||
setBasicModels(localModels);
|
||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||
}
|
||||
}, [props.editingChannel.id]);
|
||||
|
||||
@@ -363,24 +298,39 @@ const EditChannel = (props) => {
|
||||
}
|
||||
};
|
||||
|
||||
const addCustomModel = () => {
|
||||
const addCustomModels = () => {
|
||||
if (customModel.trim() === '') return;
|
||||
if (inputs.models.includes(customModel)) return showError('该模型已存在!');
|
||||
// 使用逗号分隔字符串,然后去除每个模型名称前后的空格
|
||||
const modelArray = customModel.split(',').map(model => model.trim());
|
||||
|
||||
let localModels = [...inputs.models];
|
||||
localModels.push(customModel);
|
||||
let localModelOptions = [];
|
||||
localModelOptions.push({
|
||||
key: customModel,
|
||||
text: customModel,
|
||||
value: customModel,
|
||||
});
|
||||
setModelOptions((modelOptions) => {
|
||||
return [...modelOptions, ...localModelOptions];
|
||||
let localModelOptions = [...modelOptions];
|
||||
let hasError = false;
|
||||
|
||||
modelArray.forEach(model => {
|
||||
// 检查模型是否已存在,且模型名称非空
|
||||
if (model && !localModels.includes(model)) {
|
||||
localModels.push(model); // 添加到模型列表
|
||||
localModelOptions.push({ // 添加到下拉选项
|
||||
key: model,
|
||||
text: model,
|
||||
value: model,
|
||||
});
|
||||
} else if (model) {
|
||||
showError('某些模型已存在!');
|
||||
hasError = true;
|
||||
}
|
||||
});
|
||||
|
||||
if (hasError) return; // 如果有错误则终止操作
|
||||
|
||||
// 更新状态值
|
||||
setModelOptions(localModelOptions);
|
||||
setCustomModel('');
|
||||
handleInputChange('models', localModels);
|
||||
};
|
||||
|
||||
|
||||
return (
|
||||
<>
|
||||
<SideSheet
|
||||
@@ -586,7 +536,7 @@ const EditChannel = (props) => {
|
||||
handleInputChange('models', basicModels);
|
||||
}}
|
||||
>
|
||||
填入基础模型
|
||||
填入相关模型
|
||||
</Button>
|
||||
<Button
|
||||
type='secondary'
|
||||
@@ -607,7 +557,7 @@ const EditChannel = (props) => {
|
||||
</Space>
|
||||
<Input
|
||||
addonAfter={
|
||||
<Button type='primary' onClick={addCustomModel}>
|
||||
<Button type='primary' onClick={addCustomModels}>
|
||||
填入
|
||||
</Button>
|
||||
}
|
||||
|
||||
10
web/src/pages/Pricing/index.js
Normal file
10
web/src/pages/Pricing/index.js
Normal file
@@ -0,0 +1,10 @@
|
||||
import React from 'react';
|
||||
import ModelPricing from '../../components/ModelPricing.js';
|
||||
|
||||
const Pricing = () => (
|
||||
<>
|
||||
<ModelPricing />
|
||||
</>
|
||||
);
|
||||
|
||||
export default Pricing;
|
||||
@@ -1,12 +1,13 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { API, isMobile, showError, showSuccess } from '../../helpers';
|
||||
import { renderQuotaWithPrompt } from '../../helpers/render';
|
||||
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
|
||||
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
|
||||
import {
|
||||
Button,
|
||||
Divider,
|
||||
Input,
|
||||
Modal,
|
||||
Select,
|
||||
SideSheet,
|
||||
Space,
|
||||
@@ -17,6 +18,8 @@ import {
|
||||
const EditUser = (props) => {
|
||||
const userId = props.editingUser.id;
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [addQuotaModalOpen, setIsModalOpen] = useState(false);
|
||||
const [addQuotaLocal, setAddQuotaLocal] = useState('');
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
display_name: '',
|
||||
@@ -107,6 +110,16 @@ const EditUser = (props) => {
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const addLocalQuota = () => {
|
||||
let newQuota = parseInt(quota) + parseInt(addQuotaLocal);
|
||||
setInputs((inputs) => ({ ...inputs, quota: newQuota }));
|
||||
};
|
||||
|
||||
const openAddQuotaModal = () => {
|
||||
setAddQuotaLocal('0');
|
||||
setIsModalOpen(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<SideSheet
|
||||
@@ -192,14 +205,17 @@ const EditUser = (props) => {
|
||||
<div style={{ marginTop: 20 }}>
|
||||
<Typography.Text>{`剩余额度${renderQuotaWithPrompt(quota)}`}</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
name='quota'
|
||||
placeholder={'请输入新的剩余额度'}
|
||||
onChange={(value) => handleInputChange('quota', value)}
|
||||
value={quota}
|
||||
type={'number'}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
<Space>
|
||||
<Input
|
||||
name='quota'
|
||||
placeholder={'请输入新的剩余额度'}
|
||||
onChange={(value) => handleInputChange('quota', value)}
|
||||
value={quota}
|
||||
type={'number'}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
<Button onClick={openAddQuotaModal}>添加额度</Button>
|
||||
</Space>
|
||||
</>
|
||||
)}
|
||||
<Divider style={{ marginTop: 20 }}>以下信息不可修改</Divider>
|
||||
@@ -245,6 +261,30 @@ const EditUser = (props) => {
|
||||
/>
|
||||
</Spin>
|
||||
</SideSheet>
|
||||
<Modal
|
||||
centered={true}
|
||||
visible={addQuotaModalOpen}
|
||||
onOk={() => {
|
||||
addLocalQuota();
|
||||
setIsModalOpen(false);
|
||||
}}
|
||||
onCancel={() => setIsModalOpen(false)}
|
||||
closable={null}
|
||||
>
|
||||
<div style={{ marginTop: 20 }}>
|
||||
<Typography.Text>{`新额度${renderQuota(quota)} + ${renderQuota(addQuotaLocal)} = ${renderQuota(quota + parseInt(addQuotaLocal))}`}</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
name='addQuotaLocal'
|
||||
placeholder={'需要添加的额度(支持负数)'}
|
||||
onChange={(value) => {
|
||||
setAddQuotaLocal(value);
|
||||
}}
|
||||
value={addQuotaLocal}
|
||||
type={'number'}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user