Compare commits

...

30 Commits

Author SHA1 Message Date
1808837298@qq.com
1cff3c100a Merge remote-tracking branch 'origin/main' 2024-05-08 16:57:23 +08:00
1808837298@qq.com
d7a343e2f6 feat: update model ratio 2024-05-08 16:57:11 +08:00
Calcium-Ion
637801fba5 Merge pull request #232 from kakingone/add-mj-usetime
add-mj-use-time
2024-05-08 16:51:17 +08:00
1808837298@qq.com
2bf404507f fix: update user (#230) 2024-05-08 16:46:06 +08:00
kakingone
675de89c69 --amend 2024-05-06 17:18:04 +08:00
1808837298@qq.com
16b9aacb06 feat: log completionRatio 2024-05-03 12:26:17 +08:00
1808837298@qq.com
cad380eb16 feat: able to set AccountFilter 2024-05-01 01:37:47 +08:00
1808837298@qq.com
234e39ddeb feat: update midjourney task info update timeout 2024-05-01 01:32:01 +08:00
1808837298@qq.com
7fb6420e66 fix: aws claude system 2024-04-29 00:06:25 +08:00
CaIon
5425b5bfc3 fix: aws claude 2024-04-28 20:45:34 +08:00
CaIon
21f32605c8 feat: safe send channel 2024-04-28 16:17:16 +08:00
CaIon
1c6fd87909 fix: 规范claude返回格式 2024-04-26 02:56:35 +08:00
CaIon
d1c8947851 fix: 规范claude返回格式 2024-04-25 23:57:39 +08:00
CaIon
7d2d525051 fix: claude流模式缺失role 2024-04-25 22:57:11 +08:00
CaIon
be4809b95a feat: log status code 2024-04-25 20:47:18 +08:00
CaIon
e2edd5e7e5 fix: claude 2024-04-25 20:37:50 +08:00
1808837298@qq.com
a14fa1adb1 feat: claude 整理prompt 2024-04-25 16:04:53 +08:00
CaIon
2cb10b003a fix typo 2024-04-24 22:53:58 +08:00
CaIon
86b17fcce8 chore: update model ratio 2024-04-24 22:08:54 +08:00
CaIon
08b5336431 fix: update user 2024-04-24 19:40:29 +08:00
CaIon
20aaf30785 feat: update model ratio 2024-04-24 18:53:21 +08:00
CaIon
bfcaccc2e3 feat: support cohere (close #195) 2024-04-24 18:49:56 +08:00
CaIon
3f448ba4fc feat: dalle系列日志记录更多信息 2024-04-24 15:14:16 +08:00
CaIon
408c2bdd9b chore: 移除无用代码 2024-04-24 15:13:53 +08:00
CaIon
b1b38a6bd4 fix: audio预扣费未返还 2024-04-24 15:08:15 +08:00
Calcium-Ion
608ec28761 Merge pull request #213 from iszcz/pr
用户管理页-新增分组查询
2024-04-24 14:51:08 +08:00
CaIon
a3ccc92f55 fix: close #218 2024-04-24 14:44:24 +08:00
CaIon
77e7d11151 fix: fix update payment setting 2024-04-24 00:01:54 +08:00
CaIon
783e8fd74a refactor: 重构计费代码 2024-04-23 23:51:27 +08:00
iszcz
79cf70683f 用户管理页新增分组查询 2024-04-20 02:13:11 +08:00
41 changed files with 652 additions and 141 deletions

View File

@@ -207,6 +207,7 @@ const (
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
)
var ChannelBaseURLs = []string{
@@ -244,5 +245,5 @@ var ChannelBaseURLs = []string{
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
}

View File

@@ -16,7 +16,22 @@ func SafeGoroutine(f func()) {
}()
}
func SafeSend(ch chan bool, value bool) (closed bool) {
func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = true
}
}()
// This will panic if the channel is closed.
ch <- value
// If the code reaches here, then the channel was not closed.
return false
}
func SafeSendString(ch chan string, value string) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {

View File

@@ -2,6 +2,7 @@ package common
import (
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
@@ -98,3 +99,13 @@ func LogQuota(quota int) string {
return fmt.Sprintf("%d 点额度", quota)
}
}
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)
if err != nil {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
return
}
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
}

View File

@@ -102,9 +102,17 @@ var DefaultModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.018,
"yi-34b-chat-200k": 0.0864,
"yi-vl-plus": 0.0432,
"yi-34b-chat-0205": 0.018,
"yi-34b-chat-200k": 0.0864,
"yi-vl-plus": 0.0432,
"command": 0.5,
"command-nightly": 0.5,
"command-light": 0.5,
"command-light-nightly": 0.5,
"command-r": 0.25,
"command-r-plus ": 1.5,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
}
var DefaultModelPrice = map[string]float64{
@@ -224,6 +232,19 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gemini-") {
return 3
}
if strings.HasPrefix(name, "command") {
switch name {
case "command-r":
return 3
case "command-r-plus":
return 5
default:
return 2
}
}
if strings.HasPrefix(name, "deepseek") {
return 2
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7

View File

@@ -1,6 +1,7 @@
package constant
var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true

View File

@@ -86,7 +86,7 @@ func UpdateMidjourneyTaskBulk() {
continue
}
// 设置超时时间
timeout := time.Second * 5
timeout := time.Second * 15
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)

View File

@@ -124,7 +124,7 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
@@ -160,7 +160,7 @@ func RelayMidjourney(c *gin.Context) {
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
}
}

View File

@@ -216,7 +216,8 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
users, err := model.SearchUsers(keyword)
group := c.Query("group")
users, err := model.SearchUsers(keyword, group)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -452,7 +453,7 @@ func UpdateUser(c *gin.Context) {
updatedUser.Password = "" // rollback to what it should be
}
updatePassword := updatedUser.Password != ""
if err := updatedUser.Update(updatePassword); err != nil {
if err := updatedUser.Edit(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -725,7 +726,7 @@ func ManageUser(c *gin.Context) {
user.Role = common.RoleCommonUser
}
if err := user.UpdateAll(false); err != nil {
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@@ -43,6 +43,10 @@ type OpenAIFunction struct {
Parameters any `json:"parameters,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
return int64(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil

View File

@@ -54,17 +54,33 @@ type OpenAIEmbeddingResponse struct {
}
type ChatCompletionsStreamResponseChoice struct {
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
Logprobs *any `json:"logprobs"`
FinishReason *string `json:"finish_reason"`
Index int `json:"index"`
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content string `json:"content"`
Content *string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
return c.Content == nil && len(c.ToolCalls) == 0
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
if c.Content == nil {
return ""
}
return *c.Content
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
@@ -80,11 +96,12 @@ type FunctionCall struct {
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint *string `json:"system_fingerprint"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {

View File

@@ -92,6 +92,7 @@ func InitOptionMap() {
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
@@ -197,6 +198,8 @@ func updateOptionMap(key string, value string) (err error) {
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
case "MjAccountFilterEnabled":
constant.MjAccountFilterEnabled = boolValue
case "MjModeClearEnabled":
constant.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":

View File

@@ -73,25 +73,34 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
return users, err
}
func SearchUsers(keyword string) ([]*User, error) {
func SearchUsers(keyword string, group string) ([]*User, error) {
var users []*User
var err error
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果转换成功按照ID搜索用户
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
// 如果转换成功按照ID和可选的组别搜索用户
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
if group != "" {
query = query.Where("`group` = ?", group) // 使用反引号包围group
}
err = query.Find(&users).Error
if err != nil || len(users) > 0 {
// 如果依据ID找到用户或者发生错误返回结果或错误
return users, err
}
}
// 如果ID转换失败或者没有找到用户依据其他字段进行模糊搜索
err = DB.Unscoped().Omit("password").
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
Find(&users).Error
err = nil
query := DB.Unscoped().Omit("password")
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
if group != "" {
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
err = query.Find(&users).Error
return users, err
}
@@ -235,7 +244,7 @@ func (user *User) Update(updatePassword bool) error {
return err
}
func (user *User) UpdateAll(updatePassword bool) error {
func (user *User) Edit(updatePassword bool) error {
var err error
if updatePassword {
user.Password, err = common.Password2Hash(user.Password)
@@ -244,8 +253,17 @@ func (user *User) UpdateAll(updatePassword bool) error {
}
}
newUser := *user
updates := map[string]interface{}{
"username": newUser.Username,
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
}
if updatePassword {
updates["password"] = newUser.Password
}
DB.First(&user, user.Id)
err = DB.Model(user).Select("*").Updates(newUser).Error
err = DB.Model(user).Updates(updates).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)

View File

@@ -136,7 +136,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
choice.Delta.SetContentString(aliResponse.Output.Text)
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
@@ -199,7 +199,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {

View File

@@ -5,6 +5,7 @@ import "one-api/relay/channel/claude"
type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
System string `json:"system"`
Messages []claude.ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`

View File

@@ -156,6 +156,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
var usage relaymodel.Usage
var id string
var model string
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
@@ -188,6 +189,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model

View File

@@ -57,7 +57,7 @@ func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
choice.Delta.SetContentString(baiduResponse.Result)
if baiduResponse.IsEnd {
choice.FinishReason = &relaycommon.StopFinishReason
}

View File

@@ -24,15 +24,16 @@ type ClaudeMessage struct {
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
System string `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
System string `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}

View File

@@ -20,7 +20,7 @@ func stopReasonClaude2OpenAI(reason string) string {
case "end_turn":
return "stop"
case "max_tokens":
return "length"
return "max_tokens"
default:
return reason
}
@@ -30,15 +30,14 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokens: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 4096
}
prompt := ""
for _, message := range textRequest.Messages {
@@ -73,13 +72,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range textRequest.Messages {
if message.Role == "system" {
if i != 0 {
message.Role = "user"
}
}
//if message.Role == "system" {
// if i != 0 {
// message.Role = "user"
// }
//}
if message.Role == "" {
message.Role = "user"
textRequest.Messages[i].Role = "user"
}
fmtMessage := dto.Message{
Role: message.Role,
@@ -98,13 +97,24 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
fmtMessage.Content = content
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = &message
lastMessage = &textRequest.Messages[i]
}
claudeMessages := make([]ClaudeMessage, 0)
for _, message := range formatMessages {
if message.Role == "system" {
claudeRequest.System = message.StringContent()
if message.IsStringContent() {
claudeRequest.System = message.StringContent()
} else {
contents := message.ParseContent()
content := ""
for _, ctx := range contents {
if ctx.Type == "text" {
content += ctx.Text
}
}
claudeRequest.System = content
}
} else {
claudeMessage := ClaudeMessage{
Role: message.Role,
@@ -149,7 +159,6 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages
return &claudeRequest, nil
}
@@ -161,7 +170,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.Content = claudeResponse.Completion
choice.Delta.SetContentString(claudeResponse.Completion)
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
@@ -171,9 +180,13 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
return nil, nil
} else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
choice.Delta.SetContentString(claudeResponse.Delta.Text)
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
@@ -182,12 +195,15 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil, nil
} else {
return nil, nil
}
}
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
}

View File

@@ -0,0 +1,52 @@
package cohere
import (
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package cohere
var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
}
var ChannelName = "cohere"

View File

@@ -0,0 +1,44 @@
package cohere
type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int64 `json:"max_tokens"`
}
type ChatHistory struct {
Role string `json:"role"`
Message string `json:"message"`
}
type CohereResponse struct {
IsFinished bool `json:"is_finished"`
EventType string `json:"event_type"`
Text string `json:"text,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Response *CohereResponseResult `json:"response"`
}
type CohereResponseResult struct {
ResponseId string `json:"response_id"`
FinishReason string `json:"finish_reason,omitempty"`
Text string `json:"text"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {
//Tokens CohereTokens `json:"tokens"`
BilledUnits CohereBilledUnits `json:"billed_units"`
}
type CohereBilledUnits struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type CohereTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -0,0 +1,189 @@
package cohere
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
cohereReq := CohereRequest{
Model: textRequest.Model,
ChatHistory: []ChatHistory{},
Message: "",
Stream: textRequest.Stream,
MaxTokens: textRequest.GetMaxTokens(),
}
if cohereReq.MaxTokens == 0 {
cohereReq.MaxTokens = 4000
}
for _, msg := range textRequest.Messages {
if msg.Role == "user" {
cohereReq.Message = msg.StringContent()
} else {
var role string
if msg.Role == "assistant" {
role = "CHATBOT"
} else if msg.Role == "system" {
role = "SYSTEM"
} else {
role = "USER"
}
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
Role: role,
Message: msg.StringContent(),
})
}
}
return &cohereReq
}
func stopReasonCohere2OpenAI(reason string) string {
switch reason {
case "COMPLETE":
return "stop"
case "MAX_TOKENS":
return "max_tokens"
default:
return reason
}
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResp dto.ChatCompletionsStreamResponse
openaiResp.Id = responseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = modelName
if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
{
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
Index: 0,
FinishReason: &finishReason,
},
}
if cohereResp.Response != nil {
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
}
} else {
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
{
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
Content: &cohereResp.Text,
},
Index: 0,
},
}
responseText += cohereResp.Text
}
jsonStr, err := json.Marshal(openaiResp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
}
return nil, usage
}
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
createdTime := common.GetTimestamp()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
var openaiResp dto.TextResponse
openaiResp.Id = cohereResp.ResponseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion"
openaiResp.Model = modelName
openaiResp.Usage = usage
content, _ := json.Marshal(cohereResp.Text)
openaiResp.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
Message: dto.Message{Content: content, Role: "assistant"},
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
},
}
jsonResponse, err := json.Marshal(openaiResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -151,7 +151,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.Delta.SetContentString(geminiResponse.GetResponseText())
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
@@ -203,7 +203,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",

View File

@@ -50,7 +50,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
common.SafeSendString(dataChan, data)
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)
@@ -68,7 +68,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content)
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -84,7 +84,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content)
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -123,7 +123,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
// wait data out
time.Sleep(2 * time.Second)
}
common.SafeSend(stopChan, true)
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {

View File

@@ -61,7 +61,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
}
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse

View File

@@ -86,7 +86,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
}
if len(TencentResponse.Choices) > 0 {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &relaycommon.StopFinishReason
}
@@ -138,7 +138,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
responseText += response.Choices[0].Delta.GetContentString()
}
jsonResponse, err := json.Marshal(response)
if err != nil {

View File

@@ -87,7 +87,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCo
}
}
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &relaycommon.StopFinishReason
}

View File

@@ -126,7 +126,7 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse
choice.Delta.SetContentString(zhipuResponse)
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
@@ -138,7 +138,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStream
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
choice.Delta.SetContentString("")
choice.FinishReason = &relaycommon.StopFinishReason
response := dto.ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,

View File

@@ -19,6 +19,7 @@ const (
APITypeOllama
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -52,6 +53,8 @@ func ChannelType2APIType(channelType int) int {
apiType = APITypePerplexity
case common.ChannelTypeAws:
apiType = APITypeAws
case common.ChannelTypeCohere:
apiType = APITypeCohere
}
return apiType
}

View File

@@ -20,15 +20,6 @@ import (
"time"
)
var availableVoices = []string{
"alloy",
"echo",
"fable",
"onyx",
"nova",
"shimmer",
}
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
@@ -59,9 +50,6 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if audioRequest.Voice == "" {
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
}
if !common.StringsContains(availableVoices, audioRequest.Voice) {
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}
var err error
promptTokens := 0
@@ -100,6 +88,22 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
}
succeed := false
defer func() {
if succeed {
return
}
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func() {
go func() {
// negative means add quota back for token & user
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
}()
}()
}
}()
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
@@ -163,6 +167,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK {
return relaycommon.RelayErrorHandler(resp)
}
succeed = true
var audioResponse dto.AudioResponse

View File

@@ -34,7 +34,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
imageRequest.Model = "dall-e-3"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
@@ -186,7 +186,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
quality := "normal"
if imageRequest.Quality == "hd" {
quality = "hd"
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")

View File

@@ -264,10 +264,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
completionTokens := usage.CompletionTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
@@ -279,7 +279,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
totalTokens := promptTokens + completionTokens
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}

View File

@@ -6,6 +6,7 @@ import (
"one-api/relay/channel/aws"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/cohere"
"one-api/relay/channel/gemini"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
@@ -48,6 +49,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &perplexity.Adaptor{}
case constant.APITypeAws:
return &aws.Adaptor{}
case constant.APITypeCohere:
return &cohere.Adaptor{}
}
return nil
}

View File

@@ -165,7 +165,9 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
delete(mapResult, "accountFilter")
if !constant.MjAccountFilterEnabled {
delete(mapResult, "accountFilter")
}
if !constant.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
@@ -174,11 +176,11 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
}
if constant.MjModeClearEnabled {
if prompt, ok := mapResult["prompt"].(string); ok {
prompt = strings.Replace(prompt, "--fast", "", -1)
prompt = strings.Replace(prompt, "--relax", "", -1)
prompt = strings.Replace(prompt, "--turbo", "", -1)
mapResult["prompt"] = prompt
prompt = strings.Replace(prompt, "--fast", "", -1)
prompt = strings.Replace(prompt, "--relax", "", -1)
prompt = strings.Replace(prompt, "--turbo", "", -1)
mapResult["prompt"] = prompt
}
}
reqBody, err := json.Marshal(mapResult)

View File

@@ -128,7 +128,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count tools token fail: %s", err.Error())), false
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
}
countStr := ""
for _, tool := range openaiTools {
@@ -173,48 +173,31 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
var arrayContent []dto.MediaMessage
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err, false
} else {
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err, true
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
if message.IsStringContent() {
stringContent := message.StringContent()
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err, true
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
var err error
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
var imageTokenNum int
if model == "glm-4v" {
imageTokenNum = 1047
} else {
if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
} else {
imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"]
if ok {
imageUrlMap["detail"] = detail.(string)
} else {
imageUrlMap["detail"] = "auto"
}
imageUrl := dto.MessageImageUrl{
Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string),
}
imageTokenNum, err = getImageToken(&imageUrl)
}
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err = getImageToken(&imageUrl)
if err != nil {
return 0, err, false
}
@@ -249,7 +232,7 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) {
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
tkm, _, _ := CountTokenInput(message.Delta.Content, model, false)
tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {

View File

@@ -236,6 +236,31 @@ const renderTimestamp = (timestampInSeconds) => {
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
};
// 修改renderDuration函数以包含颜色逻辑
function renderDuration(submit_time, finishTime) {
// 确保startTime和finishTime都是有效的时间戳
if (!submit_time || !finishTime) return 'N/A';
// 将时间戳转换为Date对象
const start = new Date(submit_time);
const finish = new Date(finishTime);
// 计算时间差(毫秒)
const durationMs = finish - start;
// 将时间差转换为秒,并保留一位小数
const durationSec = (durationMs / 1000).toFixed(1);
// 设置颜色大于60秒则为红色小于等于60秒则为绿色
const color = durationSec > 60 ? 'red' : 'green';
// 返回带有样式的颜色标签
return (
<Tag color={color} size="large">
{durationSec}
</Tag>
);
}
const LogsTable = () => {
const [isModalOpen, setIsModalOpen] = useState(false);
@@ -248,6 +273,15 @@ const LogsTable = () => {
return <div>{renderTimestamp(text / 1000)}</div>;
},
},
{
title: '花费时间',
dataIndex: 'finish_time', // 以finish_time作为dataIndex
key: 'finish_time',
render: (finish, record) => {
// 假设record.start_time是存在的并且finish是完成时间的时间戳
return renderDuration(record.submit_time, finish);
},
},
{
title: '渠道',
dataIndex: 'channel_id',

View File

@@ -38,6 +38,7 @@ const OperationSetting = () => {
StopOnSensitiveEnabled: '',
SensitiveWords: '',
MjNotifyEnabled: '',
MjAccountFilterEnabled: '',
MjModeClearEnabled: '',
MjForwardUrlEnabled: '',
DrawingEnabled: '',
@@ -323,6 +324,12 @@ const OperationSetting = () => {
name='MjNotifyEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.MjAccountFilterEnabled === 'true'}
label='允许AccountFilter参数'
name='MjAccountFilterEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.MjForwardUrlEnabled === 'true'}
label='开启之后将上游地址替换为服务器地址'

View File

@@ -189,7 +189,7 @@ const SystemSetting = () => {
if (inputs.EpayId !== '') {
await updateOption('EpayId', inputs.EpayId);
}
if (inputs.EpayKey !== '') {
if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') {
await updateOption('EpayKey', inputs.EpayKey);
}
await updateOption('Price', '' + inputs.Price);

View File

@@ -235,6 +235,8 @@ const UsersTable = () => {
const [activePage, setActivePage] = useState(1);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [searchGroup, setSearchGroup] = useState('');
const [groupOptions, setGroupOptions] = useState([]);
const [userCount, setUserCount] = useState(ITEMS_PER_PAGE);
const [showAddUser, setShowAddUser] = useState(false);
const [showEditUser, setShowEditUser] = useState(false);
@@ -298,6 +300,7 @@ const UsersTable = () => {
.catch((reason) => {
showError(reason);
});
fetchGroups().then();
}, []);
const manageUser = async (username, action, record) => {
@@ -340,15 +343,15 @@ const UsersTable = () => {
}
};
const searchUsers = async () => {
if (searchKeyword === '') {
const searchUsers = async (searchKeyword, searchGroup) => {
if (searchKeyword === '' && searchGroup === '') {
// if keyword is blank, load files instead.
await loadUsers(0);
setActivePage(1);
return;
}
setSearching(true);
const res = await API.get(`/api/user/search?keyword=${searchKeyword}`);
const res = await API.get(`/api/user/search?keyword=${searchKeyword}&group=${searchGroup}`);
const { success, message, data } = res.data;
if (success) {
setUsers(data);
@@ -409,6 +412,25 @@ const UsersTable = () => {
}
};
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
// add 'all' option
// res.data.data.unshift('all');
if (res === undefined) {
return;
}
setGroupOptions(
res.data.data.map((group) => ({
label: group,
value: group,
})),
);
} catch (error) {
showError(error.message);
}
};
return (
<>
<AddUser
@@ -422,17 +444,44 @@ const UsersTable = () => {
handleClose={closeEditUser}
editingUser={editingUser}
></EditUser>
<Form onSubmit={searchUsers}>
<Form.Input
label='搜索关键字'
icon='search'
field='keyword'
iconPosition='left'
placeholder='搜索用户的 ID用户名显示名称以及邮箱地址 ...'
value={searchKeyword}
loading={searching}
onChange={(value) => handleKeywordChange(value)}
/>
<Form
onSubmit={() => {
searchUsers(searchKeyword, searchGroup);
}}
labelPosition='left'
>
<div style={{ display: 'flex' }}>
<Space>
<Form.Input
label='搜索关键字'
icon='search'
field='keyword'
iconPosition='left'
placeholder='搜索用户的 ID用户名显示名称以及邮箱地址 ...'
value={searchKeyword}
loading={searching}
onChange={(value) => handleKeywordChange(value)}
/>
<Form.Select
field='group'
label='分组'
optionList={groupOptions}
onChange={(value) => {
setSearchGroup(value);
searchUsers(searchKeyword, value);
}}
/>
<Button
label='查询'
type='primary'
htmlType='submit'
className='btn-margin-right'
style={{ marginRight: 8 }}
>
查询
</Button>
</Space>
</div>
</Form>
<Table

View File

@@ -50,6 +50,13 @@ export const CHANNEL_OPTIONS = [
color: 'orange',
label: 'Google Gemini',
},
{
key: 34,
text: 'Cohere',
value: 34,
color: 'purple',
label: 'Cohere',
},
{
key: 15,
text: '百度文心千帆',

View File

@@ -155,6 +155,16 @@ const EditChannel = (props) => {
'gemini-pro-vision',
];
break;
case 34:
localModels = [
'command-r',
'command-r-plus',
'command-light',
'command-light-nightly',
'command',
'command-nightly',
];
break;
case 25:
localModels = [
'moonshot-v1-8k',