diff --git a/__debug_bin1148682411 b/__debug_bin1148682411 new file mode 100755 index 00000000..06be031f Binary files /dev/null and b/__debug_bin1148682411 differ diff --git a/common/helper/helper.go b/common/helper/helper.go index 662de16c..22bc3700 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -1,6 +1,7 @@ package helper import ( + "context" "fmt" "html/template" "log" @@ -107,6 +108,18 @@ func GenRequestID() string { return GetTimeString() + random.GetRandomNumberString(8) } +func SetRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, RequestIdKey, id) +} + +func GetRequestID(ctx context.Context) string { + rawRequestId := ctx.Value(RequestIdKey) + if rawRequestId == nil { + return "" + } + return rawRequestId.(string) +} + func GetResponseID(c *gin.Context) string { logID := c.GetString(RequestIdKey) return fmt.Sprintf("chatcmpl-%s", logID) diff --git a/common/logger/logger.go b/common/logger/logger.go index d1022932..1e3bc254 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -7,19 +7,25 @@ import ( "log" "os" "path/filepath" + "runtime" + "strings" "sync" "time" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" ) +type loggerLevel string + const ( - loggerDEBUG = "DEBUG" - loggerINFO = "INFO" - loggerWarn = "WARN" - loggerError = "ERR" + loggerDEBUG loggerLevel = "DEBUG" + loggerINFO loggerLevel = "INFO" + loggerWarn loggerLevel = "WARN" + loggerError loggerLevel = "ERROR" + loggerFatal loggerLevel = "FATAL" ) var setupLogOnce sync.Once @@ -44,27 +50,26 @@ func SetupLogger() { } func SysLog(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) + logHelper(nil, loggerINFO, s) } func SysLogf(format string, a ...any) { - SysLog(fmt.Sprintf(format, a...)) + logHelper(nil, loggerINFO, fmt.Sprintf(format, a...)) } func SysError(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) + logHelper(nil, loggerError, s) } func SysErrorf(format string, a ...any) { - SysError(fmt.Sprintf(format, a...)) + logHelper(nil, loggerError, fmt.Sprintf(format, a...)) } func Debug(ctx context.Context, msg string) { - if config.DebugEnabled { - logHelper(ctx, loggerDEBUG, msg) + if !config.DebugEnabled { + return } + logHelper(ctx, loggerDEBUG, msg) } func Info(ctx context.Context, msg string) { @@ -80,37 +85,65 @@ func Error(ctx context.Context, msg string) { } func Debugf(ctx context.Context, format string, a ...any) { - Debug(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...)) } func Infof(ctx context.Context, format string, a ...any) { - Info(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...)) } func Warnf(ctx context.Context, format string, a ...any) { - Warn(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...)) } func Errorf(ctx context.Context, format string, a ...any) { - Error(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerError, fmt.Sprintf(format, a...)) } -func logHelper(ctx context.Context, level string, msg string) { +func FatalLog(s string) { + logHelper(nil, loggerFatal, s) +} + +func FatalLogf(format string, a ...any) { + logHelper(nil, loggerFatal, fmt.Sprintf(format, a...)) +} + +func logHelper(ctx context.Context, level loggerLevel, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(helper.RequestIdKey) - if id == nil { - id = helper.GenRequestID() + var requestId string + if ctx != nil { + rawRequestId := helper.GetRequestID(ctx) + if rawRequestId != "" { + requestId = fmt.Sprintf(" | %s", rawRequestId) + } } + lineInfo, funcName := getLineInfo() now := time.Now() - _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) + _, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg) SetupLogger() + if level == loggerFatal { + os.Exit(1) + } } -func FatalLog(v ...any) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) - os.Exit(1) +func getLineInfo() (string, string) { + funcName := "[unknown] " + pc, file, line, ok := runtime.Caller(3) + if ok { + if fn := runtime.FuncForPC(pc); fn != nil { + parts := strings.Split(fn.Name(), ".") + funcName = "[" + parts[len(parts)-1] + "] " + } + } else { + file = "unknown" + line = 0 + } + parts := strings.Split(file, "one-api/") + if len(parts) > 1 { + file = parts[1] + } + return fmt.Sprintf(" | %s:%d", file, line), funcName } diff --git a/controller/auth/github.go b/controller/auth/github.go index 95d44822..8fc97400 100644 --- a/controller/auth/github.go +++ b/controller/auth/github.go @@ -4,6 +4,10 @@ import ( "bytes" "encoding/json" "fmt" + "net/http" + "strconv" + "time" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/pkg/errors" @@ -12,9 +16,6 @@ import ( "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" - "time" ) type GitHubOAuthResponse struct { @@ -81,6 +82,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } func GitHubOAuth(c *gin.Context) { + ctx := c.Request.Context() session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { @@ -136,7 +138,7 @@ func GitHubOAuth(c *gin.Context) { user.Role = model.RoleCommonUser user.Status = model.UserStatusEnabled - if err := user.Insert(0); err != nil { + if err := user.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/controller/auth/lark.go b/controller/auth/lark.go index 35f269e9..4fa4435d 100644 --- a/controller/auth/lark.go +++ b/controller/auth/lark.go @@ -80,6 +80,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { } func LarkOAuth(c *gin.Context) { + ctx := c.Request.Context() session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { @@ -126,7 +127,7 @@ func LarkOAuth(c *gin.Context) { user.Role = model.RoleCommonUser user.Status = model.UserStatusEnabled - if err := user.Insert(0); err != nil { + if err := user.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go index b82c7a08..8ada7628 100644 --- a/controller/auth/oidc.go +++ b/controller/auth/oidc.go @@ -11,6 +11,7 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/controller" @@ -88,6 +89,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } func OidcAuth(c *gin.Context) { + ctx := c.Request.Context() session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { @@ -143,7 +145,7 @@ func OidcAuth(c *gin.Context) { } else { user.DisplayName = "OIDC User" } - err := user.Insert(0) + err := user.Insert(ctx, 0) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/auth/wechat.go b/controller/auth/wechat.go index 859ae262..78b45b58 100644 --- a/controller/auth/wechat.go +++ b/controller/auth/wechat.go @@ -53,6 +53,7 @@ func getWeChatIdByCode(code string) (string, error) { } func WeChatAuth(c *gin.Context) { + ctx := c.Request.Context() if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "The administrator has not enabled login and registration via WeChat", @@ -88,7 +89,7 @@ func WeChatAuth(c *gin.Context) { user.Role = model.RoleCommonUser user.Status = model.UserStatusEnabled - if err := user.Insert(0); err != nil { + if err := user.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/controller/user.go b/controller/user.go index b6fac67e..4a21782a 100644 --- a/controller/user.go +++ b/controller/user.go @@ -115,6 +115,7 @@ func Logout(c *gin.Context) { } func Register(c *gin.Context) { + ctx := c.Request.Context() if !config.RegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "The administrator has turned off new user registration", @@ -172,7 +173,7 @@ func Register(c *gin.Context) { if config.EmailVerificationEnabled { cleanUser.Email = user.Email } - if err := cleanUser.Insert(inviterId); err != nil { + if err := cleanUser.Insert(ctx, inviterId); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -377,6 +378,7 @@ func GetSelf(c *gin.Context) { } func UpdateUser(c *gin.Context) { + ctx := c.Request.Context() var updatedUser model.User err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) if err != nil || updatedUser.Id == 0 { @@ -431,7 +433,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("The administrator changed the user quota from %s to %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("Admin changed user quota from %s to %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, @@ -550,6 +552,7 @@ func DeleteSelf(c *gin.Context) { } func CreateUser(c *gin.Context) { + ctx := c.Request.Context() var user model.User err := json.NewDecoder(c.Request.Body).Decode(&user) if err != nil || user.Username == "" || user.Password == "" { @@ -583,7 +586,7 @@ func CreateUser(c *gin.Context) { Password: user.Password, DisplayName: user.DisplayName, } - if err := cleanUser.Insert(0); err != nil { + if err := cleanUser.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -762,6 +765,7 @@ type topUpRequest struct { } func TopUp(c *gin.Context) { + ctx := c.Request.Context() req := topUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { @@ -772,7 +776,7 @@ func TopUp(c *gin.Context) { return } id := c.GetInt("id") - quota, err := model.Redeem(req.Key, id) + quota, err := model.Redeem(ctx, req.Key, id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -795,6 +799,7 @@ type adminTopUpRequest struct { } func AdminTopUp(c *gin.Context) { + ctx := c.Request.Context() req := adminTopUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { @@ -815,7 +820,7 @@ func AdminTopUp(c *gin.Context) { if req.Remark == "" { req.Remark = fmt.Sprintf("Recharged via API %s", common.LogQuota(int64(req.Quota))) } - model.RecordTopupLog(req.UserId, req.Remark, req.Quota) + model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/middleware/distributor.go b/middleware/distributor.go index 2cac55cd..c697e729 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -21,6 +21,7 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { + ctx := c.Request.Context() userId := c.GetInt(ctxkey.Id) userGroup, _ := model.CacheGetUserGroup(userId) c.Set(ctxkey.Group, userGroup) @@ -56,6 +57,7 @@ func Distribute() func(c *gin.Context) { return } } + logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id) SetupContextForSelectedChannel(c, channel, requestModel) c.Next() } diff --git a/middleware/request-id.go b/middleware/request-id.go index c1f3adc2..c55984bf 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -1,8 +1,6 @@ package middleware import ( - "context" - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" ) @@ -11,7 +9,7 @@ func RequestId() func(c *gin.Context) { return func(c *gin.Context) { id := helper.GenRequestID() c.Set(helper.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) + ctx := helper.SetRequestID(c.Request.Context(), id) c.Request = c.Request.WithContext(ctx) c.Header(helper.RequestIdKey, id) c.Next() diff --git a/model/log.go b/model/log.go index 58fdd513..1fd7ee84 100644 --- a/model/log.go +++ b/model/log.go @@ -4,11 +4,12 @@ import ( "context" "fmt" + "gorm.io/gorm" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "gorm.io/gorm" ) type Log struct { @@ -24,6 +25,7 @@ type Log struct { PromptTokens int `json:"prompt_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"` ChannelId int `json:"channel" gorm:"index"` + RequestId string `json:"request_id"` } const ( @@ -34,7 +36,18 @@ const ( LogTypeSystem ) -func RecordLog(userId int, logType int, content string) { +func recordLogHelper(ctx context.Context, log *Log) { + requestId := helper.GetRequestID(ctx) + log.RequestId = requestId + err := LOG_DB.Create(log).Error + if err != nil { + logger.Error(ctx, "failed to record log: "+err.Error()) + return + } + logger.Infof(ctx, "record log: %+v", log) +} + +func RecordLog(ctx context.Context, userId int, logType int, content string) { if logType == LogTypeConsume && !config.LogConsumeEnabled { return } @@ -45,13 +58,10 @@ func RecordLog(userId int, logType int, content string) { Type: logType, Content: content, } - err := LOG_DB.Create(log).Error - if err != nil { - logger.SysError("failed to record log: " + err.Error()) - } + recordLogHelper(ctx, log) } -func RecordTopupLog(userId int, content string, quota int) { +func RecordTopupLog(ctx context.Context, userId int, content string, quota int) { log := &Log{ UserId: userId, Username: GetUsernameById(userId), @@ -60,14 +70,10 @@ func RecordTopupLog(userId int, content string, quota int) { Content: content, Quota: quota, } - err := LOG_DB.Create(log).Error - if err != nil { - logger.SysError("failed to record log: " + err.Error()) - } + recordLogHelper(ctx, log) } func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { - logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { return } @@ -84,10 +90,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke Quota: int(quota), ChannelId: channelId, } - err := LOG_DB.Create(log).Error - if err != nil { - logger.Error(ctx, "failed to record log: "+err.Error()) - } + recordLogHelper(ctx, log) } func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { diff --git a/model/redemption.go b/model/redemption.go index b3f80947..7171e408 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -1,6 +1,7 @@ package model import ( + "context" "fmt" "github.com/pkg/errors" @@ -49,7 +50,7 @@ func GetRedemptionById(id int) (*Redemption, error) { return &redemption, err } -func Redeem(key string, userId int) (quota int64, err error) { +func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) { if key == "" { return 0, errors.New("No redemption code provided") } @@ -83,7 +84,7 @@ func Redeem(key string, userId int) (quota int64, err error) { if err != nil { return 0, errors.New("Redeem failed, " + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("Recharge %s through redemption code", common.LogQuota(redemption.Quota))) + RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("Recharged %s using redemption code", common.LogQuota(redemption.Quota))) return redemption.Quota, nil } diff --git a/model/user.go b/model/user.go index 9ff8a093..31816d85 100644 --- a/model/user.go +++ b/model/user.go @@ -1,6 +1,7 @@ package model import ( + "context" "fmt" "strings" @@ -115,7 +116,7 @@ func DeleteUserById(id int) (err error) { return user.Delete() } -func (user *User) Insert(inviterId int) error { +func (user *User) Insert(ctx context.Context, inviterId int) error { var err error if user.Password != "" { user.Password, err = common.Password2Hash(user.Password) @@ -131,16 +132,16 @@ func (user *User) Insert(inviterId int) error { return result.Error } if config.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("New user registration gives %s", common.LogQuota(config.QuotaForNewUser))) + RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) } if inviterId != 0 { if config.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("Use invitation code to give %s", common.LogQuota(config.QuotaForInvitee))) + RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("Gifted %s for using invitation code", common.LogQuota(config.QuotaForInvitee))) } if config.QuotaForInviter > 0 { _ = IncreaseUserQuota(inviterId, config.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("Invite users to give %s", common.LogQuota(config.QuotaForInviter))) + RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("Gifted %s for inviting user", common.LogQuota(config.QuotaForInviter))) } } // create default token diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go index 9d086eed..cae2d328 100644 --- a/relay/adaptor/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) // https://cloud.tencent.com/document/api/1729/101837 @@ -53,10 +54,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if err != nil { return nil, err } - tencentRequest := ConvertRequest(*request) + var convertedRequest any + switch relayMode { + case relaymode.Embeddings: + a.Action = "GetEmbedding" + convertedRequest = ConvertEmbeddingRequest(*request) + default: + a.Action = "ChatCompletions" + convertedRequest = ConvertRequest(*request) + } // we have to calculate the sign here - a.Sign = GetSign(*tencentRequest, a, secretId, secretKey) - return tencentRequest, nil + a.Sign = GetSign(convertedRequest, a, secretId, secretKey) + return convertedRequest, nil } func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { @@ -76,7 +85,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met err, responseText = StreamHandler(c, resp) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } else { - err, usage = Handler(c, resp) + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } } return } diff --git a/relay/adaptor/tencent/constants.go b/relay/adaptor/tencent/constants.go index e8631e5f..7997bfd6 100644 --- a/relay/adaptor/tencent/constants.go +++ b/relay/adaptor/tencent/constants.go @@ -6,4 +6,5 @@ var ModelList = []string{ "hunyuan-standard-256K", "hunyuan-pro", "hunyuan-vision", + "hunyuan-embedding", } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index c402e543..8bf8e469 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -15,8 +15,10 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" @@ -44,8 +46,68 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { } } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + InputList: request.ParseInput(), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var tencentResponseP EmbeddingResponseP + err := json.NewDecoder(resp.Body).Decode(&tencentResponseP) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + tencentResponse := tencentResponseP.Response + if tencentResponse.Error.Code != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: tencentResponse.Error.Message, + Code: tencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + requestModel := c.GetString(ctxkey.RequestModel) + fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse) + fullTextResponse.Model = requestModel + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)), + Model: "hunyuan-embedding", + Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens}, + } + + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ + Id: response.ReqID, Object: "chat.completion", Created: helper.GetTimestamp(), Usage: model.Usage{ @@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } TencentResponse = responseP.Response - if TencentResponse.Error.Code != 0 { + if TencentResponse.Error.Code != "" { return &model.ErrorWithStatusCode{ Error: model.Error{ Message: TencentResponse.Error.Message, @@ -195,7 +257,7 @@ func hmacSha256(s, key string) string { return string(hashed.Sum(nil)) } -func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { +func GetSign(req any, adaptor *Adaptor, secId, secKey string) string { // build canonical request string host := "hunyuan.tencentcloudapi.com" httpRequestMethod := "POST" diff --git a/relay/adaptor/tencent/model.go b/relay/adaptor/tencent/model.go index b34398fb..0d0638e5 100644 --- a/relay/adaptor/tencent/model.go +++ b/relay/adaptor/tencent/model.go @@ -41,16 +41,16 @@ type ChatRequest struct { // 1. Affects the diversity of the output text. The larger the value, the more diverse the generated text. // 2. The value range is [0.0, 1.0]. If not provided, the recommended value for each model is used. // 3. It is not recommended to use this unless necessary, as unreasonable values can affect the results. - TopP *float64 `json:"TopP"` + TopP *float64 `json:"TopP,omitempty"` // Description: // 1. Higher values make the output more random, while lower values make it more focused and deterministic. // 2. The value range is [0.0, 2.0]. If not provided, the recommended value for each model is used. // 3. It is not recommended to use this unless necessary, as unreasonable values can affect the results. - Temperature *float64 `json:"Temperature"` + Temperature *float64 `json:"Temperature,omitempty"` } type Error struct { - Code int `json:"Code"` + Code string `json:"Code"` Message string `json:"Message"` } @@ -67,15 +67,41 @@ type ResponseChoices struct { } type ChatResponse struct { - Choices []ResponseChoices `json:"Choices,omitempty"` // Results - Created int64 `json:"Created,omitempty"` // Unix timestamp string - Id string `json:"Id,omitempty"` // Session ID - Usage Usage `json:"Usage,omitempty"` // Token count - Error Error `json:"Error,omitempty"` // Error information. Note: This field may return null, indicating that no valid value was found. - Note string `json:"Note,omitempty"` // Note - ReqID string `json:"Req_id,omitempty"` // Unique request ID, returned with each request. Used for feedback on interface input parameters. + Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 + Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 + Id string `json:"Id,omitempty"` // 会话 id + Usage Usage `json:"Usage,omitempty"` // token 数量 + Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"Note,omitempty"` // 注释 + ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 } type ChatResponseP struct { Response ChatResponse `json:"Response,omitempty"` } + +type EmbeddingRequest struct { + InputList []string `json:"InputList"` +} + +type EmbeddingData struct { + Embedding []float64 `json:"Embedding"` + Index int `json:"Index"` + Object string `json:"Object"` +} + +type EmbeddingUsage struct { + PromptTokens int `json:"PromptTokens"` + TotalTokens int `json:"TotalTokens"` +} + +type EmbeddingResponse struct { + Data []EmbeddingData `json:"Data"` + EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"` + RequestId string `json:"RequestId,omitempty"` + Error Error `json:"Error,omitempty"` +} + +type EmbeddingResponseP struct { + Response EmbeddingResponse `json:"Response,omitempty"` +} diff --git a/relay/billing/billing.go b/relay/billing/billing.go index 6017b1af..a97a3415 100644 --- a/relay/billing/billing.go +++ b/relay/billing/billing.go @@ -3,6 +3,7 @@ package billing import ( "context" "fmt" + "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" )