feat: optimize rag

This commit is contained in:
Hk-Gosuto
2024-07-07 15:41:58 +08:00
parent f260f11755
commit 712022d8c7
19 changed files with 332 additions and 176 deletions

View File

@@ -116,7 +116,7 @@ export abstract class LLMApi {
abstract speech(options: SpeechOptions): Promise<ArrayBuffer>;
abstract transcription(options: TranscriptionOptions): Promise<string>;
abstract toolAgentChat(options: AgentChatOptions): Promise<void>;
abstract createRAGStore(options: CreateRAGStoreOptions): Promise<void>;
abstract createRAGStore(options: CreateRAGStoreOptions): Promise<string>;
abstract usage(): Promise<LLMUsage>;
abstract models(): Promise<LLMModel[]>;
}

View File

@@ -89,7 +89,7 @@ export class ClaudeApi implements LLMApi {
toolAgentChat(options: AgentChatOptions): Promise<void> {
throw new Error("Method not implemented.");
}
createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
createRAGStore(options: CreateRAGStoreOptions): Promise<string> {
throw new Error("Method not implemented.");
}
extractMessage(res: any) {

View File

@@ -20,7 +20,7 @@ import {
} from "@/app/utils";
export class GeminiProApi implements LLMApi {
createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
createRAGStore(options: CreateRAGStoreOptions): Promise<string> {
throw new Error("Method not implemented.");
}
transcription(options: TranscriptionOptions): Promise<string> {

View File

@@ -373,7 +373,7 @@ export class ChatGPTApi implements LLMApi {
}
}
async createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
async createRAGStore(options: CreateRAGStoreOptions): Promise<string> {
try {
const accessStore = useAccessStore.getState();
const isAzure = accessStore.provider === ServiceProvider.Azure;
@@ -395,9 +395,12 @@ export class ChatGPTApi implements LLMApi {
};
const res = await fetch(path, chatPayload);
if (res.status !== 200) throw new Error(await res.text());
const resJson = await res.json();
return resJson.partial;
} catch (e) {
console.log("[Request] failed to make a chat reqeust", e);
options.onError?.(e as Error);
return "";
}
}

View File

@@ -1,10 +1,13 @@
import { getHeaders } from "../api";
import { getClientApi } from "@/app/utils";
import { ClientApi, getHeaders } from "../api";
import { ChatSession } from "@/app/store";
export interface FileInfo {
originalFilename: string;
fileName: string;
filePath: string;
size: number;
partial?: string;
}
export class FileApi {
@@ -31,4 +34,15 @@ export class FileApi {
filePath: resJson.filePath,
};
}
async uploadForRag(file: any, session: ChatSession): Promise<FileInfo> {
var fileInfo = await this.upload(file);
var api: ClientApi = getClientApi(session.mask.modelConfig.model);
let partial = await api.llm.createRAGStore({
chatSessionId: session.id,
fileInfos: [fileInfo],
});
fileInfo.partial = partial;
return fileInfo;
}
}