diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 90556b3a..115558a5 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -20,4 +20,5 @@ const ( BaseURL = "base_url" AvailableModels = "available_models" KeyRequestBody = "key_request_body" + SystemPrompt = "system_prompt" ) diff --git a/middleware/distributor.go b/middleware/distributor.go index e2f75110..0aceb29d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) + if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { + c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) + } c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/model/channel.go b/model/channel.go index 759dfd4f..4b0f4b01 100644 --- a/model/channel.go +++ b/model/channel.go @@ -37,6 +37,7 @@ type Channel struct { ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` Config string `json:"config"` + SystemPrompt *string `json:"system_prompt" gorm:"type:text"` } type ChannelConfig struct { diff --git a/relay/constant/role/define.go b/relay/constant/role/define.go index 972488c5..5097c97e 100644 --- a/relay/constant/role/define.go +++ b/relay/constant/role/define.go @@ -1,5 +1,6 @@ package role const ( + System = "system" Assistant = "assistant" ) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 87d22f13..4d03a045 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/songquanpeng/one-api/relay/constant/role" "math" "net/http" "strings" @@ -154,3 +155,22 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { } return false } + +func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) { + if prompt == "" { + return + } + if len(request.Messages) == 0 { + return + } + if request.Messages[0].Role == role.System { + request.Messages[0].Content = prompt + logger.Infof(ctx, "rewrite system prompt") + return + } + request.Messages = append([]relaymodel.Message{{ + Role: role.System, + Content: prompt, + }}, request.Messages...) + logger.Infof(ctx, "add system prompt") +} diff --git a/relay/controller/text.go b/relay/controller/text.go index 57f98812..0a2f4b54 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -36,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { meta.OriginModelName = textRequest.Model textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model + // set system prompt if not empty + setSystemPrompt(ctx, textRequest, meta.SystemPrompt) // get model ratio & group ratio modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index b1761e9a..bcbe1045 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -30,6 +30,7 @@ type Meta struct { ActualModelName string RequestURLPath string PromptTokens int // only for DoResponse + SystemPrompt string } func GetByContext(c *gin.Context) *Meta { @@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta { BaseURL: c.GetString(ctxkey.BaseURL), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), + SystemPrompt: c.GetString(ctxkey.SystemPrompt), } cfg, ok := c.Get(ctxkey.Config) if ok { diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index b967907e..f10658c3 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -43,6 +43,7 @@ const EditChannel = () => { base_url: '', other: '', model_mapping: '', + system_prompt: '', models: [], groups: ['default'] }; @@ -425,7 +426,7 @@ const EditChannel = () => { ) } { - inputs.type !== 43 && ( + inputs.type !== 43 && (<>