mirror of
				https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
				synced 2025-11-04 16:23:41 +08:00 
			
		
		
		
	Merge pull request #2109 from Yidadaa/bugfix-0623
feat: close #1789 add user input template
This commit is contained in:
		@@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store";
 | 
			
		||||
 | 
			
		||||
import Locale from "../locales";
 | 
			
		||||
import { InputRange } from "./input-range";
 | 
			
		||||
import { List, ListItem, Select } from "./ui-lib";
 | 
			
		||||
import { ListItem, Select } from "./ui-lib";
 | 
			
		||||
 | 
			
		||||
export function ModelConfigList(props: {
 | 
			
		||||
  modelConfig: ModelConfig;
 | 
			
		||||
@@ -109,6 +109,21 @@ export function ModelConfigList(props: {
 | 
			
		||||
        ></InputRange>
 | 
			
		||||
      </ListItem>
 | 
			
		||||
 | 
			
		||||
      <ListItem
 | 
			
		||||
        title={Locale.Settings.InputTemplate.Title}
 | 
			
		||||
        subTitle={Locale.Settings.InputTemplate.SubTitle}
 | 
			
		||||
      >
 | 
			
		||||
        <input
 | 
			
		||||
          type="text"
 | 
			
		||||
          value={props.modelConfig.template}
 | 
			
		||||
          onChange={(e) =>
 | 
			
		||||
            props.updateConfig(
 | 
			
		||||
              (config) => (config.template = e.currentTarget.value),
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
        ></input>
 | 
			
		||||
      </ListItem>
 | 
			
		||||
 | 
			
		||||
      <ListItem
 | 
			
		||||
        title={Locale.Settings.HistoryCount.Title}
 | 
			
		||||
        subTitle={Locale.Settings.HistoryCount.SubTitle}
 | 
			
		||||
 
 | 
			
		||||
@@ -52,3 +52,10 @@ export const OpenaiPath = {
 | 
			
		||||
  UsagePath: "dashboard/billing/usage",
 | 
			
		||||
  SubsPath: "dashboard/billing/subscription",
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export const DEFAULT_INPUT_TEMPLATE = `
 | 
			
		||||
Act as a virtual assistant powered by model: '{{model}}', my input is:
 | 
			
		||||
'''
 | 
			
		||||
{{input}}
 | 
			
		||||
'''
 | 
			
		||||
`;
 | 
			
		||||
 
 | 
			
		||||
@@ -115,6 +115,11 @@ const cn = {
 | 
			
		||||
      SubTitle: "聊天内容的字体大小",
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    InputTemplate: {
 | 
			
		||||
      Title: "用户输入预处理",
 | 
			
		||||
      SubTitle: "用户最新的一条消息会填充到此模板",
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    Update: {
 | 
			
		||||
      Version: (x: string) => `当前版本:${x}`,
 | 
			
		||||
      IsLatest: "已是最新版本",
 | 
			
		||||
 
 | 
			
		||||
@@ -116,6 +116,12 @@ const en: LocaleType = {
 | 
			
		||||
      Title: "Font Size",
 | 
			
		||||
      SubTitle: "Adjust font size of chat content",
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    InputTemplate: {
 | 
			
		||||
      Title: "Input Template",
 | 
			
		||||
      SubTitle: "Newest message will be filled to this template",
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    Update: {
 | 
			
		||||
      Version: (x: string) => `Version: ${x}`,
 | 
			
		||||
      IsLatest: "Latest version",
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ export const BUILTIN_MASK_ID = 100000;
 | 
			
		||||
 | 
			
		||||
export const BUILTIN_MASK_STORE = {
 | 
			
		||||
  buildinId: BUILTIN_MASK_ID,
 | 
			
		||||
  masks: {} as Record<number, Mask>,
 | 
			
		||||
  masks: {} as Record<number, BuiltinMask>,
 | 
			
		||||
  get(id?: number) {
 | 
			
		||||
    if (!id) return undefined;
 | 
			
		||||
    return this.masks[id] as Mask | undefined;
 | 
			
		||||
@@ -21,6 +21,6 @@ export const BUILTIN_MASK_STORE = {
 | 
			
		||||
  },
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export const BUILTIN_MASKS: Mask[] = [...CN_MASKS, ...EN_MASKS].map((m) =>
 | 
			
		||||
  BUILTIN_MASK_STORE.add(m),
 | 
			
		||||
export const BUILTIN_MASKS: BuiltinMask[] = [...CN_MASKS, ...EN_MASKS].map(
 | 
			
		||||
  (m) => BUILTIN_MASK_STORE.add(m),
 | 
			
		||||
);
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,7 @@
 | 
			
		||||
import { ModelConfig } from "../store";
 | 
			
		||||
import { type Mask } from "../store/mask";
 | 
			
		||||
 | 
			
		||||
export type BuiltinMask = Omit<Mask, "id"> & {
 | 
			
		||||
  builtin: true;
 | 
			
		||||
export type BuiltinMask = Omit<Mask, "id" | "modelConfig"> & {
 | 
			
		||||
  builtin: Boolean;
 | 
			
		||||
  modelConfig: Partial<ModelConfig>;
 | 
			
		||||
};
 | 
			
		||||
 
 | 
			
		||||
@@ -3,11 +3,11 @@ import { persist } from "zustand/middleware";
 | 
			
		||||
 | 
			
		||||
import { trimTopic } from "../utils";
 | 
			
		||||
 | 
			
		||||
import Locale from "../locales";
 | 
			
		||||
import Locale, { getLang } from "../locales";
 | 
			
		||||
import { showToast } from "../components/ui-lib";
 | 
			
		||||
import { ModelType } from "./config";
 | 
			
		||||
import { ModelConfig, ModelType, useAppConfig } from "./config";
 | 
			
		||||
import { createEmptyMask, Mask } from "./mask";
 | 
			
		||||
import { StoreKey } from "../constant";
 | 
			
		||||
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
 | 
			
		||||
import { api, RequestMessage } from "../client/api";
 | 
			
		||||
import { ChatControllerPool } from "../client/controller";
 | 
			
		||||
import { prettyObject } from "../utils/format";
 | 
			
		||||
@@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) {
 | 
			
		||||
  return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
 | 
			
		||||
  const vars = {
 | 
			
		||||
    model: modelConfig.model,
 | 
			
		||||
    time: new Date().toLocaleString(),
 | 
			
		||||
    lang: getLang(),
 | 
			
		||||
    input: input,
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
 | 
			
		||||
 | 
			
		||||
  // must contains {{input}}
 | 
			
		||||
  const inputVar = "{{input}}";
 | 
			
		||||
  if (!output.includes(inputVar)) {
 | 
			
		||||
    output += "\n" + inputVar;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Object.entries(vars).forEach(([name, value]) => {
 | 
			
		||||
    output = output.replaceAll(`{{${name}}}`, value);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export const useChatStore = create<ChatStore>()(
 | 
			
		||||
  persist(
 | 
			
		||||
    (set, get) => ({
 | 
			
		||||
@@ -158,7 +181,16 @@ export const useChatStore = create<ChatStore>()(
 | 
			
		||||
        session.id = get().globalId;
 | 
			
		||||
 | 
			
		||||
        if (mask) {
 | 
			
		||||
          session.mask = { ...mask };
 | 
			
		||||
          const config = useAppConfig.getState();
 | 
			
		||||
          const globalModelConfig = config.modelConfig;
 | 
			
		||||
 | 
			
		||||
          session.mask = {
 | 
			
		||||
            ...mask,
 | 
			
		||||
            modelConfig: {
 | 
			
		||||
              ...globalModelConfig,
 | 
			
		||||
              ...mask.modelConfig,
 | 
			
		||||
            },
 | 
			
		||||
          };
 | 
			
		||||
          session.topic = mask.name;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -238,9 +270,12 @@ export const useChatStore = create<ChatStore>()(
 | 
			
		||||
        const session = get().currentSession();
 | 
			
		||||
        const modelConfig = session.mask.modelConfig;
 | 
			
		||||
 | 
			
		||||
        const userContent = fillTemplateWith(content, modelConfig);
 | 
			
		||||
        console.log("[User Input] fill with template: ", userContent);
 | 
			
		||||
 | 
			
		||||
        const userMessage: ChatMessage = createMessage({
 | 
			
		||||
          role: "user",
 | 
			
		||||
          content,
 | 
			
		||||
          content: userContent,
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        const botMessage: ChatMessage = createMessage({
 | 
			
		||||
@@ -250,31 +285,22 @@ export const useChatStore = create<ChatStore>()(
 | 
			
		||||
          model: modelConfig.model,
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        const systemInfo = createMessage({
 | 
			
		||||
          role: "system",
 | 
			
		||||
          content: `IMPORTANT: You are a virtual assistant powered by the ${
 | 
			
		||||
            modelConfig.model
 | 
			
		||||
          } model, now time is ${new Date().toLocaleString()}}`,
 | 
			
		||||
          id: botMessage.id! + 1,
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        // get recent messages
 | 
			
		||||
        const systemMessages = [];
 | 
			
		||||
        // if user define a mask with context prompts, wont send system info
 | 
			
		||||
        if (session.mask.context.length === 0) {
 | 
			
		||||
          systemMessages.push(systemInfo);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const recentMessages = get().getMessagesWithMemory();
 | 
			
		||||
        const sendMessages = systemMessages.concat(
 | 
			
		||||
          recentMessages.concat(userMessage),
 | 
			
		||||
        );
 | 
			
		||||
        const sendMessages = recentMessages.concat(userMessage);
 | 
			
		||||
        const sessionIndex = get().currentSessionIndex;
 | 
			
		||||
        const messageIndex = get().currentSession().messages.length + 1;
 | 
			
		||||
 | 
			
		||||
        // save user's and bot's message
 | 
			
		||||
        get().updateCurrentSession((session) => {
 | 
			
		||||
          session.messages = session.messages.concat([userMessage, botMessage]);
 | 
			
		||||
          const savedUserMessage = {
 | 
			
		||||
            ...userMessage,
 | 
			
		||||
            content,
 | 
			
		||||
          };
 | 
			
		||||
          session.messages = session.messages.concat([
 | 
			
		||||
            savedUserMessage,
 | 
			
		||||
            botMessage,
 | 
			
		||||
          ]);
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        // make request
 | 
			
		||||
@@ -350,55 +376,62 @@ export const useChatStore = create<ChatStore>()(
 | 
			
		||||
      getMessagesWithMemory() {
 | 
			
		||||
        const session = get().currentSession();
 | 
			
		||||
        const modelConfig = session.mask.modelConfig;
 | 
			
		||||
        const clearContextIndex = session.clearContextIndex ?? 0;
 | 
			
		||||
        const messages = session.messages.slice();
 | 
			
		||||
        const totalMessageCount = session.messages.length;
 | 
			
		||||
 | 
			
		||||
        // wont send cleared context messages
 | 
			
		||||
        const clearedContextMessages = session.messages.slice(
 | 
			
		||||
          session.clearContextIndex ?? 0,
 | 
			
		||||
        );
 | 
			
		||||
        const messages = clearedContextMessages.filter((msg) => !msg.isError);
 | 
			
		||||
        const n = messages.length;
 | 
			
		||||
 | 
			
		||||
        const context = session.mask.context.slice();
 | 
			
		||||
        // in-context prompts
 | 
			
		||||
        const contextPrompts = session.mask.context.slice();
 | 
			
		||||
 | 
			
		||||
        // long term memory
 | 
			
		||||
        if (
 | 
			
		||||
        const shouldSendLongTermMemory =
 | 
			
		||||
          modelConfig.sendMemory &&
 | 
			
		||||
          session.memoryPrompt &&
 | 
			
		||||
          session.memoryPrompt.length > 0
 | 
			
		||||
        ) {
 | 
			
		||||
          const memoryPrompt = get().getMemoryPrompt();
 | 
			
		||||
          context.push(memoryPrompt);
 | 
			
		||||
        }
 | 
			
		||||
          session.memoryPrompt.length > 0 &&
 | 
			
		||||
          session.lastSummarizeIndex <= clearContextIndex;
 | 
			
		||||
        const longTermMemoryPrompts = shouldSendLongTermMemory
 | 
			
		||||
          ? [get().getMemoryPrompt()]
 | 
			
		||||
          : [];
 | 
			
		||||
        const longTermMemoryStartIndex = session.lastSummarizeIndex;
 | 
			
		||||
 | 
			
		||||
        // get short term and unmemorized long term memory
 | 
			
		||||
        const shortTermMemoryMessageIndex = Math.max(
 | 
			
		||||
        // short term memory
 | 
			
		||||
        const shortTermMemoryStartIndex = Math.max(
 | 
			
		||||
          0,
 | 
			
		||||
          n - modelConfig.historyMessageCount,
 | 
			
		||||
          totalMessageCount - modelConfig.historyMessageCount,
 | 
			
		||||
        );
 | 
			
		||||
        const longTermMemoryMessageIndex = session.lastSummarizeIndex;
 | 
			
		||||
 | 
			
		||||
        // try to concat history messages
 | 
			
		||||
        // lets concat send messages, including 4 parts:
 | 
			
		||||
        // 1. long term memory: summarized memory messages
 | 
			
		||||
        // 2. pre-defined in-context prompts
 | 
			
		||||
        // 3. short term memory: latest n messages
 | 
			
		||||
        // 4. newest input message
 | 
			
		||||
        const memoryStartIndex = Math.min(
 | 
			
		||||
          shortTermMemoryMessageIndex,
 | 
			
		||||
          longTermMemoryMessageIndex,
 | 
			
		||||
          longTermMemoryStartIndex,
 | 
			
		||||
          shortTermMemoryStartIndex,
 | 
			
		||||
        );
 | 
			
		||||
        const threshold = modelConfig.max_tokens;
 | 
			
		||||
        // and if user has cleared history messages, we should exclude the memory too.
 | 
			
		||||
        const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
 | 
			
		||||
        const maxTokenThreshold = modelConfig.max_tokens;
 | 
			
		||||
 | 
			
		||||
        // get recent messages as many as possible
 | 
			
		||||
        // get recent messages as much as possible
 | 
			
		||||
        const reversedRecentMessages = [];
 | 
			
		||||
        for (
 | 
			
		||||
          let i = n - 1, count = 0;
 | 
			
		||||
          i >= memoryStartIndex && count < threshold;
 | 
			
		||||
          let i = totalMessageCount - 1, tokenCount = 0;
 | 
			
		||||
          i >= contextStartIndex && tokenCount < maxTokenThreshold;
 | 
			
		||||
          i -= 1
 | 
			
		||||
        ) {
 | 
			
		||||
          const msg = messages[i];
 | 
			
		||||
          if (!msg || msg.isError) continue;
 | 
			
		||||
          count += estimateTokenLength(msg.content);
 | 
			
		||||
          tokenCount += estimateTokenLength(msg.content);
 | 
			
		||||
          reversedRecentMessages.push(msg);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // concat
 | 
			
		||||
        const recentMessages = context.concat(reversedRecentMessages.reverse());
 | 
			
		||||
        // concat all messages
 | 
			
		||||
        const recentMessages = [
 | 
			
		||||
          ...longTermMemoryPrompts,
 | 
			
		||||
          ...contextPrompts,
 | 
			
		||||
          ...reversedRecentMessages.reverse(),
 | 
			
		||||
        ];
 | 
			
		||||
 | 
			
		||||
        return recentMessages;
 | 
			
		||||
      },
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
import { create } from "zustand";
 | 
			
		||||
import { persist } from "zustand/middleware";
 | 
			
		||||
import { getClientConfig } from "../config/client";
 | 
			
		||||
import { StoreKey } from "../constant";
 | 
			
		||||
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
 | 
			
		||||
 | 
			
		||||
export enum SubmitKey {
 | 
			
		||||
  Enter = "Enter",
 | 
			
		||||
@@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = {
 | 
			
		||||
    sendMemory: true,
 | 
			
		||||
    historyMessageCount: 4,
 | 
			
		||||
    compressMessageLengthThreshold: 1000,
 | 
			
		||||
    template: DEFAULT_INPUT_TEMPLATE,
 | 
			
		||||
  },
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -176,15 +177,16 @@ export const useAppConfig = create<ChatConfigStore>()(
 | 
			
		||||
    }),
 | 
			
		||||
    {
 | 
			
		||||
      name: StoreKey.Config,
 | 
			
		||||
      version: 3,
 | 
			
		||||
      version: 3.1,
 | 
			
		||||
      migrate(persistedState, version) {
 | 
			
		||||
        if (version === 3) return persistedState as any;
 | 
			
		||||
        if (version === 3.1) return persistedState as any;
 | 
			
		||||
 | 
			
		||||
        const state = persistedState as ChatConfig;
 | 
			
		||||
        state.modelConfig.sendMemory = true;
 | 
			
		||||
        state.modelConfig.historyMessageCount = 4;
 | 
			
		||||
        state.modelConfig.compressMessageLengthThreshold = 1000;
 | 
			
		||||
        state.modelConfig.frequency_penalty = 0;
 | 
			
		||||
        state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
 | 
			
		||||
        state.dontShowMaskSplashScreen = false;
 | 
			
		||||
 | 
			
		||||
        return state;
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@ import { persist } from "zustand/middleware";
 | 
			
		||||
import { BUILTIN_MASKS } from "../masks";
 | 
			
		||||
import { getLang, Lang } from "../locales";
 | 
			
		||||
import { DEFAULT_TOPIC, ChatMessage } from "./chat";
 | 
			
		||||
import { ModelConfig, ModelType, useAppConfig } from "./config";
 | 
			
		||||
import { ModelConfig, useAppConfig } from "./config";
 | 
			
		||||
import { StoreKey } from "../constant";
 | 
			
		||||
 | 
			
		||||
export type Mask = {
 | 
			
		||||
@@ -89,7 +89,18 @@ export const useMaskStore = create<MaskStore>()(
 | 
			
		||||
        const userMasks = Object.values(get().masks).sort(
 | 
			
		||||
          (a, b) => b.id - a.id,
 | 
			
		||||
        );
 | 
			
		||||
        return userMasks.concat(BUILTIN_MASKS);
 | 
			
		||||
        const config = useAppConfig.getState();
 | 
			
		||||
        const buildinMasks = BUILTIN_MASKS.map(
 | 
			
		||||
          (m) =>
 | 
			
		||||
            ({
 | 
			
		||||
              ...m,
 | 
			
		||||
              modelConfig: {
 | 
			
		||||
                ...config.modelConfig,
 | 
			
		||||
                ...m.modelConfig,
 | 
			
		||||
              },
 | 
			
		||||
            } as Mask),
 | 
			
		||||
        );
 | 
			
		||||
        return userMasks.concat(buildinMasks);
 | 
			
		||||
      },
 | 
			
		||||
      search(text) {
 | 
			
		||||
        return Object.values(get().masks);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user