feat: support set system prompt for channel (close #1920)

This commit is contained in:
JustSong 2024-11-10 14:53:34 +08:00
parent 92cd46d64f
commit 6eb0770a89
8 changed files with 44 additions and 1 deletions

View File

@ -20,4 +20,5 @@ const (
BaseURL = "base_url" BaseURL = "base_url"
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body" KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
) )

View File

@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name) c.Set(ctxkey.ChannelName, channel.Name)
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
}
c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.OriginalModel, modelName) // for retry c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

View File

@ -37,6 +37,7 @@ type Channel struct {
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"` Config string `json:"config"`
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
} }
type ChannelConfig struct { type ChannelConfig struct {

View File

@ -1,5 +1,6 @@
package role package role
const ( const (
System = "system"
Assistant = "assistant" Assistant = "assistant"
) )

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/constant/role"
"math" "math"
"net/http" "net/http"
"strings" "strings"
@ -154,3 +155,22 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
} }
return false return false
} }
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) {
if prompt == "" {
return
}
if len(request.Messages) == 0 {
return
}
if request.Messages[0].Role == role.System {
request.Messages[0].Content = prompt
logger.Infof(ctx, "rewrite system prompt")
return
}
request.Messages = append([]relaymodel.Message{{
Role: role.System,
Content: prompt,
}}, request.Messages...)
logger.Infof(ctx, "add system prompt")
}

View File

@ -36,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.OriginModelName = textRequest.Model meta.OriginModelName = textRequest.Model
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// set system prompt if not empty
setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)

View File

@ -30,6 +30,7 @@ type Meta struct {
ActualModelName string ActualModelName string
RequestURLPath string RequestURLPath string
PromptTokens int // only for DoResponse PromptTokens int // only for DoResponse
SystemPrompt string
} }
func GetByContext(c *gin.Context) *Meta { func GetByContext(c *gin.Context) *Meta {
@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta {
BaseURL: c.GetString(ctxkey.BaseURL), BaseURL: c.GetString(ctxkey.BaseURL),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
SystemPrompt: c.GetString(ctxkey.SystemPrompt),
} }
cfg, ok := c.Get(ctxkey.Config) cfg, ok := c.Get(ctxkey.Config)
if ok { if ok {

View File

@ -43,6 +43,7 @@ const EditChannel = () => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
system_prompt: '',
models: [], models: [],
groups: ['default'] groups: ['default']
}; };
@ -425,7 +426,7 @@ const EditChannel = () => {
) )
} }
{ {
inputs.type !== 43 && ( inputs.type !== 43 && (<>
<Form.Field> <Form.Field>
<Form.TextArea <Form.TextArea
label='模型重定向' label='模型重定向'
@ -437,6 +438,18 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
<Form.Field>
<Form.TextArea
label='系统提示词'
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
name='system_prompt'
onChange={handleInputChange}
value={inputs.system_prompt}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
</>
) )
} }
{ {