Compare commits

..

4 Commits

Author SHA1 Message Date
JustSong
6ab87f8a08 feat: add warning in log when system prompt is reset
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
2024-11-10 17:18:46 +08:00
JustSong
833fa7ad6f feat: support set system_prompt for theme air & berry 2024-11-10 15:09:02 +08:00
JustSong
6eb0770a89 feat: support set system prompt for channel (close #1920) 2024-11-10 14:53:34 +08:00
JustSong
92cd46d64f feat: able to use ENFORCE_INCLUDE_USAGE to enforce include usage in response
Some checks are pending
CI / Unit tests (push) Waiting to run
CI / commit_lint (push) Waiting to run
2024-11-10 00:36:08 +08:00
13 changed files with 335 additions and 245 deletions

View File

@@ -400,6 +400,7 @@ graph LR
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage默认不开启可选值为 `true` 和 `false`。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -160,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "") var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)

View File

@@ -20,4 +20,5 @@ const (
BaseURL = "base_url" BaseURL = "base_url"
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body" KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
) )

View File

@@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name) c.Set(ctxkey.ChannelName, channel.Name)
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
}
c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.OriginalModel, modelName) // for retry c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

View File

@@ -37,6 +37,7 @@ type Channel struct {
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"` Config string `json:"config"`
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
} }
type ChannelConfig struct { type ChannelConfig struct {

View File

@@ -1,5 +1,6 @@
package role package role
const ( const (
System = "system"
Assistant = "assistant" Assistant = "assistant"
) )

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/constant/role"
"math" "math"
"net/http" "net/http"
"strings" "strings"
@@ -90,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil return preConsumedQuota, nil
} }
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
if usage == nil { if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected") logger.Error(ctx, "usage is nil, which is unexpected")
return return
@@ -118,7 +119,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
if err != nil { if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error()) logger.Error(ctx, "error update user quota cache: "+err.Error())
} }
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) var extraLog string
if systemPromptReset {
extraLog = " (注意系统提示词已被重置)"
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota)
@@ -154,3 +159,23 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
} }
return false return false
} }
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
if prompt == "" {
return false
}
if len(request.Messages) == 0 {
return false
}
if request.Messages[0].Role == role.System {
request.Messages[0].Content = prompt
logger.Infof(ctx, "rewrite system prompt")
return true
}
request.Messages = append([]relaymodel.Message{{
Role: role.System,
Content: prompt,
}}, request.Messages...)
logger.Infof(ctx, "add system prompt")
return true
}

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config"
"io" "io"
"net/http" "net/http"
@@ -35,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.OriginModelName = textRequest.Model meta.OriginModelName = textRequest.Model
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// set system prompt if not empty
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
@@ -79,12 +82,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return respErr return respErr
} }
// post-consume quota // post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
return nil return nil
} }
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai // no need to convert request for openai
return c.Request.Body, nil return c.Request.Body, nil
} }

View File

@@ -30,6 +30,7 @@ type Meta struct {
ActualModelName string ActualModelName string
RequestURLPath string RequestURLPath string
PromptTokens int // only for DoResponse PromptTokens int // only for DoResponse
SystemPrompt string
} }
func GetByContext(c *gin.Context) *Meta { func GetByContext(c *gin.Context) *Meta {
@@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta {
BaseURL: c.GetString(ctxkey.BaseURL), BaseURL: c.GetString(ctxkey.BaseURL),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
SystemPrompt: c.GetString(ctxkey.SystemPrompt),
} }
cfg, ok := c.Get(ctxkey.Config) cfg, ok := c.Get(ctxkey.Config)
if ok { if ok {

View File

@@ -43,6 +43,7 @@ const EditChannel = (props) => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
system_prompt: '',
models: [], models: [],
auto_ban: 1, auto_ban: 1,
groups: ['default'] groups: ['default']
@@ -496,6 +497,19 @@ const EditChannel = (props) => {
value={inputs.model_mapping} value={inputs.model_mapping}
autoComplete='new-password' autoComplete='new-password'
/> />
<div style={{ marginTop: 10 }}>
<Typography.Text strong>系统提示词</Typography.Text>
</div>
<TextArea
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
name='system_prompt'
onChange={value => {
handleInputChange('system_prompt', value)
}}
autosize
value={inputs.system_prompt}
autoComplete='new-password'
/>
<Typography.Text style={{ <Typography.Text style={{
color: 'rgba(var(--semi-blue-5), 1)', color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none', userSelect: 'none',

View File

@@ -595,6 +595,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText> <FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
)} )}
</FormControl> </FormControl>
<FormControl fullWidth error={Boolean(touched.system_prompt && errors.system_prompt)} sx={{ ...theme.typography.otherInput }}>
{/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */}
<TextField
multiline
id="channel-system_prompt-label"
label={inputLabel.system_prompt}
value={values.system_prompt}
name="system_prompt"
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-system_prompt-label"
minRows={5}
placeholder={inputPrompt.system_prompt}
/>
{touched.system_prompt && errors.system_prompt ? (
<FormHelperText error id="helper-tex-channel-system_prompt-label">
{errors.system_prompt}
</FormHelperText>
) : (
<FormHelperText id="helper-tex-channel-system_prompt-label"> {inputPrompt.system_prompt} </FormHelperText>
)}
</FormControl>
<DialogActions> <DialogActions>
<Button onClick={onCancel}>取消</Button> <Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary"> <Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">

View File

@@ -18,6 +18,7 @@ const defaultConfig = {
other: '其他参数', other: '其他参数',
models: '模型', models: '模型',
model_mapping: '模型映射关系', model_mapping: '模型映射关系',
system_prompt: '系统提示词',
groups: '用户组', groups: '用户组',
config: null config: null
}, },
@@ -30,6 +31,7 @@ const defaultConfig = {
models: '请选择该渠道所支持的模型', models: '请选择该渠道所支持的模型',
model_mapping: model_mapping:
'请输入要修改的模型映射关系格式为api请求模型ID:实际转发给渠道的模型ID使用JSON数组表示例如{"gpt-3.5": "gpt-35"}', '请输入要修改的模型映射关系格式为api请求模型ID:实际转发给渠道的模型ID使用JSON数组表示例如{"gpt-3.5": "gpt-35"}',
system_prompt:"此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型",
groups: '请选择该渠道所支持的用户组', groups: '请选择该渠道所支持的用户组',
config: null config: null
}, },

View File

@@ -43,6 +43,7 @@ const EditChannel = () => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
system_prompt: '',
models: [], models: [],
groups: ['default'] groups: ['default']
}; };
@@ -425,7 +426,7 @@ const EditChannel = () => {
) )
} }
{ {
inputs.type !== 43 && ( inputs.type !== 43 && (<>
<Form.Field> <Form.Field>
<Form.TextArea <Form.TextArea
label='模型重定向' label='模型重定向'
@@ -437,6 +438,18 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
<Form.Field>
<Form.TextArea
label='系统提示词'
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
name='system_prompt'
onChange={handleInputChange}
value={inputs.system_prompt}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
</>
) )
} }
{ {