diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 4eb26277a..2d8ddd2ca 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -119,6 +119,7 @@ export class ChatGPTApi implements LLMApi { providerName: options.config.providerName, }, }; + console.log('-------', modelConfig, options) const requestPayload: RequestPayload = { messages, stream: options.config.stream, diff --git a/app/store/chat.ts b/app/store/chat.ts index 3cc2bbe74..46269c897 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -97,7 +97,10 @@ function createEmptySession(): ChatSession { const ChatFetchTaskPool: Record = {}; -function getSummarizeModel(currentModel: string) { +function getSummarizeModel(currentModel: string): { + name: string, + providerName: string | undefined, +} { // if it is using gpt-* models, force to use 3.5 to summarize if (currentModel.startsWith("gpt")) { const configStore = useAppConfig.getState(); @@ -110,12 +113,21 @@ function getSummarizeModel(currentModel: string) { const summarizeModel = allModel.find( (m) => m.name === SUMMARIZE_MODEL && m.available, ); - return summarizeModel?.name ?? currentModel; + return { + name: summarizeModel?.name ?? currentModel, + providerName: summarizeModel?.provider?.providerName, + } } if (currentModel.startsWith("gemini")) { - return GEMINI_SUMMARIZE_MODEL; + return { + name: GEMINI_SUMMARIZE_MODEL, + providerName: ServiceProvider.Google, + } + } + return { + name: currentModel, + providerName: undefined, } - return currentModel; } function countMessages(msgs: ChatMessage[]) { @@ -905,7 +917,8 @@ export const useChatStore = createPersistStore( api.llm.chat({ messages: topicMessages, config: { - model: getSummarizeModel(session.mask.modelConfig.model), + model: getSummarizeModel(session.mask.modelConfig.model).name, + providerName: getSummarizeModel(session.mask.modelConfig.model).providerName, stream: false, }, onFinish(message) { @@ -967,7 +980,8 @@ export const useChatStore = createPersistStore( config: { ...modelcfg, stream: true, - model: getSummarizeModel(session.mask.modelConfig.model), + model: getSummarizeModel(session.mask.modelConfig.model).name, + providerName: getSummarizeModel(session.mask.modelConfig.model).providerName, }, onUpdate(message) { session.memoryPrompt = message;