Merge remote-tracking branch 'upstream/main' into dev

# Conflicts:
#	app/locales/ar.ts
#	app/locales/bn.ts
#	app/locales/cs.ts
#	app/locales/de.ts
#	app/locales/es.ts
#	app/locales/fr.ts
#	app/locales/id.ts
#	app/locales/it.ts
#	app/locales/jp.ts
#	app/locales/ko.ts
#	app/locales/no.ts
#	app/locales/pt.ts
#	app/locales/ru.ts
#	app/locales/sk.ts
#	app/locales/tr.ts
#	app/locales/vi.ts
#	app/store/chat.ts
#	app/store/config.ts
This commit is contained in:
sijinhui
2024-09-14 15:37:19 +08:00
14 changed files with 159 additions and 86 deletions

View File

@@ -1,9 +1,15 @@
import { trimTopic, getMessageTextContent } from "../utils";
import { getMessageTextContent, trimTopic } from "../utils";
import Locale, { getLang } from "../locales";
import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
import { nanoid } from "nanoid";
import type {
ClientApi,
MultimodalContent,
RequestMessage,
} from "../client/api";
import { getClientApi } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { showToast } from "../components/ui-lib";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
import {
DEFAULT_INPUT_TEMPLATE,
DEFAULT_MODELS,
@@ -11,9 +17,9 @@ import {
KnowledgeCutOffDate,
ServiceProvider,
StoreKey,
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
} from "../constant";
import Locale, { getLang } from "../locales";
import { isDalle3, safeLocalStorage } from "../utils";
import {
getClientApi,
getHeaders,
@@ -26,13 +32,10 @@ import type {
} from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
import { isDalle3, safeLocalStorage } from "../utils";
import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
import { estimateTokenLength } from "../utils/token";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
const localStorage = safeLocalStorage();
@@ -114,39 +117,6 @@ function createEmptySession(): ChatSession {
// if it is using gpt-* models, force to use 4o-mini to summarize
const ChatFetchTaskPool: Record<string, any> = {};
function getSummarizeModel(currentModel: string): {
name: string;
providerName: string | undefined;
} {
// if it is using gpt-* models, force to use 4o-mini to summarize
if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
const configStore = useAppConfig.getState();
const accessStore = useAccessStore.getState();
const allModel = collectModelsWithDefaultModel(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
accessStore.defaultModel,
);
const summarizeModel = allModel.find(
(m) => m.name === SUMMARIZE_MODEL && m.available,
);
return {
name: summarizeModel?.name ?? currentModel,
providerName: summarizeModel?.provider?.providerName,
};
}
if (currentModel.startsWith("gemini")) {
return {
name: GEMINI_SUMMARIZE_MODEL,
providerName: ServiceProvider.Google,
};
}
return {
name: currentModel,
providerName: undefined,
};
}
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce(
(pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
@@ -935,7 +905,7 @@ export const useChatStore = createPersistStore(
return;
}
const providerName = modelConfig.providerName;
const providerName = modelConfig.compressProviderName;
const api: ClientApi = getClientApi(providerName);
// remove error messages if any
@@ -957,9 +927,7 @@ export const useChatStore = createPersistStore(
api.llm.chat({
messages: topicMessages,
config: {
model: getSummarizeModel(session.mask.modelConfig.model).name,
providerName: getSummarizeModel(session.mask.modelConfig.model)
.providerName,
model: modelConfig.compressModel,
stream: false,
},
onFinish(message) {
@@ -1021,9 +989,10 @@ export const useChatStore = createPersistStore(
config: {
...modelcfg,
stream: true,
model: getSummarizeModel(session.mask.modelConfig.model).name,
providerName: getSummarizeModel(session.mask.modelConfig.model)
.providerName,
model: modelConfig.compressModel,
// providerName: getSummarizeModel(session.mask.modelConfig.model)
// .providerName,
// TODO:
},
onUpdate(message) {
session.memoryPrompt = message;
@@ -1072,7 +1041,7 @@ export const useChatStore = createPersistStore(
},
{
name: StoreKey.Chat,
version: 3.1,
version: 3.2,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
@@ -1119,6 +1088,16 @@ export const useChatStore = createPersistStore(
});
}
// add default summarize model for every session
if (version < 3.2) {
newState.sessions.forEach((s) => {
const config = useAppConfig.getState();
s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
s.mask.modelConfig.compressProviderName =
config.modelConfig.compressProviderName;
});
}
return newState as any;
},
},

View File

@@ -55,7 +55,7 @@ export const DEFAULT_CONFIG = {
dontUseModel: DISABLE_MODELS,
modelConfig: {
model: "gpt-3.5-turbo-0125" as ModelType,
model: "gpt-4o-mini" as ModelType,
providerName: "OpenAI" as ServiceProvider,
temperature: 0.8,
top_p: 1,
@@ -65,6 +65,8 @@ export const DEFAULT_CONFIG = {
sendMemory: true,
historyMessageCount: 5,
compressMessageLengthThreshold: 4000,
compressModel: "gpt-4o-mini" as ModelType,
compressProviderName: "OpenAI" as ServiceProvider,
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
@@ -145,7 +147,7 @@ export const useAppConfig = createPersistStore(
}),
{
name: StoreKey.Config,
version: 3.993,
version: 4,
migrate(persistedState, version) {
const state = persistedState as ChatConfig;
@@ -190,6 +192,13 @@ export const useAppConfig = createPersistStore(
// : config?.template ?? DEFAULT_INPUT_TEMPLATE;
}
if (version < 4) {
state.modelConfig.compressModel =
DEFAULT_CONFIG.modelConfig.compressModel;
state.modelConfig.compressProviderName =
DEFAULT_CONFIG.modelConfig.compressProviderName;
}
return state as any;
},
},