Merge remote-tracking branch 'upstream/main' into dev

# Conflicts:
#	app/client/platforms/openai.ts
#	app/constant.ts
#	app/utils/model.ts
This commit is contained in:
sijinhui
2024-08-06 13:08:13 +08:00
10 changed files with 202 additions and 68 deletions

View File

@@ -31,6 +31,7 @@ import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
import { isDalle3 } from "../utils";
export type ChatMessage = RequestMessage & {
date: string;
@@ -95,12 +96,12 @@ function createEmptySession(): ChatSession {
};
}
// if it is using gpt-* models, force to use 4o-mini to summarize
// if it is using gpt-* models, force to use 4o-mini to summarize
const ChatFetchTaskPool: Record<string, any> = {};
function getSummarizeModel(currentModel: string): {
name: string,
providerName: string | undefined,
name: string;
providerName: string | undefined;
} {
// if it is using gpt-* models, force to use 3.5 to summarize
if (currentModel.startsWith("gpt")) {
@@ -117,18 +118,18 @@ function getSummarizeModel(currentModel: string): {
return {
name: summarizeModel?.name ?? currentModel,
providerName: summarizeModel?.provider?.providerName,
}
};
}
if (currentModel.startsWith("gemini")) {
return {
name: GEMINI_SUMMARIZE_MODEL,
providerName: ServiceProvider.Google,
}
};
}
return {
name: currentModel,
providerName: undefined,
}
};
}
function countMessages(msgs: ChatMessage[]) {
@@ -718,7 +719,7 @@ export const useChatStore = createPersistStore(
set(() => ({}));
extAttr?.setAutoScroll(true);
} else {
const api: ClientApi = getClientApi(modelConfig.providerName)
const api: ClientApi = getClientApi(modelConfig.providerName);
// console.log('-------', modelConfig, '-----', api)
// make request
@@ -896,8 +897,13 @@ export const useChatStore = createPersistStore(
const config = useAppConfig.getState();
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
// skip summarize when using dalle3?
if (isDalle3(modelConfig.model)) {
return;
}
const api: ClientApi = getClientApi(modelConfig.providerName);
const providerName = modelConfig.providerName;
const api: ClientApi = getClientApi(providerName);
// remove error messages if any
const messages = session.messages;
@@ -919,8 +925,10 @@ export const useChatStore = createPersistStore(
messages: topicMessages,
config: {
model: getSummarizeModel(session.mask.modelConfig.model).name,
providerName: getSummarizeModel(session.mask.modelConfig.model).providerName,
providerName: getSummarizeModel(session.mask.modelConfig.model)
.providerName,
stream: false,
providerName,
},
onFinish(message) {
get().updateCurrentSession(
@@ -982,7 +990,8 @@ export const useChatStore = createPersistStore(
...modelcfg,
stream: true,
model: getSummarizeModel(session.mask.modelConfig.model).name,
providerName: getSummarizeModel(session.mask.modelConfig.model).providerName,
providerName: getSummarizeModel(session.mask.modelConfig.model)
.providerName,
},
onUpdate(message) {
session.memoryPrompt = message;

View File

@@ -1,4 +1,5 @@
import { LLMModel } from "../client/api";
import { DalleSize } from "../typing";
import { getClientConfig } from "../config/client";
import {
DEFAULT_INPUT_TEMPLATE,
@@ -66,6 +67,7 @@ export const DEFAULT_CONFIG = {
compressMessageLengthThreshold: 4000,
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
},
};