Merge branch 'main' of github.com:ChatGPTNextWeb/ChatGPT-Next-Web into ChatGPTNextWeb-main

This commit is contained in:
ayman-makki
2024-10-04 13:44:30 +01:00
159 changed files with 19561 additions and 3377 deletions

View File

@@ -1,6 +1,7 @@
import {
ApiPath,
DEFAULT_API_HOST,
GoogleSafetySettingsThreshold,
ServiceProvider,
StoreKey,
} from "../constant";
@@ -12,10 +13,47 @@ import { DEFAULT_CONFIG } from "./config";
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
const DEFAULT_OPENAI_URL =
getClientConfig()?.buildMode === "export"
? DEFAULT_API_HOST + "/api/proxy/openai"
: ApiPath.OpenAI;
const isApp = getClientConfig()?.buildMode === "export";
const DEFAULT_OPENAI_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/openai"
: ApiPath.OpenAI;
const DEFAULT_GOOGLE_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/google"
: ApiPath.Google;
const DEFAULT_ANTHROPIC_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/anthropic"
: ApiPath.Anthropic;
const DEFAULT_BAIDU_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/baidu"
: ApiPath.Baidu;
const DEFAULT_BYTEDANCE_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/bytedance"
: ApiPath.ByteDance;
const DEFAULT_ALIBABA_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/alibaba"
: ApiPath.Alibaba;
const DEFAULT_TENCENT_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/tencent"
: ApiPath.Tencent;
const DEFAULT_MOONSHOT_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/moonshot"
: ApiPath.Moonshot;
const DEFAULT_STABILITY_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/stability"
: ApiPath.Stability;
const DEFAULT_IFLYTEK_URL = isApp
? DEFAULT_API_HOST + "/api/proxy/iflytek"
: ApiPath.Iflytek;
const DEFAULT_ACCESS_STATE = {
accessCode: "",
@@ -33,14 +71,46 @@ const DEFAULT_ACCESS_STATE = {
azureApiVersion: "2023-08-01-preview",
// google ai studio
googleUrl: "",
googleUrl: DEFAULT_GOOGLE_URL,
googleApiKey: "",
googleApiVersion: "v1",
googleSafetySettings: GoogleSafetySettingsThreshold.BLOCK_ONLY_HIGH,
// anthropic
anthropicUrl: DEFAULT_ANTHROPIC_URL,
anthropicApiKey: "",
anthropicApiVersion: "2023-06-01",
anthropicUrl: "",
// baidu
baiduUrl: DEFAULT_BAIDU_URL,
baiduApiKey: "",
baiduSecretKey: "",
// bytedance
bytedanceUrl: DEFAULT_BYTEDANCE_URL,
bytedanceApiKey: "",
// alibaba
alibabaUrl: DEFAULT_ALIBABA_URL,
alibabaApiKey: "",
// moonshot
moonshotUrl: DEFAULT_MOONSHOT_URL,
moonshotApiKey: "",
//stability
stabilityUrl: DEFAULT_STABILITY_URL,
stabilityApiKey: "",
// tencent
tencentUrl: DEFAULT_TENCENT_URL,
tencentSecretKey: "",
tencentSecretId: "",
// iflytek
iflytekUrl: DEFAULT_IFLYTEK_URL,
iflytekApiKey: "",
iflytekApiSecret: "",
// server config
needCode: true,
@@ -50,6 +120,9 @@ const DEFAULT_ACCESS_STATE = {
disableFastLink: false,
customModels: "",
defaultModel: "",
// tts config
edgeTTSVoiceName: "zh-CN-YunxiNeural",
};
export const useAccessStore = createPersistStore(
@@ -62,6 +135,12 @@ export const useAccessStore = createPersistStore(
return get().needCode;
},
edgeVoiceName() {
this.fetch();
return get().edgeTTSVoiceName;
},
isValidOpenAI() {
return ensure(get(), ["openaiApiKey"]);
},
@@ -78,6 +157,29 @@ export const useAccessStore = createPersistStore(
return ensure(get(), ["anthropicApiKey"]);
},
isValidBaidu() {
return ensure(get(), ["baiduApiKey", "baiduSecretKey"]);
},
isValidByteDance() {
return ensure(get(), ["bytedanceApiKey"]);
},
isValidAlibaba() {
return ensure(get(), ["alibabaApiKey"]);
},
isValidTencent() {
return ensure(get(), ["tencentSecretKey", "tencentSecretId"]);
},
isValidMoonshot() {
return ensure(get(), ["moonshotApiKey"]);
},
isValidIflytek() {
return ensure(get(), ["iflytekApiKey"]);
},
isAuthorized() {
this.fetch();
@@ -87,6 +189,12 @@ export const useAccessStore = createPersistStore(
this.isValidAzure() ||
this.isValidGoogle() ||
this.isValidAnthropic() ||
this.isValidBaidu() ||
this.isValidByteDance() ||
this.isValidAlibaba() ||
this.isValidTencent() ||
this.isValidMoonshot() ||
this.isValidIflytek() ||
!this.enabledAccessControl() ||
(this.enabledAccessControl() && ensure(get(), ["accessCode"]))
);
@@ -103,10 +211,13 @@ export const useAccessStore = createPersistStore(
})
.then((res) => res.json())
.then((res) => {
// Set default model from env request
let defaultModel = res.defaultModel ?? "";
DEFAULT_CONFIG.modelConfig.model =
defaultModel !== "" ? defaultModel : "gpt-3.5-turbo";
const defaultModel = res.defaultModel ?? "";
if (defaultModel !== "") {
const [model, providerName] = defaultModel.split("@");
DEFAULT_CONFIG.modelConfig.model = model;
DEFAULT_CONFIG.modelConfig.providerName = providerName;
}
return res;
})
.then((res: DangerConfig) => {

View File

@@ -1,28 +1,43 @@
import { trimTopic, getMessageTextContent } from "../utils";
import { getMessageTextContent, trimTopic } from "../utils";
import Locale, { getLang } from "../locales";
import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
import { nanoid } from "nanoid";
import type {
ClientApi,
MultimodalContent,
RequestMessage,
} from "../client/api";
import { getClientApi } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { showToast } from "../components/ui-lib";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
import {
DEFAULT_INPUT_TEMPLATE,
DEFAULT_MODELS,
DEFAULT_SYSTEM_TEMPLATE,
KnowledgeCutOffDate,
ModelProvider,
StoreKey,
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
} from "../constant";
import { ClientApi, RequestMessage, MultimodalContent } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import Locale, { getLang } from "../locales";
import { isDalle3, safeLocalStorage } from "../utils";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { identifyDefaultClaudeModel } from "../utils/checkers";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
import { estimateTokenLength } from "../utils/token";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
const localStorage = safeLocalStorage();
export type ChatMessageTool = {
id: string;
index?: number;
type?: string;
function?: {
name: string;
arguments?: string;
};
content?: string;
isError?: boolean;
};
export type ChatMessage = RequestMessage & {
date: string;
@@ -30,6 +45,7 @@ export type ChatMessage = RequestMessage & {
isError?: boolean;
id: string;
model?: ModelType;
tools?: ChatMessageTool[];
};
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@@ -86,27 +102,6 @@ function createEmptySession(): ChatSession {
};
}
function getSummarizeModel(currentModel: string) {
// if it is using gpt-* models, force to use 3.5 to summarize
if (currentModel.startsWith("gpt")) {
const configStore = useAppConfig.getState();
const accessStore = useAccessStore.getState();
const allModel = collectModelsWithDefaultModel(
configStore.models,
[configStore.customModels, accessStore.customModels].join(","),
accessStore.defaultModel,
);
const summarizeModel = allModel.find(
(m) => m.name === SUMMARIZE_MODEL && m.available,
);
return summarizeModel?.name ?? currentModel;
}
if (currentModel.startsWith("gemini")) {
return GEMINI_SUMMARIZE_MODEL;
}
return currentModel;
}
function countMessages(msgs: ChatMessage[]) {
return msgs.reduce(
(pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
@@ -161,6 +156,7 @@ function fillTemplateWith(input: string, modelConfig: ModelConfig) {
const DEFAULT_CHAT_STATE = {
sessions: [createEmptySession()],
currentSessionIndex: 0,
lastInput: "",
};
export const useChatStore = createPersistStore(
@@ -174,6 +170,28 @@ export const useChatStore = createPersistStore(
}
const methods = {
forkSession() {
// 获取当前会话
const currentSession = get().currentSession();
if (!currentSession) return;
const newSession = createEmptySession();
newSession.topic = currentSession.topic;
newSession.messages = [...currentSession.messages];
newSession.mask = {
...currentSession.mask,
modelConfig: {
...currentSession.mask.modelConfig,
},
};
set((state) => ({
currentSessionIndex: 0,
sessions: [newSession, ...state.sessions],
}));
},
clearSessions() {
set(() => ({
sessions: [createEmptySession()],
@@ -363,15 +381,7 @@ export const useChatStore = createPersistStore(
]);
});
var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(modelConfig.model)) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(modelConfig.providerName);
// make request
api.llm.chat({
messages: sendMessages,
@@ -393,8 +403,24 @@ export const useChatStore = createPersistStore(
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onBeforeTool(tool: ChatMessageTool) {
(botMessage.tools = botMessage?.tools || []).push(tool);
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onAfterTool(tool: ChatMessageTool) {
botMessage?.tools?.forEach((t, i, tools) => {
if (tool.id == t.id) {
tools[i] = { ...tool };
}
});
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onError(error) {
const isAborted = error.message.includes("aborted");
const isAborted = error.message?.includes?.("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
@@ -450,7 +476,8 @@ export const useChatStore = createPersistStore(
// system prompts, to get close to OpenAI Web ChatGPT
const shouldInjectSystemPrompts =
modelConfig.enableInjectSystemPrompts &&
session.mask.modelConfig.model.startsWith("gpt-");
(session.mask.modelConfig.model.startsWith("gpt-") ||
session.mask.modelConfig.model.startsWith("chatgpt-"));
var systemPrompts: ChatMessage[] = [];
systemPrompts = shouldInjectSystemPrompts
@@ -542,43 +569,53 @@ export const useChatStore = createPersistStore(
});
},
summarizeSession() {
summarizeSession(refreshTitle: boolean = false) {
const config = useAppConfig.getState();
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
var api: ClientApi;
if (modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(modelConfig.model)) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
// skip summarize when using dalle3?
if (isDalle3(modelConfig.model)) {
return;
}
const providerName = modelConfig.compressProviderName;
const api: ClientApi = getClientApi(providerName);
// remove error messages if any
const messages = session.messages;
// should summarize topic after chating more than 50 words
const SUMMARIZE_MIN_LEN = 50;
if (
config.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN
(config.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
refreshTitle
) {
const topicMessages = messages.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
const startIndex = Math.max(
0,
messages.length - modelConfig.historyMessageCount,
);
const topicMessages = messages
.slice(
startIndex < messages.length ? startIndex : messages.length - 1,
messages.length,
)
.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
);
api.llm.chat({
messages: topicMessages,
config: {
model: getSummarizeModel(session.mask.modelConfig.model),
model: modelConfig.compressModel,
stream: false,
providerName,
},
onFinish(message) {
if (!isValidMessage(message)) return;
get().updateCurrentSession(
(session) =>
(session.topic =
@@ -637,7 +674,7 @@ export const useChatStore = createPersistStore(
config: {
...modelcfg,
stream: true,
model: getSummarizeModel(session.mask.modelConfig.model),
model: modelConfig.compressModel,
},
onUpdate(message) {
session.memoryPrompt = message;
@@ -654,6 +691,10 @@ export const useChatStore = createPersistStore(
},
});
}
function isValidMessage(message: any): boolean {
return typeof message === "string" && !message.startsWith("```json");
}
},
updateStat(message: ChatMessage) {
@@ -670,17 +711,23 @@ export const useChatStore = createPersistStore(
set(() => ({ sessions }));
},
clearAllData() {
async clearAllData() {
await indexedDBStorage.clear();
localStorage.clear();
location.reload();
},
setLastInput(lastInput: string) {
set({
lastInput,
});
},
};
return methods;
},
{
name: StoreKey.Chat,
version: 3.1,
version: 3.2,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
@@ -727,6 +774,16 @@ export const useChatStore = createPersistStore(
});
}
// add default summarize model for every session
if (version < 3.2) {
newState.sessions.forEach((s) => {
const config = useAppConfig.getState();
s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
s.mask.modelConfig.compressProviderName =
config.modelConfig.compressProviderName;
});
}
return newState as any;
},
},

View File

@@ -1,14 +1,25 @@
import { LLMModel } from "../client/api";
import { DalleSize, DalleQuality, DalleStyle } from "../typing";
import { getClientConfig } from "../config/client";
import {
DEFAULT_INPUT_TEMPLATE,
DEFAULT_MODELS,
DEFAULT_SIDEBAR_WIDTH,
DEFAULT_TTS_ENGINE,
DEFAULT_TTS_ENGINES,
DEFAULT_TTS_MODEL,
DEFAULT_TTS_MODELS,
DEFAULT_TTS_VOICE,
DEFAULT_TTS_VOICES,
StoreKey,
ServiceProvider,
} from "../constant";
import { createPersistStore } from "../utils/store";
export type ModelType = (typeof DEFAULT_MODELS)[number]["name"];
export type TTSModelType = (typeof DEFAULT_TTS_MODELS)[number];
export type TTSVoiceType = (typeof DEFAULT_TTS_VOICES)[number];
export type TTSEngineType = (typeof DEFAULT_TTS_ENGINES)[number];
export enum SubmitKey {
Enter = "Enter",
@@ -32,12 +43,15 @@ export const DEFAULT_CONFIG = {
submitKey: SubmitKey.Enter,
avatar: "1f603",
fontSize: 14,
fontFamily: "",
theme: Theme.Auto as Theme,
tightBorder: !!config?.isApp,
sendPreviewBubble: true,
enableAutoGenerateTitle: true,
sidebarWidth: DEFAULT_SIDEBAR_WIDTH,
enableArtifacts: true, // show artifacts config
disablePromptHint: false,
dontShowMaskSplashScreen: false, // dont show splash screen when create chat
@@ -47,7 +61,8 @@ export const DEFAULT_CONFIG = {
models: DEFAULT_MODELS as any as LLMModel[],
modelConfig: {
model: "gpt-3.5-turbo" as ModelType,
model: "gpt-4o-mini" as ModelType,
providerName: "OpenAI" as ServiceProvider,
temperature: 0.5,
top_p: 1,
max_tokens: 4000,
@@ -56,14 +71,29 @@ export const DEFAULT_CONFIG = {
sendMemory: true,
historyMessageCount: 4,
compressMessageLengthThreshold: 1000,
compressModel: "gpt-4o-mini" as ModelType,
compressProviderName: "OpenAI" as ServiceProvider,
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
quality: "standard" as DalleQuality,
style: "vivid" as DalleStyle,
},
ttsConfig: {
enable: false,
autoplay: false,
engine: DEFAULT_TTS_ENGINE,
model: DEFAULT_TTS_MODEL,
voice: DEFAULT_TTS_VOICE,
speed: 1.0,
},
};
export type ChatConfig = typeof DEFAULT_CONFIG;
export type ModelConfig = ChatConfig["modelConfig"];
export type TTSConfig = ChatConfig["ttsConfig"];
export function limitNumber(
x: number,
@@ -78,6 +108,21 @@ export function limitNumber(
return Math.min(max, Math.max(min, x));
}
export const TTSConfigValidator = {
engine(x: string) {
return x as TTSEngineType;
},
model(x: string) {
return x as TTSModelType;
},
voice(x: string) {
return x as TTSVoiceType;
},
speed(x: number) {
return limitNumber(x, 0.25, 4.0, 1.0);
},
};
export const ModalConfigValidator = {
model(x: string) {
return x as ModelType;
@@ -116,12 +161,12 @@ export const useAppConfig = createPersistStore(
for (const model of oldModels) {
model.available = false;
modelMap[model.name] = model;
modelMap[`${model.name}@${model?.provider?.id}`] = model;
}
for (const model of newModels) {
model.available = true;
modelMap[model.name] = model;
modelMap[`${model.name}@${model?.provider?.id}`] = model;
}
set(() => ({
@@ -133,7 +178,22 @@ export const useAppConfig = createPersistStore(
}),
{
name: StoreKey.Config,
version: 3.9,
version: 4,
merge(persistedState, currentState) {
const state = persistedState as ChatConfig | undefined;
if (!state) return { ...currentState };
const models = currentState.models.slice();
state.models.forEach((pModel) => {
const idx = models.findIndex(
(v) => v.name === pModel.name && v.provider === pModel.provider,
);
if (idx !== -1) models[idx] = pModel;
else models.push(pModel);
});
return { ...currentState, ...state, models: models };
},
migrate(persistedState, version) {
const state = persistedState as ChatConfig;
@@ -171,6 +231,13 @@ export const useAppConfig = createPersistStore(
: config?.template ?? DEFAULT_INPUT_TEMPLATE;
}
if (version < 4) {
state.modelConfig.compressModel =
DEFAULT_CONFIG.modelConfig.compressModel;
state.modelConfig.compressProviderName =
DEFAULT_CONFIG.modelConfig.compressProviderName;
}
return state as any;
},
},

View File

@@ -2,3 +2,4 @@ export * from "./chat";
export * from "./update";
export * from "./access";
export * from "./config";
export * from "./plugin";

View File

@@ -17,13 +17,18 @@ export type Mask = {
modelConfig: ModelConfig;
lang: Lang;
builtin: boolean;
plugin?: string[];
enableArtifacts?: boolean;
};
export const DEFAULT_MASK_STATE = {
masks: {} as Record<string, Mask>,
language: undefined as Lang | undefined,
};
export type MaskState = typeof DEFAULT_MASK_STATE;
export type MaskState = typeof DEFAULT_MASK_STATE & {
language?: Lang | undefined;
};
export const DEFAULT_MASK_AVATAR = "gpt-bot";
export const createEmptyMask = () =>
@@ -37,6 +42,7 @@ export const createEmptyMask = () =>
lang: getLang(),
builtin: false,
createdAt: Date.now(),
plugin: [],
}) as Mask;
export const useMaskStore = createPersistStore(
@@ -99,6 +105,11 @@ export const useMaskStore = createPersistStore(
search(text: string) {
return Object.values(get().masks);
},
setLanguage(language: Lang | undefined) {
set({
language,
});
},
}),
{
name: StoreKey.Mask,

271
app/store/plugin.ts Normal file
View File

@@ -0,0 +1,271 @@
import OpenAPIClientAxios from "openapi-client-axios";
import { StoreKey } from "../constant";
import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { getClientConfig } from "../config/client";
import yaml from "js-yaml";
import { adapter, getOperationId } from "../utils";
import { useAccessStore } from "./access";
const isApp = getClientConfig()?.isApp;
export type Plugin = {
id: string;
createdAt: number;
title: string;
version: string;
content: string;
builtin: boolean;
authType?: string;
authLocation?: string;
authHeader?: string;
authToken?: string;
};
export type FunctionToolItem = {
type: string;
function: {
name: string;
description?: string;
parameters: Object;
};
};
type FunctionToolServiceItem = {
api: OpenAPIClientAxios;
length: number;
tools: FunctionToolItem[];
funcs: Record<string, Function>;
};
export const FunctionToolService = {
tools: {} as Record<string, FunctionToolServiceItem>,
add(plugin: Plugin, replace = false) {
if (!replace && this.tools[plugin.id]) return this.tools[plugin.id];
const headerName = (
plugin?.authType == "custom" ? plugin?.authHeader : "Authorization"
) as string;
const tokenValue =
plugin?.authType == "basic"
? `Basic ${plugin?.authToken}`
: plugin?.authType == "bearer"
? `Bearer ${plugin?.authToken}`
: plugin?.authToken;
const authLocation = plugin?.authLocation || "header";
const definition = yaml.load(plugin.content) as any;
const serverURL = definition?.servers?.[0]?.url;
const baseURL = !isApp ? "/api/proxy" : serverURL;
const headers: Record<string, string | undefined> = {
"X-Base-URL": !isApp ? serverURL : undefined,
};
if (authLocation == "header") {
headers[headerName] = tokenValue;
}
// try using openaiApiKey for Dalle3 Plugin.
if (!tokenValue && plugin.id === "dalle3") {
const openaiApiKey = useAccessStore.getState().openaiApiKey;
if (openaiApiKey) {
headers[headerName] = `Bearer ${openaiApiKey}`;
}
}
const api = new OpenAPIClientAxios({
definition: yaml.load(plugin.content) as any,
axiosConfigDefaults: {
adapter: (window.__TAURI__ ? adapter : ["xhr"]) as any,
baseURL,
headers,
},
});
try {
api.initSync();
} catch (e) {}
const operations = api.getOperations();
return (this.tools[plugin.id] = {
api,
length: operations.length,
tools: operations.map((o) => {
// @ts-ignore
const parameters = o?.requestBody?.content["application/json"]
?.schema || {
type: "object",
properties: {},
};
if (!parameters["required"]) {
parameters["required"] = [];
}
if (o.parameters instanceof Array) {
o.parameters.forEach((p) => {
// @ts-ignore
if (p?.in == "query" || p?.in == "path") {
// const name = `${p.in}__${p.name}`
// @ts-ignore
const name = p?.name;
parameters["properties"][name] = {
// @ts-ignore
type: p.schema.type,
// @ts-ignore
description: p.description,
};
// @ts-ignore
if (p.required) {
parameters["required"].push(name);
}
}
});
}
return {
type: "function",
function: {
name: getOperationId(o),
description: o.description || o.summary,
parameters: parameters,
},
} as FunctionToolItem;
}),
funcs: operations.reduce((s, o) => {
// @ts-ignore
s[getOperationId(o)] = function (args) {
const parameters: Record<string, any> = {};
if (o.parameters instanceof Array) {
o.parameters.forEach((p) => {
// @ts-ignore
parameters[p?.name] = args[p?.name];
// @ts-ignore
delete args[p?.name];
});
}
if (authLocation == "query") {
parameters[headerName] = tokenValue;
} else if (authLocation == "body") {
args[headerName] = tokenValue;
}
// @ts-ignore if o.operationId is null, then using o.path and o.method
return api.client.paths[o.path][o.method](
parameters,
args,
api.axiosConfigDefaults,
);
};
return s;
}, {}),
});
},
get(id: string) {
return this.tools[id];
},
};
export const createEmptyPlugin = () =>
({
id: nanoid(),
title: "",
version: "1.0.0",
content: "",
builtin: false,
createdAt: Date.now(),
}) as Plugin;
export const DEFAULT_PLUGIN_STATE = {
plugins: {} as Record<string, Plugin>,
};
export const usePluginStore = createPersistStore(
{ ...DEFAULT_PLUGIN_STATE },
(set, get) => ({
create(plugin?: Partial<Plugin>) {
const plugins = get().plugins;
const id = plugin?.id || nanoid();
plugins[id] = {
...createEmptyPlugin(),
...plugin,
id,
builtin: false,
};
set(() => ({ plugins }));
get().markUpdate();
return plugins[id];
},
updatePlugin(id: string, updater: (plugin: Plugin) => void) {
const plugins = get().plugins;
const plugin = plugins[id];
if (!plugin) return;
const updatePlugin = { ...plugin };
updater(updatePlugin);
plugins[id] = updatePlugin;
FunctionToolService.add(updatePlugin, true);
set(() => ({ plugins }));
get().markUpdate();
},
delete(id: string) {
const plugins = get().plugins;
delete plugins[id];
set(() => ({ plugins }));
get().markUpdate();
},
getAsTools(ids: string[]) {
const plugins = get().plugins;
const selected = (ids || [])
.map((id) => plugins[id])
.filter((i) => i)
.map((p) => FunctionToolService.add(p));
return [
// @ts-ignore
selected.reduce((s, i) => s.concat(i.tools), []),
selected.reduce((s, i) => Object.assign(s, i.funcs), {}),
];
},
get(id?: string) {
return get().plugins[id ?? 1145141919810];
},
getAll() {
return Object.values(get().plugins).sort(
(a, b) => b.createdAt - a.createdAt,
);
},
}),
{
name: StoreKey.Plugin,
version: 1,
onRehydrateStorage(state) {
// Skip store rehydration on server side
if (typeof window === "undefined") {
return;
}
fetch("./plugins.json")
.then((res) => res.json())
.then((res) => {
Promise.all(
res.map((item: any) =>
// skip get schema
state.get(item.id)
? item
: fetch(item.schema)
.then((res) => res.text())
.then((content) => ({
...item,
content,
}))
.catch((e) => item),
),
).then((builtinPlugins: any) => {
builtinPlugins
.filter((item: any) => item?.content)
.forEach((item: any) => {
const plugin = state.create(item);
state.updatePlugin(plugin.id, (plugin) => {
const tool = FunctionToolService.add(plugin, true);
plugin.title = tool.api.definition.info.title;
plugin.version = tool.api.definition.info.version;
plugin.builtin = true;
});
});
});
});
},
},
);

View File

@@ -1,7 +1,7 @@
import Fuse from "fuse.js";
import { getLang } from "../locales";
import { StoreKey } from "../constant";
import { nanoid } from "nanoid";
import { StoreKey } from "../constant";
import { getLang } from "../locales";
import { createPersistStore } from "../utils/store";
export interface Prompt {
@@ -147,6 +147,11 @@ export const usePromptStore = createPersistStore(
},
onRehydrateStorage(state) {
// Skip store rehydration on server side
if (typeof window === "undefined") {
return;
}
const PROMPT_URL = "./prompts.json";
type PromptList = Array<[string, string]>;
@@ -154,7 +159,7 @@ export const usePromptStore = createPersistStore(
fetch(PROMPT_URL)
.then((res) => res.json())
.then((res) => {
let fetchPrompts = [res.en, res.cn];
let fetchPrompts = [res.en, res.tw, res.cn];
if (getLang() === "cn") {
fetchPrompts = fetchPrompts.reverse();
}
@@ -175,7 +180,8 @@ export const usePromptStore = createPersistStore(
const allPromptsForSearch = builtinPrompts
.reduce((pre, cur) => pre.concat(cur), [])
.filter((v) => !!v.title && !!v.content);
SearchService.count.builtin = res.en.length + res.cn.length;
SearchService.count.builtin =
res.en.length + res.cn.length + res.tw.length;
SearchService.init(allPromptsForSearch, userPrompts);
});
},

163
app/store/sd.ts Normal file
View File

@@ -0,0 +1,163 @@
import {
Stability,
StoreKey,
ACCESS_CODE_PREFIX,
ApiPath,
} from "@/app/constant";
import { getBearerToken } from "@/app/client/api";
import { createPersistStore } from "@/app/utils/store";
import { nanoid } from "nanoid";
import { uploadImage, base64Image2Blob } from "@/app/utils/chat";
import { models, getModelParamBasicData } from "@/app/components/sd/sd-panel";
import { useAccessStore } from "./access";
const defaultModel = {
name: models[0].name,
value: models[0].value,
};
const defaultParams = getModelParamBasicData(models[0].params({}), {});
const DEFAULT_SD_STATE = {
currentId: 0,
draw: [],
currentModel: defaultModel,
currentParams: defaultParams,
};
export const useSdStore = createPersistStore<
{
currentId: number;
draw: any[];
currentModel: typeof defaultModel;
currentParams: any;
},
{
getNextId: () => number;
sendTask: (data: any, okCall?: Function) => void;
updateDraw: (draw: any) => void;
setCurrentModel: (model: any) => void;
setCurrentParams: (data: any) => void;
}
>(
DEFAULT_SD_STATE,
(set, _get) => {
function get() {
return {
..._get(),
...methods,
};
}
const methods = {
getNextId() {
const id = ++_get().currentId;
set({ currentId: id });
return id;
},
sendTask(data: any, okCall?: Function) {
data = { ...data, id: nanoid(), status: "running" };
set({ draw: [data, ..._get().draw] });
this.getNextId();
this.stabilityRequestCall(data);
okCall?.();
},
stabilityRequestCall(data: any) {
const accessStore = useAccessStore.getState();
let prefix: string = ApiPath.Stability as string;
let bearerToken = "";
if (accessStore.useCustomConfig) {
prefix = accessStore.stabilityUrl || (ApiPath.Stability as string);
bearerToken = getBearerToken(accessStore.stabilityApiKey);
}
if (!bearerToken && accessStore.enabledAccessControl()) {
bearerToken = getBearerToken(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
const headers = {
Accept: "application/json",
Authorization: bearerToken,
};
const path = `${prefix}/${Stability.GeneratePath}/${data.model}`;
const formData = new FormData();
for (let paramsKey in data.params) {
formData.append(paramsKey, data.params[paramsKey]);
}
fetch(path, {
method: "POST",
headers,
body: formData,
})
.then((response) => response.json())
.then((resData) => {
if (resData.errors && resData.errors.length > 0) {
this.updateDraw({
...data,
status: "error",
error: resData.errors[0],
});
this.getNextId();
return;
}
const self = this;
if (resData.finish_reason === "SUCCESS") {
uploadImage(base64Image2Blob(resData.image, "image/png"))
.then((img_data) => {
console.debug("uploadImage success", img_data, self);
self.updateDraw({
...data,
status: "success",
img_data,
});
})
.catch((e) => {
console.error("uploadImage error", e);
self.updateDraw({
...data,
status: "error",
error: JSON.stringify(e),
});
});
} else {
self.updateDraw({
...data,
status: "error",
error: JSON.stringify(resData),
});
}
this.getNextId();
})
.catch((error) => {
this.updateDraw({ ...data, status: "error", error: error.message });
console.error("Error:", error);
this.getNextId();
});
},
updateDraw(_draw: any) {
const draw = _get().draw || [];
draw.some((item, index) => {
if (item.id === _draw.id) {
draw[index] = _draw;
set(() => ({ draw }));
return true;
}
});
},
setCurrentModel(model: any) {
set({ currentModel: model });
},
setCurrentParams(data: any) {
set({
currentParams: data,
});
},
};
return methods;
},
{
name: StoreKey.SdList,
version: 1.0,
},
);

View File

@@ -1,5 +1,4 @@
import { getClientConfig } from "../config/client";
import { Updater } from "../typing";
import { ApiPath, STORAGE_KEY, StoreKey } from "../constant";
import { createPersistStore } from "../utils/store";
import {
@@ -100,15 +99,17 @@ export const useSyncStore = createPersistStore(
const remoteState = await client.get(config.username);
if (!remoteState || remoteState === "") {
await client.set(config.username, JSON.stringify(localState));
console.log("[Sync] Remote state is empty, using local state instead.");
return
console.log(
"[Sync] Remote state is empty, using local state instead.",
);
return;
} else {
const parsedRemoteState = JSON.parse(
await client.get(config.username),
) as AppState;
mergeAppState(localState, parsedRemoteState);
setLocalAppState(localState);
}
}
} catch (e) {
console.log("[Sync] failed to get remote state", e);
throw e;

View File

@@ -8,8 +8,6 @@ import { getClientConfig } from "../config/client";
import { createPersistStore } from "../utils/store";
import ChatGptIcon from "../icons/chatgpt.png";
import Locale from "../locales";
import { use } from "react";
import { useAppConfig } from ".";
import { ClientApi } from "../client/api";
const ONE_MINUTE = 60 * 1000;