mirror of
				https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	Merge pull request #5769 from ryanhex53/fix-model-multi@
Custom model names can include the `@` symbol by itself.
This commit is contained in:
		@@ -1,8 +1,8 @@
 | 
			
		||||
import { NextRequest, NextResponse } from "next/server";
 | 
			
		||||
import { getServerSideConfig } from "../config/server";
 | 
			
		||||
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
 | 
			
		||||
import { isModelAvailableInServer } from "../utils/model";
 | 
			
		||||
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
 | 
			
		||||
import { getModelProvider, isModelAvailableInServer } from "../utils/model";
 | 
			
		||||
 | 
			
		||||
const serverConfig = getServerSideConfig();
 | 
			
		||||
 | 
			
		||||
@@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
 | 
			
		||||
        .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
 | 
			
		||||
        .forEach((m) => {
 | 
			
		||||
          const [fullName, displayName] = m.split("=");
 | 
			
		||||
          const [_, providerName] = fullName.split("@");
 | 
			
		||||
          const [_, providerName] = getModelProvider(fullName);
 | 
			
		||||
          if (providerName === "azure" && !displayName) {
 | 
			
		||||
            const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
 | 
			
		||||
              "deployments/",
 | 
			
		||||
 
 | 
			
		||||
@@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
 | 
			
		||||
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
 | 
			
		||||
 | 
			
		||||
import { isEmpty } from "lodash-es";
 | 
			
		||||
import { getModelProvider } from "../utils/model";
 | 
			
		||||
 | 
			
		||||
const localStorage = safeLocalStorage();
 | 
			
		||||
 | 
			
		||||
@@ -645,7 +646,7 @@ export function ChatActions(props: {
 | 
			
		||||
          onClose={() => setShowModelSelector(false)}
 | 
			
		||||
          onSelection={(s) => {
 | 
			
		||||
            if (s.length === 0) return;
 | 
			
		||||
            const [model, providerName] = s[0].split("@");
 | 
			
		||||
            const [model, providerName] = getModelProvider(s[0]);
 | 
			
		||||
            chatStore.updateCurrentSession((session) => {
 | 
			
		||||
              session.mask.modelConfig.model = model as ModelType;
 | 
			
		||||
              session.mask.modelConfig.providerName =
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
 | 
			
		||||
import { useAllModels } from "../utils/hooks";
 | 
			
		||||
import { groupBy } from "lodash-es";
 | 
			
		||||
import styles from "./model-config.module.scss";
 | 
			
		||||
import { getModelProvider } from "../utils/model";
 | 
			
		||||
 | 
			
		||||
export function ModelConfigList(props: {
 | 
			
		||||
  modelConfig: ModelConfig;
 | 
			
		||||
@@ -28,7 +29,9 @@ export function ModelConfigList(props: {
 | 
			
		||||
          value={value}
 | 
			
		||||
          align="left"
 | 
			
		||||
          onChange={(e) => {
 | 
			
		||||
            const [model, providerName] = e.currentTarget.value.split("@");
 | 
			
		||||
            const [model, providerName] = getModelProvider(
 | 
			
		||||
              e.currentTarget.value,
 | 
			
		||||
            );
 | 
			
		||||
            props.updateConfig((config) => {
 | 
			
		||||
              config.model = ModalConfigValidator.model(model);
 | 
			
		||||
              config.providerName = providerName as ServiceProvider;
 | 
			
		||||
@@ -247,7 +250,9 @@ export function ModelConfigList(props: {
 | 
			
		||||
          aria-label={Locale.Settings.CompressModel.Title}
 | 
			
		||||
          value={compressModelValue}
 | 
			
		||||
          onChange={(e) => {
 | 
			
		||||
            const [model, providerName] = e.currentTarget.value.split("@");
 | 
			
		||||
            const [model, providerName] = getModelProvider(
 | 
			
		||||
              e.currentTarget.value,
 | 
			
		||||
            );
 | 
			
		||||
            props.updateConfig((config) => {
 | 
			
		||||
              config.compressModel = ModalConfigValidator.model(model);
 | 
			
		||||
              config.compressProviderName = providerName as ServiceProvider;
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
 | 
			
		||||
import { createPersistStore } from "../utils/store";
 | 
			
		||||
import { ensure } from "../utils/clone";
 | 
			
		||||
import { DEFAULT_CONFIG } from "./config";
 | 
			
		||||
import { getModelProvider } from "../utils/model";
 | 
			
		||||
 | 
			
		||||
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
 | 
			
		||||
 | 
			
		||||
@@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
 | 
			
		||||
        .then((res) => {
 | 
			
		||||
          const defaultModel = res.defaultModel ?? "";
 | 
			
		||||
          if (defaultModel !== "") {
 | 
			
		||||
            const [model, providerName] = defaultModel.split("@");
 | 
			
		||||
            const [model, providerName] = getModelProvider(defaultModel);
 | 
			
		||||
            DEFAULT_CONFIG.modelConfig.model = model;
 | 
			
		||||
            DEFAULT_CONFIG.modelConfig.providerName = providerName;
 | 
			
		||||
            DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          return res;
 | 
			
		||||
 
 | 
			
		||||
@@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * get model name and provider from a formatted string,
 | 
			
		||||
 * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
 | 
			
		||||
 * @param modelWithProvider model name with provider separated by last `@` char,
 | 
			
		||||
 * @returns [model, provider] tuple, if no `@` char found, provider is undefined
 | 
			
		||||
 */
 | 
			
		||||
export function getModelProvider(modelWithProvider: string): [string, string?] {
 | 
			
		||||
  const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
 | 
			
		||||
  return [model, provider];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function collectModelTable(
 | 
			
		||||
  models: readonly LLMModel[],
 | 
			
		||||
  customModels: string,
 | 
			
		||||
@@ -79,10 +90,10 @@ export function collectModelTable(
 | 
			
		||||
        );
 | 
			
		||||
      } else {
 | 
			
		||||
        // 1. find model by name, and set available value
 | 
			
		||||
        const [customModelName, customProviderName] = name.split("@");
 | 
			
		||||
        const [customModelName, customProviderName] = getModelProvider(name);
 | 
			
		||||
        let count = 0;
 | 
			
		||||
        for (const fullName in modelTable) {
 | 
			
		||||
          const [modelName, providerName] = fullName.split("@");
 | 
			
		||||
          const [modelName, providerName] = getModelProvider(fullName);
 | 
			
		||||
          if (
 | 
			
		||||
            customModelName == modelName &&
 | 
			
		||||
            (customProviderName === undefined ||
 | 
			
		||||
@@ -102,7 +113,7 @@ export function collectModelTable(
 | 
			
		||||
        }
 | 
			
		||||
        // 2. if model not exists, create new model with available value
 | 
			
		||||
        if (count === 0) {
 | 
			
		||||
          let [customModelName, customProviderName] = name.split("@");
 | 
			
		||||
          let [customModelName, customProviderName] = getModelProvider(name);
 | 
			
		||||
          const provider = customProvider(
 | 
			
		||||
            customProviderName || customModelName,
 | 
			
		||||
          );
 | 
			
		||||
@@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
 | 
			
		||||
      for (const key of Object.keys(modelTable)) {
 | 
			
		||||
        if (
 | 
			
		||||
          modelTable[key].available &&
 | 
			
		||||
          key.split("@").shift() == defaultModel
 | 
			
		||||
          getModelProvider(key)[0] == defaultModel
 | 
			
		||||
        ) {
 | 
			
		||||
          modelTable[key].isDefault = true;
 | 
			
		||||
          break;
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										31
									
								
								test/model-provider.test.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								test/model-provider.test.ts
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
import { getModelProvider } from "../app/utils/model";
 | 
			
		||||
 | 
			
		||||
describe("getModelProvider", () => {
 | 
			
		||||
  test("should return model and provider when input contains '@'", () => {
 | 
			
		||||
    const input = "model@provider";
 | 
			
		||||
    const [model, provider] = getModelProvider(input);
 | 
			
		||||
    expect(model).toBe("model");
 | 
			
		||||
    expect(provider).toBe("provider");
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  test("should return model and undefined provider when input does not contain '@'", () => {
 | 
			
		||||
    const input = "model";
 | 
			
		||||
    const [model, provider] = getModelProvider(input);
 | 
			
		||||
    expect(model).toBe("model");
 | 
			
		||||
    expect(provider).toBeUndefined();
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  test("should handle multiple '@' characters correctly", () => {
 | 
			
		||||
    const input = "model@provider@extra";
 | 
			
		||||
    const [model, provider] = getModelProvider(input);
 | 
			
		||||
    expect(model).toBe("model@provider");
 | 
			
		||||
    expect(provider).toBe("extra");
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  test("should return empty strings when input is empty", () => {
 | 
			
		||||
    const input = "";
 | 
			
		||||
    const [model, provider] = getModelProvider(input);
 | 
			
		||||
    expect(model).toBe("");
 | 
			
		||||
    expect(provider).toBeUndefined();
 | 
			
		||||
  });
 | 
			
		||||
});
 | 
			
		||||
		Reference in New Issue
	
	Block a user