diff --git a/app/api/auth.ts b/app/api/auth.ts index b41e34e05..16c8034eb 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -1,7 +1,7 @@ import { NextRequest } from "next/server"; import { getServerSideConfig } from "../config/server"; import md5 from "spark-md5"; -import { ACCESS_CODE_PREFIX } from "../constant"; +import { ACCESS_CODE_PREFIX, ModelProvider } from "../constant"; function getIP(req: NextRequest) { let ip = req.ip ?? req.headers.get("x-real-ip"); @@ -16,15 +16,15 @@ function getIP(req: NextRequest) { function parseApiKey(bearToken: string) { const token = bearToken.trim().replaceAll("Bearer ", "").trim(); - const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); + const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX); return { - accessCode: isOpenAiKey ? "" : token.slice(ACCESS_CODE_PREFIX.length), - apiKey: isOpenAiKey ? token : "", + accessCode: isApiKey ? "" : token.slice(ACCESS_CODE_PREFIX.length), + apiKey: isApiKey ? token : "", }; } -export function auth(req: NextRequest) { +export function auth(req: NextRequest, modelProvider: ModelProvider) { const authToken = req.headers.get("Authorization") ?? ""; // check if it is openai api key or user token @@ -49,22 +49,23 @@ export function auth(req: NextRequest) { if (serverConfig.hideUserApiKey && !!apiKey) { return { error: true, - msg: "you are not allowed to access openai with your own api key", + msg: "you are not allowed to access with your own api key", }; } // if user does not provide an api key, inject system api key if (!apiKey) { - const serverApiKey = serverConfig.isAzure - ? serverConfig.azureApiKey - : serverConfig.apiKey; + const serverConfig = getServerSideConfig(); - if (serverApiKey) { + const systemApiKey = + modelProvider === ModelProvider.GeminiPro + ? serverConfig.googleApiKey + : serverConfig.isAzure + ? serverConfig.azureApiKey + : serverConfig.apiKey; + if (systemApiKey) { console.log("[Auth] use system api key"); - req.headers.set( - "Authorization", - `${serverConfig.isAzure ? "" : "Bearer "}${serverApiKey}`, - ); + req.headers.set("Authorization", `Bearer ${systemApiKey}`); } else { console.log("[Auth] admin did not provide an api key"); } diff --git a/app/api/common.ts b/app/api/common.ts index 13cfab03c..2d89ea1e5 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -9,8 +9,21 @@ const serverConfig = getServerSideConfig(); export async function requestOpenai(req: NextRequest) { const controller = new AbortController(); - const authValue = req.headers.get("Authorization") ?? ""; - const authHeaderName = serverConfig.isAzure ? "api-key" : "Authorization"; + var authValue, + authHeaderName = ""; + if (serverConfig.isAzure) { + authValue = + req.headers + .get("Authorization") + ?.trim() + .replaceAll("Bearer ", "") + .trim() ?? ""; + + authHeaderName = "api-key"; + } else { + authValue = req.headers.get("Authorization") ?? ""; + authHeaderName = "Authorization"; + } let path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll( "/api/openai/", diff --git a/app/api/google/[...path]/route.ts b/app/api/google/[...path]/route.ts index 217556784..869bd5076 100644 --- a/app/api/google/[...path]/route.ts +++ b/app/api/google/[...path]/route.ts @@ -1,7 +1,7 @@ import { NextRequest, NextResponse } from "next/server"; import { auth } from "../../auth"; import { getServerSideConfig } from "@/app/config/server"; -import { GEMINI_BASE_URL, Google } from "@/app/constant"; +import { GEMINI_BASE_URL, Google, ModelProvider } from "@/app/constant"; async function handle( req: NextRequest, @@ -39,10 +39,18 @@ async function handle( 10 * 60 * 1000, ); + const authResult = auth(req, ModelProvider.GeminiPro); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + const bearToken = req.headers.get("Authorization") ?? ""; const token = bearToken.trim().replaceAll("Bearer ", "").trim(); const key = token ? token : serverConfig.googleApiKey; + if (!key) { return NextResponse.json( { @@ -56,7 +64,6 @@ async function handle( } const fetchUrl = `${baseUrl}/${path}?key=${key}`; - const fetchOptions: RequestInit = { headers: { "Content-Type": "application/json", diff --git a/app/api/openai/[...path]/route.ts b/app/api/openai/[...path]/route.ts index 2addd53a5..77059c151 100644 --- a/app/api/openai/[...path]/route.ts +++ b/app/api/openai/[...path]/route.ts @@ -1,6 +1,6 @@ import { type OpenAIListModelResponse } from "@/app/client/platforms/openai"; import { getServerSideConfig } from "@/app/config/server"; -import { OpenaiPath } from "@/app/constant"; +import { ModelProvider, OpenaiPath } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "../../auth"; @@ -45,7 +45,7 @@ async function handle( ); } - const authResult = auth(req); + const authResult = auth(req, ModelProvider.GPT); if (authResult.error) { return NextResponse.json(authResult, { status: 401, @@ -75,4 +75,22 @@ export const GET = handle; export const POST = handle; export const runtime = "edge"; -export const preferredRegion = ['arn1', 'bom1', 'cdg1', 'cle1', 'cpt1', 'dub1', 'fra1', 'gru1', 'hnd1', 'iad1', 'icn1', 'kix1', 'lhr1', 'pdx1', 'sfo1', 'sin1', 'syd1']; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index ec7d79564..c35e93cb3 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -21,10 +21,24 @@ export class GeminiProApi implements LLMApi { } async chat(options: ChatOptions): Promise { const messages = options.messages.map((v) => ({ - role: v.role.replace("assistant", "model").replace("system", "model"), + role: v.role.replace("assistant", "model").replace("system", "user"), parts: [{ text: v.content }], })); + // google requires that role in neighboring messages must not be the same + for (let i = 0; i < messages.length - 1; ) { + // Check if current and next item both have the role "model" + if (messages[i].role === messages[i + 1].role) { + // Concatenate the 'parts' of the current and next item + messages[i].parts = messages[i].parts.concat(messages[i + 1].parts); + // Remove the next item + messages.splice(i + 1, 1); + } else { + // Move to the next item + i++; + } + } + const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, @@ -43,14 +57,6 @@ export class GeminiProApi implements LLMApi { topP: modelConfig.top_p, // "topK": modelConfig.top_k, }, - // stream: options.config.stream, - // model: modelConfig.model, - // temperature: modelConfig.temperature, - // presence_penalty: modelConfig.presence_penalty, - // frequency_penalty: modelConfig.frequency_penalty, - // top_p: modelConfig.top_p, - // max_tokens: Math.max(modelConfig.max_tokens, 1024), - // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. }; console.log("[Request] google payload: ", requestPayload); diff --git a/app/config/server.ts b/app/config/server.ts index 83c711242..c6251a5c2 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -65,6 +65,7 @@ export const getServerSideConfig = () => { } const isAzure = !!process.env.AZURE_URL; + const isGoogle = !!process.env.GOOGLE_API_KEY; const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); @@ -84,6 +85,7 @@ export const getServerSideConfig = () => { azureApiKey: process.env.AZURE_API_KEY, azureApiVersion: process.env.AZURE_API_VERSION, + isGoogle, googleApiKey: process.env.GOOGLE_API_KEY, googleUrl: process.env.GOOGLE_URL, diff --git a/app/store/chat.ts b/app/store/chat.ts index 1dcf4e646..4af5a52ac 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -389,24 +389,22 @@ export const useChatStore = createPersistStore( const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts; var systemPrompts: ChatMessage[] = []; - if (modelConfig.model !== "gemini-pro") { - systemPrompts = shouldInjectSystemPrompts - ? [ - createMessage({ - role: "system", - content: fillTemplateWith("", { - ...modelConfig, - template: DEFAULT_SYSTEM_TEMPLATE, - }), + systemPrompts = shouldInjectSystemPrompts + ? [ + createMessage({ + role: "system", + content: fillTemplateWith("", { + ...modelConfig, + template: DEFAULT_SYSTEM_TEMPLATE, }), - ] - : []; - if (shouldInjectSystemPrompts) { - console.log( - "[Global System Prompt] ", - systemPrompts.at(0)?.content ?? "empty", - ); - } + }), + ] + : []; + if (shouldInjectSystemPrompts) { + console.log( + "[Global System Prompt] ", + systemPrompts.at(0)?.content ?? "empty", + ); } // long term memory diff --git a/app/utils/model.ts b/app/utils/model.ts index c4a4833ed..b2a42ef02 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -10,24 +10,23 @@ export function collectModelTable( available: boolean; name: string; displayName: string; - provider: LLMModel["provider"]; + provider?: LLMModel["provider"]; // Marked as optional } > = {}; // default models - models.forEach( - (m) => - (modelTable[m.name] = { - ...m, - displayName: m.name, - }), - ); + models.forEach((m) => { + modelTable[m.name] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); // server custom models customModels .split(",") .filter((v) => !!v && v.length > 0) - .map((m) => { + .forEach((m) => { const available = !m.startsWith("-"); const nameConfig = m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; @@ -35,15 +34,15 @@ export function collectModelTable( // enable or disable all models if (name === "all") { - Object.values(modelTable).forEach((m) => (m.available = available)); + Object.values(modelTable).forEach((model) => (model.available = available)); + } else { + modelTable[name] = { + name, + displayName: displayName || name, + available, + provider: modelTable[name]?.provider, // Use optional chaining + }; } - - modelTable[name] = { - name, - displayName: displayName || name, - available, - provider: modelTable[name].provider, - }; }); return modelTable; }