refactor: model management

This commit is contained in:
Junyan Qin
2026-01-01 02:00:24 +08:00
parent 96e40eaf25
commit 4528000fc4
20 changed files with 238 additions and 137 deletions

View File

@@ -1,6 +1,7 @@
import quart
import argon2
import asyncio
import traceback
from .. import group
@@ -141,8 +142,10 @@ class UserRouterGroup(group.RouterGroup):
}
)
except ValueError as e:
traceback.print_exc()
return self.fail(1, str(e))
except Exception as e:
traceback.print_exc()
return self.fail(2, f'OAuth callback failed: {str(e)}')
@self.route('/info', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)

View File

@@ -85,10 +85,17 @@ class LLMModelsService:
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
llm_model = await self.get_llm_model(model_data['uuid'])
await self.ap.model_mgr.load_llm_model(llm_model)
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
if runtime_provider is None:
raise Exception('provider not found')
# Check if default pipeline has no model bound
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
persistence_model.LLMModel(**model_data),
runtime_provider,
)
self.ap.model_mgr.llm_models.append(runtime_llm_model)
# set the default pipeline model to this model
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
@@ -152,8 +159,16 @@ class LLMModelsService:
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
llm_model = await self.get_llm_model(model_uuid)
await self.ap.model_mgr.load_llm_model(llm_model)
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
if runtime_provider is None:
raise Exception('provider not found')
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
persistence_model.LLMModel(**model_data),
runtime_provider,
)
self.ap.model_mgr.llm_models.append(runtime_llm_model)
async def delete_llm_model(self, model_uuid: str) -> None:
"""Delete an LLM model"""
@@ -174,10 +189,10 @@ class LLMModelsService:
if runtime_llm_model is None:
raise Exception('model not found')
else:
runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data)
runtime_llm_model = await self.ap.model_mgr.init_temporary_runtime_llm_model(model_data)
extra_args = model_data.get('extra_args', {})
await runtime_llm_model.requester.invoke_llm(
await runtime_llm_model.provider.requester.invoke_llm(
query=None,
model=runtime_llm_model,
messages=[provider_message.Message(role='user', content='Hello, world! Please just reply a "Hello".')],
@@ -244,8 +259,15 @@ class EmbeddingModelsService:
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
)
embedding_model = await self.get_embedding_model(model_data['uuid'])
await self.ap.model_mgr.load_embedding_model(embedding_model)
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
if runtime_provider is None:
raise Exception('provider not found')
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
persistence_model.EmbeddingModel(**model_data),
runtime_provider,
)
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
return model_data['uuid']
@@ -298,8 +320,16 @@ class EmbeddingModelsService:
)
await self.ap.model_mgr.remove_embedding_model(model_uuid)
embedding_model = await self.get_embedding_model(model_uuid)
await self.ap.model_mgr.load_embedding_model(embedding_model)
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
if runtime_provider is None:
raise Exception('provider not found')
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
persistence_model.EmbeddingModel(**model_data),
runtime_provider,
)
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
async def delete_embedding_model(self, model_uuid: str) -> None:
"""Delete an embedding model"""
@@ -322,9 +352,9 @@ class EmbeddingModelsService:
if runtime_embedding_model is None:
raise Exception('model not found')
else:
runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data)
runtime_embedding_model = await self.ap.model_mgr.init_temporary_runtime_embedding_model(model_data)
await runtime_embedding_model.requester.invoke_embedding(
await runtime_embedding_model.provider.requester.invoke_embedding(
model=runtime_embedding_model,
input_text=['Hello, world!'],
extra_args={},

View File

@@ -61,6 +61,10 @@ class ModelProviderService:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
)
# load to runtime
runtime_provider = await self.ap.model_mgr.load_provider(provider_data)
self.ap.model_mgr.provider_dict[runtime_provider.provider_entity.uuid] = runtime_provider
return provider_data['uuid']
async def update_provider(self, provider_uuid: str, provider_data: dict) -> None:
@@ -72,8 +76,7 @@ class ModelProviderService:
.where(persistence_model.ModelProvider.uuid == provider_uuid)
.values(**provider_data)
)
# Reload all models using this provider
await self.ap.model_mgr.load_models_from_db()
await self.ap.model_mgr.reload_provider(provider_uuid)
async def delete_provider(self, provider_uuid: str) -> None:
"""Delete a provider (only if no models reference it)"""
@@ -100,6 +103,8 @@ class ModelProviderService:
)
)
await self.ap.model_mgr.remove_provider(provider_uuid)
async def get_provider_model_counts(self, provider_uuid: str) -> dict:
"""Get count of models using this provider"""
llm_result = await self.ap.persistence_mgr.execute_async(
@@ -150,3 +155,12 @@ class ModelProviderService:
'api_keys': api_keys or [],
}
)
async def update_space_model_provider_api_keys(self, api_key: str) -> None:
"""Update Space model provider API keys"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
.values(api_keys=[api_key])
)
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')

View File

@@ -149,6 +149,7 @@ class UserService:
space_access_token_expires_at=expires_at,
)
)
await self.ap.provider_service.update_space_model_provider_api_keys(api_key)
return await self.get_user_by_space_account_uuid(space_account_uuid)
# Check if user with same email exists
@@ -167,6 +168,7 @@ class UserService:
space_access_token_expires_at=expires_at,
)
)
await self.ap.provider_service.update_space_model_provider_api_keys(api_key)
return await self.get_user_by_email(email)
# Check if system is already initialized
@@ -189,6 +191,7 @@ class UserService:
space_access_token_expires_at=expires_at,
)
)
await self.ap.provider_service.update_space_model_provider_api_keys(api_key)
return await self.get_user_by_space_account_uuid(space_account_uuid)

View File

@@ -125,25 +125,39 @@ class PersistenceManager:
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
async def write_space_model_providers(self):
space_models_gateway_api_url = self.ap.instance_config.data.get('space', {}).get(
'models_gateway_api_url', 'https://api.langbot.cloud/v1'
)
# write space model providers
result = await self.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.requester == 'space-chat-completions'
)
)
if result.first() is None:
exists_space_chat_completions_model_provider = result.first()
# api keys will be set/updated when the oauth callback
if exists_space_chat_completions_model_provider is None:
self.ap.logger.info('Creating space model providers...')
space_chat_completions_model_provider = {
'uuid': '00000000-0000-0000-0000-000000000000',
'name': 'LangBot Models',
'requester': 'space-chat-completions',
'base_url': 'https://api.langbot.cloud/v1',
'base_url': space_models_gateway_api_url,
'api_keys': [],
}
await self.execute_async(
sqlalchemy.insert(persistence_model.ModelProvider).values(space_chat_completions_model_provider)
)
else:
if exists_space_chat_completions_model_provider.base_url != space_models_gateway_api_url:
await self.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == exists_space_chat_completions_model_provider.uuid)
.values({'base_url': space_models_gateway_api_url})
)
# =================================

View File

@@ -324,7 +324,7 @@ class RuntimeConnectionHandler(handler.Handler):
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs]
result = await llm_model.requester.invoke_llm(
result = await llm_model.provider.requester.invoke_llm(
query=None,
model=llm_model,
messages=messages_obj,

View File

@@ -16,6 +16,9 @@ class ModelManager:
ap: app.Application
provider_dict: dict[str, requester.RuntimeProvider]
"""运行时模型提供商字典, uuid -> RuntimeProvider"""
llm_models: list[requester.RuntimeLLMModel]
embedding_models: list[requester.RuntimeEmbeddingModel]
@@ -51,23 +54,31 @@ class ModelManager:
self.embedding_models = []
# Load all providers first
self.provider_dict = {}
providers_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider)
)
providers = {p.uuid: p for p in providers_result.all()}
for provider in providers_result.all():
try:
runtime_provider = await self.load_provider(provider)
self.provider_dict[provider.uuid] = runtime_provider
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping provider {provider.uuid}')
continue
except Exception as e:
self.ap.logger.error(f'Failed to load provider {provider.uuid}: {e}\n{traceback.format_exc()}')
# Load LLM models
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
llm_models = result.all()
for llm_model in llm_models:
try:
provider = providers.get(llm_model.provider_uuid)
provider = self.provider_dict.get(llm_model.provider_uuid)
if provider is None:
self.ap.logger.warning(f'Provider {llm_model.provider_uuid} not found for model {llm_model.uuid}')
continue
await self.load_llm_model_with_provider(llm_model, provider)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}')
runtime_llm_model = await self.load_llm_model_with_provider(llm_model, provider)
self.llm_models.append(runtime_llm_model)
except Exception as e:
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
@@ -76,17 +87,14 @@ class ModelManager:
embedding_models = result.all()
for embedding_model in embedding_models:
try:
provider = providers.get(embedding_model.provider_uuid)
provider = self.provider_dict.get(embedding_model.provider_uuid)
if provider is None:
self.ap.logger.warning(
f'Provider {embedding_model.provider_uuid} not found for model {embedding_model.uuid}'
)
continue
await self.load_embedding_model_with_provider(embedding_model, provider)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(
f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}'
)
runtime_embedding_model = await self.load_embedding_model_with_provider(embedding_model, provider)
self.embedding_models.append(runtime_embedding_model)
except Exception as e:
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
@@ -149,123 +157,139 @@ class ModelManager:
preserve_uuid=True,
)
async def init_runtime_llm_model(
async def init_temporary_runtime_llm_model(
self,
model_info: dict,
):
) -> requester.RuntimeLLMModel:
"""Initialize runtime LLM model from dict (for testing)"""
provider_info = model_info.get('provider', {})
requester_name = provider_info.get('requester', '')
base_url = provider_info.get('base_url', '')
api_keys = provider_info.get('api_keys', [])
if requester_name not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(requester_name)
requester_cfg = {'base_url': base_url}
requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
# Create a temporary model entity
model_entity = persistence_model.LLMModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid='',
abilities=model_info.get('abilities', []),
extra_args=model_info.get('extra_args', {}),
)
runtime_provider = await self.load_provider(provider_info)
runtime_llm_model = requester.RuntimeLLMModel(
model_entity=model_entity,
token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys),
requester=requester_inst,
model_entity=persistence_model.LLMModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid='',
abilities=model_info.get('abilities', []),
extra_args=model_info.get('extra_args', {}),
),
provider=runtime_provider,
)
return runtime_llm_model
async def init_runtime_embedding_model(
async def init_temporary_runtime_embedding_model(
self,
model_info: dict,
):
) -> requester.RuntimeEmbeddingModel:
"""Initialize runtime embedding model from dict (for testing)"""
provider_info = model_info.get('provider', {})
requester_name = provider_info.get('requester', '')
base_url = provider_info.get('base_url', '')
api_keys = provider_info.get('api_keys', [])
if requester_name not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(requester_name)
requester_cfg = {'base_url': base_url}
requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
model_entity = persistence_model.EmbeddingModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid='',
extra_args=model_info.get('extra_args', {}),
)
runtime_provider = await self.load_provider(provider_info)
runtime_embedding_model = requester.RuntimeEmbeddingModel(
model_entity=model_entity,
token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys),
requester=requester_inst,
model_entity=persistence_model.EmbeddingModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid='',
extra_args=model_info.get('extra_args', {}),
),
provider=runtime_provider,
)
return runtime_embedding_model
async def load_provider(
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
) -> requester.RuntimeProvider:
"""Load provider from dict"""
if isinstance(provider_info, sqlalchemy.Row):
provider_entity = persistence_model.ModelProvider(**provider_info._mapping)
elif isinstance(provider_info, dict):
provider_entity = persistence_model.ModelProvider(**provider_info)
else:
provider_entity = provider_info
if provider_entity.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
requester_inst = self.requester_dict[provider_entity.requester](
ap=self.ap, config={'base_url': provider_entity.base_url}
)
await requester_inst.initialize()
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
provider = requester.RuntimeProvider(
provider_entity=provider_entity,
token_mgr=token_mgr,
requester=requester_inst,
)
return provider
async def remove_provider(self, provider_uuid: str):
"""Remove provider
This method will not consider the models using this provider,
because the models should be removed by the caller.
"""
del self.provider_dict[provider_uuid]
async def reload_provider(self, provider_uuid: str):
"""Reload provider"""
provider_entity = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == provider_uuid
)
)
provider_entity = provider_entity.first()
if provider_entity is None:
raise provider_errors.ProviderNotFoundError(provider_uuid)
new_runtime_provider = await self.load_provider(provider_entity)
# update refs in runtime models
for model in self.llm_models:
if model.provider.provider_entity.uuid == provider_uuid:
model.provider = new_runtime_provider
for model in self.embedding_models:
if model.provider.provider_entity.uuid == provider_uuid:
model.provider = new_runtime_provider
# update ref in provider dict
self.provider_dict[provider_uuid] = new_runtime_provider
async def load_llm_model_with_provider(
self,
model_info: persistence_model.LLMModel | sqlalchemy.Row,
provider: persistence_model.ModelProvider | sqlalchemy.Row,
):
provider: requester.RuntimeProvider,
) -> requester.RuntimeLLMModel:
"""Load LLM model with provider info"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.LLMModel(**model_info._mapping)
if isinstance(provider, sqlalchemy.Row):
provider = persistence_model.ModelProvider(**provider._mapping)
if provider.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(provider.requester)
requester_cfg = {'base_url': provider.base_url}
requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
runtime_llm_model = requester.RuntimeLLMModel(
model_entity=model_info,
token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []),
requester=requester_inst,
provider=provider,
)
self.llm_models.append(runtime_llm_model)
return runtime_llm_model
async def load_embedding_model_with_provider(
self,
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row,
provider: persistence_model.ModelProvider | sqlalchemy.Row,
):
provider: requester.RuntimeProvider,
) -> requester.RuntimeEmbeddingModel:
"""Load embedding model with provider info"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
if isinstance(provider, sqlalchemy.Row):
provider = persistence_model.ModelProvider(**provider._mapping)
if provider.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(provider.requester)
requester_cfg = {'base_url': provider.base_url}
requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
runtime_embedding_model = requester.RuntimeEmbeddingModel(
model_entity=model_info,
token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []),
requester=requester_inst,
provider=provider,
)
self.embedding_models.append(runtime_embedding_model)
return runtime_embedding_model
async def load_llm_model(self, model_info: dict):
"""Load LLM model from dict (with provider info)"""

View File

@@ -11,11 +11,11 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class RuntimeLLMModel:
"""运行时模型"""
class RuntimeProvider:
"""运行时模型提供商"""
model_entity: persistence_model.LLMModel
"""模型数据"""
provider_entity: persistence_model.ModelProvider
"""提供商数据"""
token_mgr: token.TokenManager
"""api key管理器"""
@@ -25,36 +25,49 @@ class RuntimeLLMModel:
def __init__(
self,
model_entity: persistence_model.LLMModel,
provider_entity: persistence_model.ModelProvider,
token_mgr: token.TokenManager,
requester: ProviderAPIRequester,
):
self.model_entity = model_entity
self.provider_entity = provider_entity
self.token_mgr = token_mgr
self.requester = requester
class RuntimeLLMModel:
"""运行时模型"""
model_entity: persistence_model.LLMModel
"""模型数据"""
provider: RuntimeProvider
"""提供商实例"""
def __init__(
self,
model_entity: persistence_model.LLMModel,
provider: RuntimeProvider,
):
self.model_entity = model_entity
self.provider = provider
class RuntimeEmbeddingModel:
"""运行时 Embedding 模型"""
model_entity: persistence_model.EmbeddingModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: ProviderAPIRequester
"""请求器实例"""
provider: RuntimeProvider
"""提供商实例"""
def __init__(
self,
model_entity: persistence_model.EmbeddingModel,
token_mgr: token.TokenManager,
requester: ProviderAPIRequester,
provider: RuntimeProvider,
):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
self.provider = provider
class ProviderAPIRequester(metaclass=abc.ABCMeta):

View File

@@ -56,7 +56,7 @@ class AnthropicMessages(requester.ProviderAPIRequester):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message:
self.client.api_key = model.token_mgr.get_token()
self.client.api_key = model.provider.token_mgr.get_token()
args = extra_args.copy()
args['model'] = model.model_entity.name
@@ -190,7 +190,7 @@ class AnthropicMessages(requester.ProviderAPIRequester):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message:
self.client.api_key = model.token_mgr.get_token()
self.client.api_key = model.provider.token_mgr.get_token()
args = extra_args.copy()
args['model'] = model.model_entity.name

View File

@@ -30,7 +30,7 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name
@@ -117,7 +117,7 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
if is_use_dashscope_call:
response = dashscope.MultiModalConversation.call(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key = "sk-xxx"
api_key=use_model.token_mgr.get_token(),
api_key=use_model.provider.token_mgr.get_token(),
model=use_model.model_entity.name,
messages=messages,
result_format='message',

View File

@@ -130,7 +130,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.MessageChunk:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name
@@ -251,7 +251,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name
@@ -337,7 +337,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
extra_args: dict[str, typing.Any] = {},
) -> list[list[float]]:
"""调用 Embedding API"""
self.client.api_key = model.token_mgr.get_token()
self.client.api_key = model.provider.token_mgr.get_token()
args = {
'model': model.model_entity.name,

View File

@@ -26,7 +26,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name

View File

@@ -29,7 +29,7 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.MessageChunk:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name

View File

@@ -109,7 +109,7 @@ class JieKouAIChatCompletions(chatcmpl.OpenAIChatCompletions):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name

View File

@@ -131,7 +131,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name
@@ -181,7 +181,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name

View File

@@ -27,7 +27,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name

View File

@@ -109,7 +109,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token()
self.client.api_key = use_model.provider.token_mgr.get_token()
args = {}
args['model'] = use_model.model_entity.name

View File

@@ -130,7 +130,7 @@ class LocalAgentRunner(runner.RequestRunner):
if not is_stream:
# 非流式输出,直接请求
msg = await use_llm_model.requester.invoke_llm(
msg = await use_llm_model.provider.requester.invoke_llm(
query,
use_llm_model,
req_messages,
@@ -147,7 +147,7 @@ class LocalAgentRunner(runner.RequestRunner):
accumulated_content = '' # 从开始累积的所有内容
last_role = 'assistant'
msg_sequence = 1
async for msg in use_llm_model.requester.invoke_llm_stream(
async for msg in use_llm_model.provider.requester.invoke_llm_stream(
query,
use_llm_model,
req_messages,
@@ -250,7 +250,7 @@ class LocalAgentRunner(runner.RequestRunner):
last_role = 'assistant'
msg_sequence = first_end_sequence
async for msg in use_llm_model.requester.invoke_llm_stream(
async for msg in use_llm_model.provider.requester.invoke_llm_stream(
query,
use_llm_model,
req_messages,
@@ -306,7 +306,7 @@ class LocalAgentRunner(runner.RequestRunner):
)
else:
# 处理完所有调用,再次请求
msg = await use_llm_model.requester.invoke_llm(
msg = await use_llm_model.provider.requester.invoke_llm(
query,
use_llm_model,
req_messages,

View File

@@ -33,7 +33,7 @@ class Embedder(BaseService):
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
# get embeddings
embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding(
embeddings_list: list[list[float]] = await embedding_model.provider.requester.invoke_embedding(
model=embedding_model,
input_text=chunks,
extra_args={}, # TODO: add extra args

View File

@@ -19,7 +19,7 @@ class Retriever(base_service.BaseService):
f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}"
)
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
query_embedding: list[float] = await embedding_model.provider.requester.invoke_embedding(
model=embedding_model,
input_text=[query],
extra_args={}, # TODO: add extra args