Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2024-06-17 03:16:36 +00:00
13 changed files with 74 additions and 73 deletions

View File

@@ -132,10 +132,10 @@ var ValidThemes = map[string]bool{
// All duration's unit is seconds // All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration // Shouldn't larger then RateLimitKeyExpirationDuration
var ( var (
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 240)
GlobalApiRateLimitDuration int64 = 3 * 60 GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 120)
GlobalWebRateLimitDuration int64 = 3 * 60 GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10 UploadRateLimitNum = 10

View File

@@ -43,11 +43,19 @@ func SysLog(s string) {
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysLogf(format string, a ...any) {
SysLog(fmt.Sprintf(format, a...))
}
func SysError(s string) { func SysError(s string) {
t := time.Now() t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysErrorf(format string, a ...any) {
SysError(fmt.Sprintf(format, a...))
}
func Debug(ctx context.Context, msg string) { func Debug(ctx context.Context, msg string) {
if config.DebugEnabled { if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg) logHelper(ctx, loggerDEBUG, msg)

View File

@@ -12,8 +12,6 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
) )
@@ -117,7 +115,6 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
ctx := c.Request.Context()
if !config.RegisterEnabled { if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
@@ -182,28 +179,6 @@ func Register(c *gin.Context) {
}) })
return return
} }
go func() {
err := user.ValidateAndFill()
if err != nil {
logger.Errorf(ctx, "user.ValidateAndFill failed: %v", err)
return
}
cleanToken := model.Token{
UserId: user.Id,
Name: "default",
Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: -1,
UnlimitedQuota: true,
}
err = cleanToken.Insert()
if err != nil {
logger.Errorf(ctx, "cleanToken.Insert failed: %v", err)
return
}
}()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",

View File

@@ -26,7 +26,7 @@ var buildFS embed.FS
func main() { func main() {
logger.SetupLogger() logger.SetupLogger()
logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) logger.SysLogf("One API %s started", common.Version)
if os.Getenv("GIN_MODE") != "debug" { if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }

View File

@@ -91,26 +91,28 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.BaseURL, channel.GetBaseURL()) c.Set(ctxkey.BaseURL, channel.GetBaseURL())
cfg, _ := channel.LoadConfig() cfg, _ := channel.LoadConfig()
// this is for backward compatibility // this is for backward compatibility
if channel.Other != nil {
switch channel.Type { switch channel.Type {
case channeltype.Azure: case channeltype.Azure:
if cfg.APIVersion == "" { if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other cfg.APIVersion = *channel.Other
} }
case channeltype.Xunfei: case channeltype.Xunfei:
if cfg.APIVersion == "" { if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other cfg.APIVersion = *channel.Other
} }
case channeltype.Gemini: case channeltype.Gemini:
if cfg.APIVersion == "" { if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other cfg.APIVersion = *channel.Other
} }
case channeltype.AIProxyLibrary: case channeltype.AIProxyLibrary:
if cfg.LibraryID == "" { if cfg.LibraryID == "" {
cfg.LibraryID = channel.Other cfg.LibraryID = *channel.Other
} }
case channeltype.Ali: case channeltype.Ali:
if cfg.Plugin == "" { if cfg.Plugin == "" {
cfg.Plugin = channel.Other cfg.Plugin = *channel.Other
}
} }
} }
c.Set(ctxkey.Config, cfg) c.Set(ctxkey.Config, cfg)

View File

@@ -27,7 +27,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"` // DEPRECATED: please save config to field Config Other *string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"` Models string `json:"models"`

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm" "gorm.io/gorm"
@@ -141,6 +142,22 @@ func (user *User) Insert(inviterId int) error {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
// create default token
cleanToken := Token{
UserId: user.Id,
Name: "default",
Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: -1,
UnlimitedQuota: true,
}
result.Error = cleanToken.Insert()
if result.Error != nil {
// do not block
logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
}
return nil return nil
} }

View File

@@ -244,8 +244,10 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText() choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason //choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse var response openai.ChatCompletionsStreamResponse
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Created = helper.GetTimestamp()
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "gemini" response.Model = "gemini"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}

View File

@@ -28,14 +28,6 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
version := parseAPIVersionByModelName(meta.ActualModelName)
if version == "" {
version = a.meta.Config.APIVersion
}
if version == "" {
version = "v1.1"
}
a.meta.Config.APIVersion = version
// check DoResponse for auth part // check DoResponse for auth part
return nil return nil
} }
@@ -70,6 +62,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
if a.request == nil { if a.request == nil {
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
} }
version := parseAPIVersionByModelName(meta.ActualModelName)
if version == "" {
version = a.meta.Config.APIVersion
}
if version == "" {
version = "v1.1"
}
a.meta.Config.APIVersion = version
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2]) err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2])
} else { } else {

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -29,11 +30,7 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages)) messages := make([]Message, 0, len(request.Messages))
var lastToolCalls []model.Tool
for _, message := range request.Messages { for _, message := range request.Messages {
if message.ToolCalls != nil {
lastToolCalls = message.ToolCalls
}
messages = append(messages, Message{ messages = append(messages, Message{
Role: message.Role, Role: message.Role,
Content: message.StringContent(), Content: message.StringContent(),
@@ -46,9 +43,10 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages xunfeiRequest.Payload.Message.Text = messages
if len(lastToolCalls) != 0 {
for _, toolCall := range lastToolCalls { if strings.HasPrefix(domain, "generalv3") {
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) xunfeiRequest.Payload.Functions = &Functions{
Text: request.Tools,
} }
} }
@@ -204,7 +202,7 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq
} }
} }
if len(xunfeiResponse.Payload.Choices.Text) == 0 { if len(xunfeiResponse.Payload.Choices.Text) == 0 {
return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil return openai.ErrorWrapper(errors.New("xunfei empty response detected"), "xunfei_empty_response_detected", http.StatusInternalServerError), nil
} }
xunfeiResponse.Payload.Choices.Text[0].Content = content xunfeiResponse.Payload.Choices.Text[0].Content = content

View File

@@ -9,6 +9,10 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
} }
type Functions struct {
Text []model.Tool `json:"text,omitempty"`
}
type ChatRequest struct { type ChatRequest struct {
Header struct { Header struct {
AppId string `json:"app_id"` AppId string `json:"app_id"`
@@ -26,9 +30,7 @@ type ChatRequest struct {
Message struct { Message struct {
Text []Message `json:"text"` Text []Message `json:"text"`
} `json:"message"` } `json:"message"`
Functions struct { Functions *Functions `json:"functions,omitempty"`
Text []model.Function `json:"text,omitempty"`
} `json:"functions,omitempty"`
} `json:"payload"` } `json:"payload"`
} }

View File

@@ -163,7 +163,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
values.other = 'v2.1'; values.other = 'v2.1';
} }
if (values.key === '') { if (values.key === '') {
if (values.config.ak !== '' && values.config.sk !== '' && values.config.region !== '') { if (values.config.ak && values.config.sk && values.config.region) {
values.key = `${values.config.ak}|${values.config.sk}|${values.config.region}`; values.key = `${values.config.ak}|${values.config.sk}|${values.config.region}`;
} }
} }

View File

@@ -181,9 +181,6 @@ const EditChannel = () => {
if (localInputs.type === 3 && localInputs.other === '') { if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2024-03-01-preview'; localInputs.other = '2024-03-01-preview';
} }
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
}
let res; let res;
localInputs.models = localInputs.models.join(','); localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(','); localInputs.group = localInputs.groups.join(',');