fix: model id

This commit is contained in:
Laisky.Cai 2024-04-18 03:37:25 +00:00
parent 03184457f9
commit e93e489ea9
5 changed files with 11 additions and 5 deletions

View File

@ -11,4 +11,5 @@ var (
CtxKeyRequestModel string = "request_model"
CtxKeyRawRequest string = "raw_request"
CtxKeyConvertedRequest string = "converted_request"
CtxKeyOriginModel string = "origin_model"
)

View File

@ -55,7 +55,7 @@ func Relay(c *gin.Context) {
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
originalModel := c.GetString(common.CtxKeyOriginModel)
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey)
retryTimes := config.RetryTimes

View File

@ -79,7 +79,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping())
c.Set("original_model", modelName) // for retry
c.Set(common.CtxKeyOriginModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility

View File

@ -93,6 +93,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
func InitDB(envName string) (db *gorm.DB, err error) {
db, err = chooseDB(envName)
if config.DebugEnabled {
db = db.Debug()
}
if err == nil {
if config.DebugSQLEnabled {
db = db.Debug()

View File

@ -81,7 +81,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(channel.Models)
awsModelId, err := awsModelID(c.GetString(common.CtxKeyOriginModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
@ -148,7 +148,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(channel.Models)
awsModelId, err := awsModelID(c.GetString(common.CtxKeyOriginModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
@ -211,7 +211,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
return true
}
response.Id = id
response.Model = c.GetString("original_model")
response.Model = c.GetString(common.CtxKeyOriginModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {