fix: model id

This commit is contained in:
Laisky.Cai
2024-04-18 03:37:25 +00:00
parent 03184457f9
commit e93e489ea9
5 changed files with 11 additions and 5 deletions

View File

@@ -11,4 +11,5 @@ var (
CtxKeyRequestModel string = "request_model" CtxKeyRequestModel string = "request_model"
CtxKeyRawRequest string = "raw_request" CtxKeyRawRequest string = "raw_request"
CtxKeyConvertedRequest string = "converted_request" CtxKeyConvertedRequest string = "converted_request"
CtxKeyOriginModel string = "origin_model"
) )

View File

@@ -55,7 +55,7 @@ func Relay(c *gin.Context) {
lastFailedChannelId := channelId lastFailedChannelId := channelId
channelName := c.GetString("channel_name") channelName := c.GetString("channel_name")
group := c.GetString("group") group := c.GetString("group")
originalModel := c.GetString("original_model") originalModel := c.GetString(common.CtxKeyOriginModel)
go processChannelRelayError(ctx, channelId, channelName, bizErr) go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey) requestId := c.GetString(logger.RequestIdKey)
retryTimes := config.RetryTimes retryTimes := config.RetryTimes

View File

@@ -79,7 +79,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Set("original_model", modelName) // for retry c.Set(common.CtxKeyOriginModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility // this is for backward compatibility

View File

@@ -93,6 +93,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
func InitDB(envName string) (db *gorm.DB, err error) { func InitDB(envName string) (db *gorm.DB, err error) {
db, err = chooseDB(envName) db, err = chooseDB(envName)
if config.DebugEnabled {
db = db.Debug()
}
if err == nil { if err == nil {
if config.DebugSQLEnabled { if config.DebugSQLEnabled {
db = db.Debug() db = db.Debug()

View File

@@ -81,7 +81,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return wrapErr(errors.Wrap(err, "newAwsClient")), nil return wrapErr(errors.Wrap(err, "newAwsClient")), nil
} }
awsModelId, err := awsModelID(channel.Models) awsModelId, err := awsModelID(c.GetString(common.CtxKeyOriginModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return wrapErr(errors.Wrap(err, "awsModelID")), nil
} }
@@ -148,7 +148,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
return wrapErr(errors.Wrap(err, "newAwsClient")), nil return wrapErr(errors.Wrap(err, "newAwsClient")), nil
} }
awsModelId, err := awsModelID(channel.Models) awsModelId, err := awsModelID(c.GetString(common.CtxKeyOriginModel))
if err != nil { if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil return wrapErr(errors.Wrap(err, "awsModelID")), nil
} }
@@ -211,7 +211,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
return true return true
} }
response.Id = id response.Id = id
response.Model = c.GetString("original_model") response.Model = c.GetString(common.CtxKeyOriginModel)
response.Created = createdTime response.Created = createdTime
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {