From e93e489ea9ee1ae437d0b0976e0b0d504fb78c5d Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Thu, 18 Apr 2024 03:37:25 +0000 Subject: [PATCH] fix: model id --- common/constants.go | 1 + controller/relay.go | 2 +- middleware/distributor.go | 2 +- model/main.go | 5 +++++ relay/adaptor/aws/main.go | 6 +++--- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/common/constants.go b/common/constants.go index e4466a57..cdcec59b 100644 --- a/common/constants.go +++ b/common/constants.go @@ -11,4 +11,5 @@ var ( CtxKeyRequestModel string = "request_model" CtxKeyRawRequest string = "raw_request" CtxKeyConvertedRequest string = "converted_request" + CtxKeyOriginModel string = "origin_model" ) diff --git a/controller/relay.go b/controller/relay.go index 8278bf23..51ded1d1 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -55,7 +55,7 @@ func Relay(c *gin.Context) { lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") - originalModel := c.GetString("original_model") + originalModel := c.GetString(common.CtxKeyOriginModel) go processChannelRelayError(ctx, channelId, channelName, bizErr) requestId := c.GetString(logger.RequestIdKey) retryTimes := config.RetryTimes diff --git a/middleware/distributor.go b/middleware/distributor.go index 7abb0be2..b6952d42 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -79,7 +79,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("model_mapping", channel.GetModelMapping()) - c.Set("original_model", modelName) // for retry + c.Set(common.CtxKeyOriginModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // this is for backward compatibility diff --git a/model/main.go b/model/main.go index e5124a4c..3c6ab79d 100644 --- a/model/main.go +++ b/model/main.go @@ -93,6 +93,11 @@ func chooseDB(envName string) (*gorm.DB, error) { func InitDB(envName string) (db *gorm.DB, err error) { db, err = chooseDB(envName) + + if config.DebugEnabled { + db = db.Debug() + } + if err == nil { if config.DebugSQLEnabled { db = db.Debug() diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 0c0643ed..16ea6809 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -81,7 +81,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return wrapErr(errors.Wrap(err, "newAwsClient")), nil } - awsModelId, err := awsModelID(channel.Models) + awsModelId, err := awsModelID(c.GetString(common.CtxKeyOriginModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -148,7 +148,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt return wrapErr(errors.Wrap(err, "newAwsClient")), nil } - awsModelId, err := awsModelID(channel.Models) + awsModelId, err := awsModelID(c.GetString(common.CtxKeyOriginModel)) if err != nil { return wrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -211,7 +211,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt return true } response.Id = id - response.Model = c.GetString("original_model") + response.Model = c.GetString(common.CtxKeyOriginModel) response.Created = createdTime jsonStr, err := json.Marshal(response) if err != nil {