Compare commits
9 Commits
copilot/bu
...
feat/litel
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
527c8dc76a | ||
|
|
8dd16aac51 | ||
|
|
d170bdd343 | ||
|
|
b33d05f99a | ||
|
|
de61b5d368 | ||
|
|
58c1916712 | ||
|
|
a8fba46040 | ||
|
|
3115d6f6dd | ||
|
|
323481d69b |
@@ -77,6 +77,7 @@ dependencies = [
|
|||||||
"pymilvus>=2.6.4",
|
"pymilvus>=2.6.4",
|
||||||
"pgvector>=0.4.1",
|
"pgvector>=0.4.1",
|
||||||
"botocore>=1.42.39",
|
"botocore>=1.42.39",
|
||||||
|
"litellm>=1.0.0",
|
||||||
]
|
]
|
||||||
keywords = [
|
keywords = [
|
||||||
"bot",
|
"bot",
|
||||||
|
|||||||
@@ -97,3 +97,51 @@ class EmbeddingModelsRouterGroup(group.RouterGroup):
|
|||||||
await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data)
|
await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data)
|
||||||
|
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
@group.group_class('models/rerank', '/api/v1/provider/models/rerank')
|
||||||
|
class RerankModelsRouterGroup(group.RouterGroup):
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def _() -> str:
|
||||||
|
if quart.request.method == 'GET':
|
||||||
|
provider_uuid = quart.request.args.get('provider_uuid')
|
||||||
|
if provider_uuid:
|
||||||
|
return self.success(
|
||||||
|
data={
|
||||||
|
'models': await self.ap.rerank_models_service.get_rerank_models_by_provider(provider_uuid)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.success(data={'models': await self.ap.rerank_models_service.get_rerank_models()})
|
||||||
|
elif quart.request.method == 'POST':
|
||||||
|
json_data = await quart.request.json
|
||||||
|
model_uuid = await self.ap.rerank_models_service.create_rerank_model(json_data)
|
||||||
|
return self.success(data={'uuid': model_uuid})
|
||||||
|
|
||||||
|
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def _(model_uuid: str) -> str:
|
||||||
|
if quart.request.method == 'GET':
|
||||||
|
model = await self.ap.rerank_models_service.get_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
return self.http_status(404, -1, 'model not found')
|
||||||
|
|
||||||
|
return self.success(data={'model': model})
|
||||||
|
elif quart.request.method == 'PUT':
|
||||||
|
json_data = await quart.request.json
|
||||||
|
|
||||||
|
await self.ap.rerank_models_service.update_rerank_model(model_uuid, json_data)
|
||||||
|
|
||||||
|
return self.success()
|
||||||
|
elif quart.request.method == 'DELETE':
|
||||||
|
await self.ap.rerank_models_service.delete_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
return self.success()
|
||||||
|
|
||||||
|
@self.route('/<model_uuid>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def _(model_uuid: str) -> str:
|
||||||
|
json_data = await quart.request.json
|
||||||
|
|
||||||
|
await self.ap.rerank_models_service.test_rerank_model(model_uuid, json_data)
|
||||||
|
|
||||||
|
return self.success()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
|||||||
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
|
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
|
||||||
provider['llm_count'] = counts['llm_count']
|
provider['llm_count'] = counts['llm_count']
|
||||||
provider['embedding_count'] = counts['embedding_count']
|
provider['embedding_count'] = counts['embedding_count']
|
||||||
|
provider['rerank_count'] = counts['rerank_count']
|
||||||
return self.success(data={'providers': providers})
|
return self.success(data={'providers': providers})
|
||||||
elif quart.request.method == 'POST':
|
elif quart.request.method == 'POST':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
@@ -32,6 +33,7 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
|||||||
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
|
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
|
||||||
provider['llm_count'] = counts['llm_count']
|
provider['llm_count'] = counts['llm_count']
|
||||||
provider['embedding_count'] = counts['embedding_count']
|
provider['embedding_count'] = counts['embedding_count']
|
||||||
|
provider['rerank_count'] = counts['rerank_count']
|
||||||
return self.success(data={'provider': provider})
|
return self.success(data={'provider': provider})
|
||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
|
|||||||
@@ -367,3 +367,162 @@ class EmbeddingModelsService:
|
|||||||
input_text=['Hello, world!'],
|
input_text=['Hello, world!'],
|
||||||
extra_args={},
|
extra_args={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankModelsService:
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application) -> None:
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def get_rerank_models(self) -> list[dict]:
|
||||||
|
"""Get all rerank models with provider info"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
||||||
|
models = result.all()
|
||||||
|
|
||||||
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.ModelProvider)
|
||||||
|
)
|
||||||
|
providers = {p.uuid: p for p in providers_result.all()}
|
||||||
|
|
||||||
|
models_list = []
|
||||||
|
for model in models:
|
||||||
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||||
|
provider = providers.get(model.provider_uuid)
|
||||||
|
if provider:
|
||||||
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||||
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||||
|
models_list.append(model_dict)
|
||||||
|
|
||||||
|
return models_list
|
||||||
|
|
||||||
|
async def get_rerank_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
||||||
|
"""Get rerank models by provider UUID"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||||
|
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
models = result.all()
|
||||||
|
return [self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, m) for m in models]
|
||||||
|
|
||||||
|
async def create_rerank_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
|
||||||
|
"""Create a new rerank model"""
|
||||||
|
if not preserve_uuid:
|
||||||
|
model_data['uuid'] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if 'provider' in model_data:
|
||||||
|
provider_data = model_data.pop('provider')
|
||||||
|
if provider_data.get('uuid'):
|
||||||
|
model_data['provider_uuid'] = provider_data['uuid']
|
||||||
|
else:
|
||||||
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||||
|
requester=provider_data.get('requester', ''),
|
||||||
|
base_url=provider_data.get('base_url', ''),
|
||||||
|
api_keys=provider_data.get('api_keys', []),
|
||||||
|
)
|
||||||
|
model_data['provider_uuid'] = provider_uuid
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.insert(persistence_model.RerankModel).values(**model_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||||
|
if runtime_provider is None:
|
||||||
|
raise Exception('provider not found')
|
||||||
|
|
||||||
|
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||||
|
persistence_model.RerankModel(**model_data),
|
||||||
|
runtime_provider,
|
||||||
|
)
|
||||||
|
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||||
|
|
||||||
|
return model_data['uuid']
|
||||||
|
|
||||||
|
async def get_rerank_model(self, model_uuid: str) -> dict | None:
|
||||||
|
"""Get a single rerank model with provider info"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||||
|
)
|
||||||
|
model = result.first()
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||||
|
|
||||||
|
provider_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||||
|
persistence_model.ModelProvider.uuid == model.provider_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider = provider_result.first()
|
||||||
|
if provider:
|
||||||
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||||
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||||
|
|
||||||
|
return model_dict
|
||||||
|
|
||||||
|
async def update_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||||
|
"""Update an existing rerank model"""
|
||||||
|
if 'uuid' in model_data:
|
||||||
|
del model_data['uuid']
|
||||||
|
|
||||||
|
if 'provider' in model_data:
|
||||||
|
provider_data = model_data.pop('provider')
|
||||||
|
if provider_data.get('uuid'):
|
||||||
|
model_data['provider_uuid'] = provider_data['uuid']
|
||||||
|
else:
|
||||||
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||||
|
requester=provider_data.get('requester', ''),
|
||||||
|
base_url=provider_data.get('base_url', ''),
|
||||||
|
api_keys=provider_data.get('api_keys', []),
|
||||||
|
)
|
||||||
|
model_data['provider_uuid'] = provider_uuid
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.update(persistence_model.RerankModel)
|
||||||
|
.where(persistence_model.RerankModel.uuid == model_uuid)
|
||||||
|
.values(**model_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||||
|
if runtime_provider is None:
|
||||||
|
raise Exception('provider not found')
|
||||||
|
|
||||||
|
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||||
|
persistence_model.RerankModel(**model_data),
|
||||||
|
runtime_provider,
|
||||||
|
)
|
||||||
|
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||||
|
|
||||||
|
async def delete_rerank_model(self, model_uuid: str) -> None:
|
||||||
|
"""Delete a rerank model"""
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.delete(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||||
|
)
|
||||||
|
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
async def test_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||||
|
"""Test a rerank model"""
|
||||||
|
runtime_rerank_model: model_requester.RuntimeRerankModel | None = None
|
||||||
|
|
||||||
|
if model_uuid != '_':
|
||||||
|
for model in self.ap.model_mgr.rerank_models:
|
||||||
|
if model.model_entity.uuid == model_uuid:
|
||||||
|
runtime_rerank_model = model
|
||||||
|
break
|
||||||
|
if runtime_rerank_model is None:
|
||||||
|
raise Exception('model not found')
|
||||||
|
else:
|
||||||
|
runtime_rerank_model = await self.ap.model_mgr.init_temporary_runtime_rerank_model(model_data)
|
||||||
|
|
||||||
|
await runtime_rerank_model.provider.invoke_rerank(
|
||||||
|
model=runtime_rerank_model,
|
||||||
|
query='What is artificial intelligence?',
|
||||||
|
documents=[
|
||||||
|
'Artificial intelligence is a branch of computer science.',
|
||||||
|
'The weather is nice today.',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -98,6 +98,14 @@ class ModelProviderService:
|
|||||||
if embedding_result.first() is not None:
|
if embedding_result.first() is not None:
|
||||||
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
||||||
|
|
||||||
|
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||||
|
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rerank_result.first() is not None:
|
||||||
|
raise ValueError('Cannot delete provider: Rerank models still reference it')
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
||||||
persistence_model.ModelProvider.uuid == provider_uuid
|
persistence_model.ModelProvider.uuid == provider_uuid
|
||||||
@@ -122,7 +130,14 @@ class ModelProviderService:
|
|||||||
)
|
)
|
||||||
embedding_count = embedding_result.scalar() or 0
|
embedding_count = embedding_result.scalar() or 0
|
||||||
|
|
||||||
return {'llm_count': llm_count, 'embedding_count': embedding_count}
|
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(sqlalchemy.func.count())
|
||||||
|
.select_from(persistence_model.RerankModel)
|
||||||
|
.where(persistence_model.RerankModel.provider_uuid == provider_uuid)
|
||||||
|
)
|
||||||
|
rerank_count = rerank_result.scalar() or 0
|
||||||
|
|
||||||
|
return {'llm_count': llm_count, 'embedding_count': embedding_count, 'rerank_count': rerank_count}
|
||||||
|
|
||||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||||
"""Find existing provider or create new one"""
|
"""Find existing provider or create new one"""
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ class SpaceService:
|
|||||||
space_url = space_config['url']
|
space_url = space_config['url']
|
||||||
|
|
||||||
session = httpclient.get_session()
|
session = httpclient.get_session()
|
||||||
async with session.get(f'{space_url}/api/v1/models') as response:
|
async with session.get(f'{space_url}/api/v1/models', params={'page_size': 100}) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise ValueError(f'Failed to get models: {await response.text()}')
|
raise ValueError(f'Failed to get models: {await response.text()}')
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|||||||
@@ -133,6 +133,8 @@ class Application:
|
|||||||
|
|
||||||
embedding_models_service: model_service.EmbeddingModelsService = None
|
embedding_models_service: model_service.EmbeddingModelsService = None
|
||||||
|
|
||||||
|
rerank_models_service: model_service.RerankModelsService = None
|
||||||
|
|
||||||
provider_service: provider_service.ModelProviderService = None
|
provider_service: provider_service.ModelProviderService = None
|
||||||
|
|
||||||
pipeline_service: pipeline_service.PipelineService = None
|
pipeline_service: pipeline_service.PipelineService = None
|
||||||
|
|||||||
@@ -61,6 +61,9 @@ class BuildAppStage(stage.BootingStage):
|
|||||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||||
ap.embedding_models_service = embedding_models_service_inst
|
ap.embedding_models_service = embedding_models_service_inst
|
||||||
|
|
||||||
|
rerank_models_service_inst = model_service.RerankModelsService(ap)
|
||||||
|
ap.rerank_models_service = rerank_models_service_inst
|
||||||
|
|
||||||
provider_service_inst = provider_service.ModelProviderService(ap)
|
provider_service_inst = provider_service.ModelProviderService(ap)
|
||||||
ap.provider_service = provider_service_inst
|
ap.provider_service = provider_service_inst
|
||||||
|
|
||||||
|
|||||||
@@ -59,3 +59,22 @@ class EmbeddingModel(Base):
|
|||||||
server_default=sqlalchemy.func.now(),
|
server_default=sqlalchemy.func.now(),
|
||||||
onupdate=sqlalchemy.func.now(),
|
onupdate=sqlalchemy.func.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankModel(Base):
|
||||||
|
"""Rerank model"""
|
||||||
|
|
||||||
|
__tablename__ = 'rerank_models'
|
||||||
|
|
||||||
|
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||||
|
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||||
|
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||||
|
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||||
|
updated_at = sqlalchemy.Column(
|
||||||
|
sqlalchemy.DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sqlalchemy.func.now(),
|
||||||
|
onupdate=sqlalchemy.func.now(),
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""add rerank_models table
|
||||||
|
|
||||||
|
Revision ID: 0003_add_rerank_models
|
||||||
|
Revises: 0002_sample
|
||||||
|
Create Date: 2026-04-19
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = '0003_add_rerank_models'
|
||||||
|
down_revision = '0002_sample'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Check if table already exists (may have been created by create_all())
|
||||||
|
conn = op.get_bind()
|
||||||
|
inspector = sa.inspect(conn)
|
||||||
|
if 'rerank_models' not in inspector.get_table_names():
|
||||||
|
op.create_table(
|
||||||
|
'rerank_models',
|
||||||
|
sa.Column('uuid', sa.String(255), primary_key=True, unique=True),
|
||||||
|
sa.Column('name', sa.String(255), nullable=False),
|
||||||
|
sa.Column('provider_uuid', sa.String(255), nullable=False),
|
||||||
|
sa.Column('extra_args', sa.JSON, nullable=False, server_default='{}'),
|
||||||
|
sa.Column('prefered_ranking', sa.Integer, nullable=False, server_default='0'),
|
||||||
|
sa.Column('created_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column('updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('rerank_models')
|
||||||
@@ -4,12 +4,12 @@ import sqlalchemy
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from . import requester
|
from . import requester
|
||||||
|
from .requesters import litellmchat
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from ...discover import engine
|
from ...discover import engine
|
||||||
from . import token
|
from . import token
|
||||||
from ...entity.persistence import model as persistence_model
|
from ...entity.persistence import model as persistence_model
|
||||||
from ...entity.errors import provider as provider_errors
|
from ...entity.errors import provider as provider_errors
|
||||||
from async_lru import alru_cache
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
@@ -24,6 +24,8 @@ class ModelManager:
|
|||||||
|
|
||||||
embedding_models: list[requester.RuntimeEmbeddingModel]
|
embedding_models: list[requester.RuntimeEmbeddingModel]
|
||||||
|
|
||||||
|
rerank_models: list[requester.RuntimeRerankModel]
|
||||||
|
|
||||||
requester_components: list[engine.Component]
|
requester_components: list[engine.Component]
|
||||||
|
|
||||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
|
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
|
||||||
@@ -32,6 +34,7 @@ class ModelManager:
|
|||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.llm_models = []
|
self.llm_models = []
|
||||||
self.embedding_models = []
|
self.embedding_models = []
|
||||||
|
self.rerank_models = []
|
||||||
self.requester_components = []
|
self.requester_components = []
|
||||||
self.requester_dict = {}
|
self.requester_dict = {}
|
||||||
|
|
||||||
@@ -40,6 +43,13 @@ class ModelManager:
|
|||||||
|
|
||||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
||||||
for component in self.requester_components:
|
for component in self.requester_components:
|
||||||
|
# Skip components that use litellm_provider (they will use litellmchat.py instead)
|
||||||
|
if component.spec.get('litellm_provider'):
|
||||||
|
self.ap.logger.debug(
|
||||||
|
f'Skipping Python class loading for {component.metadata.name} '
|
||||||
|
f'(uses litellm_provider={component.spec.get("litellm_provider")})'
|
||||||
|
)
|
||||||
|
continue
|
||||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||||
|
|
||||||
self.requester_dict = requester_dict
|
self.requester_dict = requester_dict
|
||||||
@@ -64,8 +74,7 @@ class ModelManager:
|
|||||||
|
|
||||||
self.llm_models = []
|
self.llm_models = []
|
||||||
self.embedding_models = []
|
self.embedding_models = []
|
||||||
|
self.rerank_models = []
|
||||||
# Load all providers first
|
|
||||||
self.provider_dict = {}
|
self.provider_dict = {}
|
||||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_model.ModelProvider)
|
sqlalchemy.select(persistence_model.ModelProvider)
|
||||||
@@ -110,6 +119,22 @@ class ModelManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||||
|
|
||||||
|
# Load rerank models
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
||||||
|
rerank_models = result.all()
|
||||||
|
for rerank_model in rerank_models:
|
||||||
|
try:
|
||||||
|
provider = self.provider_dict.get(rerank_model.provider_uuid)
|
||||||
|
if provider is None:
|
||||||
|
self.ap.logger.warning(
|
||||||
|
f'Provider {rerank_model.provider_uuid} not found for model {rerank_model.uuid}'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
runtime_rerank_model = await self.load_rerank_model_with_provider(rerank_model, provider)
|
||||||
|
self.rerank_models.append(runtime_rerank_model)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Failed to load model {rerank_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||||
|
|
||||||
async def sync_new_models_from_space(self):
|
async def sync_new_models_from_space(self):
|
||||||
"""Sync models from Space"""
|
"""Sync models from Space"""
|
||||||
space_model_provider = await self.ap.persistence_mgr.execute_async(
|
space_model_provider = await self.ap.persistence_mgr.execute_async(
|
||||||
@@ -212,6 +237,26 @@ class ModelManager:
|
|||||||
|
|
||||||
return runtime_embedding_model
|
return runtime_embedding_model
|
||||||
|
|
||||||
|
async def init_temporary_runtime_rerank_model(
|
||||||
|
self,
|
||||||
|
model_info: dict,
|
||||||
|
) -> requester.RuntimeRerankModel:
|
||||||
|
"""Initialize runtime rerank model from dict (for testing)"""
|
||||||
|
provider_info = model_info.get('provider', {})
|
||||||
|
runtime_provider = await self.load_provider(provider_info)
|
||||||
|
|
||||||
|
runtime_rerank_model = requester.RuntimeRerankModel(
|
||||||
|
model_entity=persistence_model.RerankModel(
|
||||||
|
uuid=model_info.get('uuid', ''),
|
||||||
|
name=model_info.get('name', ''),
|
||||||
|
provider_uuid='',
|
||||||
|
extra_args=model_info.get('extra_args', {}),
|
||||||
|
),
|
||||||
|
provider=runtime_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return runtime_rerank_model
|
||||||
|
|
||||||
async def load_provider(
|
async def load_provider(
|
||||||
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
|
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
|
||||||
) -> requester.RuntimeProvider:
|
) -> requester.RuntimeProvider:
|
||||||
@@ -223,13 +268,34 @@ class ModelManager:
|
|||||||
else:
|
else:
|
||||||
provider_entity = provider_info
|
provider_entity = provider_info
|
||||||
|
|
||||||
if provider_entity.requester not in self.requester_dict:
|
# Get requester manifest to check for litellm_provider
|
||||||
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
|
requester_manifest = self.get_available_requester_manifest_by_name(provider_entity.requester)
|
||||||
|
|
||||||
|
# Build config from base_url
|
||||||
|
config = {'base_url': provider_entity.base_url}
|
||||||
|
|
||||||
|
# Check if requester manifest specifies litellm_provider
|
||||||
|
if requester_manifest and requester_manifest.spec.get('litellm_provider'):
|
||||||
|
# Use unified LiteLLMRequester with provider prefix
|
||||||
|
# Map litellm_provider (YAML spec) to custom_llm_provider (config)
|
||||||
|
config['custom_llm_provider'] = requester_manifest.spec['litellm_provider']
|
||||||
|
requester_inst = litellmchat.LiteLLMRequester(
|
||||||
|
ap=self.ap,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
self.ap.logger.debug(
|
||||||
|
f'Using LiteLLMRequester for {provider_entity.requester} '
|
||||||
|
f'with custom_llm_provider={config["custom_llm_provider"]}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use original requester class (for backward compatibility)
|
||||||
|
if provider_entity.requester not in self.requester_dict:
|
||||||
|
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
|
||||||
|
requester_inst = self.requester_dict[provider_entity.requester](
|
||||||
|
ap=self.ap,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
requester_inst = self.requester_dict[provider_entity.requester](
|
|
||||||
ap=self.ap,
|
|
||||||
config={'base_url': provider_entity.base_url},
|
|
||||||
)
|
|
||||||
await requester_inst.initialize()
|
await requester_inst.initialize()
|
||||||
|
|
||||||
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
|
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
|
||||||
@@ -269,6 +335,9 @@ class ModelManager:
|
|||||||
for model in self.embedding_models:
|
for model in self.embedding_models:
|
||||||
if model.provider.provider_entity.uuid == provider_uuid:
|
if model.provider.provider_entity.uuid == provider_uuid:
|
||||||
model.provider = new_runtime_provider
|
model.provider = new_runtime_provider
|
||||||
|
for model in self.rerank_models:
|
||||||
|
if model.provider.provider_entity.uuid == provider_uuid:
|
||||||
|
model.provider = new_runtime_provider
|
||||||
|
|
||||||
# update ref in provider dict
|
# update ref in provider dict
|
||||||
self.provider_dict[provider_uuid] = new_runtime_provider
|
self.provider_dict[provider_uuid] = new_runtime_provider
|
||||||
@@ -305,6 +374,22 @@ class ModelManager:
|
|||||||
|
|
||||||
return runtime_embedding_model
|
return runtime_embedding_model
|
||||||
|
|
||||||
|
async def load_rerank_model_with_provider(
|
||||||
|
self,
|
||||||
|
model_info: persistence_model.RerankModel | sqlalchemy.Row,
|
||||||
|
provider: requester.RuntimeProvider,
|
||||||
|
) -> requester.RuntimeRerankModel:
|
||||||
|
"""Load rerank model with provider info"""
|
||||||
|
if isinstance(model_info, sqlalchemy.Row):
|
||||||
|
model_info = persistence_model.RerankModel(**model_info._mapping)
|
||||||
|
|
||||||
|
runtime_rerank_model = requester.RuntimeRerankModel(
|
||||||
|
model_entity=model_info,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return runtime_rerank_model
|
||||||
|
|
||||||
async def load_llm_model(self, model_info: dict):
|
async def load_llm_model(self, model_info: dict):
|
||||||
"""Load LLM model from dict (with provider info)"""
|
"""Load LLM model from dict (with provider info)"""
|
||||||
provider_info = model_info.get('provider', {})
|
provider_info = model_info.get('provider', {})
|
||||||
@@ -352,7 +437,6 @@ class ModelManager:
|
|||||||
|
|
||||||
await self.load_embedding_model_with_provider(model_entity, provider_entity)
|
await self.load_embedding_model_with_provider(model_entity, provider_entity)
|
||||||
|
|
||||||
@alru_cache(ttl=60 * 5)
|
|
||||||
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
||||||
"""Get LLM model by uuid"""
|
"""Get LLM model by uuid"""
|
||||||
for model in self.llm_models:
|
for model in self.llm_models:
|
||||||
@@ -360,7 +444,6 @@ class ModelManager:
|
|||||||
return model
|
return model
|
||||||
raise ValueError(f'LLM model {uuid} not found')
|
raise ValueError(f'LLM model {uuid} not found')
|
||||||
|
|
||||||
@alru_cache(ttl=60 * 5)
|
|
||||||
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
||||||
"""Get embedding model by uuid"""
|
"""Get embedding model by uuid"""
|
||||||
for model in self.embedding_models:
|
for model in self.embedding_models:
|
||||||
@@ -368,6 +451,13 @@ class ModelManager:
|
|||||||
return model
|
return model
|
||||||
raise ValueError(f'Embedding model {uuid} not found')
|
raise ValueError(f'Embedding model {uuid} not found')
|
||||||
|
|
||||||
|
async def get_rerank_model_by_uuid(self, uuid: str) -> requester.RuntimeRerankModel:
|
||||||
|
"""Get rerank model by uuid"""
|
||||||
|
for model in self.rerank_models:
|
||||||
|
if model.model_entity.uuid == uuid:
|
||||||
|
return model
|
||||||
|
raise ValueError(f'Rerank model {uuid} not found')
|
||||||
|
|
||||||
async def remove_llm_model(self, model_uuid: str):
|
async def remove_llm_model(self, model_uuid: str):
|
||||||
"""Remove LLM model"""
|
"""Remove LLM model"""
|
||||||
for model in self.llm_models:
|
for model in self.llm_models:
|
||||||
@@ -382,6 +472,13 @@ class ModelManager:
|
|||||||
self.embedding_models.remove(model)
|
self.embedding_models.remove(model)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def remove_rerank_model(self, model_uuid: str):
|
||||||
|
"""Remove rerank model"""
|
||||||
|
for model in self.rerank_models:
|
||||||
|
if model.model_entity.uuid == model_uuid:
|
||||||
|
self.rerank_models.remove(model)
|
||||||
|
return
|
||||||
|
|
||||||
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
||||||
"""Get all available requesters"""
|
"""Get all available requesters"""
|
||||||
if model_type != '':
|
if model_type != '':
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ class RuntimeProvider:
|
|||||||
if isinstance(result, tuple):
|
if isinstance(result, tuple):
|
||||||
msg, usage_info = result
|
msg, usage_info = result
|
||||||
if usage_info:
|
if usage_info:
|
||||||
input_tokens = usage_info.get('input_tokens', 0)
|
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||||
output_tokens = usage_info.get('output_tokens', 0)
|
output_tokens = usage_info.get('completion_tokens', 0)
|
||||||
return msg
|
return msg
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
@@ -128,7 +128,6 @@ class RuntimeProvider:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
status = 'success'
|
status = 'success'
|
||||||
error_message = None
|
error_message = None
|
||||||
# Note: Stream doesn't easily provide token counts, set to 0
|
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
|
|
||||||
@@ -143,6 +142,15 @@ class RuntimeProvider:
|
|||||||
remove_think=remove_think,
|
remove_think=remove_think,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
# Extract usage from stream if available (stored by LiteLLM requester)
|
||||||
|
if query:
|
||||||
|
if query.variables is None:
|
||||||
|
query.variables = {}
|
||||||
|
if '_stream_usage' in query.variables:
|
||||||
|
usage_info = query.variables['_stream_usage']
|
||||||
|
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||||
|
output_tokens = usage_info.get('completion_tokens', 0)
|
||||||
|
del query.variables['_stream_usage']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status = 'error'
|
status = 'error'
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
@@ -247,6 +255,40 @@ class RuntimeProvider:
|
|||||||
except Exception as monitor_err:
|
except Exception as monitor_err:
|
||||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record embedding call: {monitor_err}')
|
self.requester.ap.logger.error(f'[Monitoring] Failed to record embedding call: {monitor_err}')
|
||||||
|
|
||||||
|
async def invoke_rerank(
|
||||||
|
self,
|
||||||
|
model: RuntimeRerankModel,
|
||||||
|
query: str,
|
||||||
|
documents: typing.List[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> typing.List[dict]:
|
||||||
|
"""Bridge method for invoking rerank with monitoring"""
|
||||||
|
start_time = time.time()
|
||||||
|
status = 'success'
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
extra_args=extra_args,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
status = 'error'
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.requester.ap.logger.debug(
|
||||||
|
f'[Rerank] model={model.model_entity.name} docs={len(documents)} '
|
||||||
|
f'duration={duration_ms}ms status={status}'
|
||||||
|
)
|
||||||
|
except Exception as monitor_err:
|
||||||
|
self.requester.ap.logger.error(f'[Monitoring] Failed to record rerank call: {monitor_err}')
|
||||||
|
|
||||||
|
|
||||||
class RuntimeLLMModel:
|
class RuntimeLLMModel:
|
||||||
"""运行时模型"""
|
"""运行时模型"""
|
||||||
@@ -284,6 +326,24 @@ class RuntimeEmbeddingModel:
|
|||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeRerankModel:
|
||||||
|
"""运行时 Rerank 模型"""
|
||||||
|
|
||||||
|
model_entity: persistence_model.RerankModel
|
||||||
|
"""模型数据"""
|
||||||
|
|
||||||
|
provider: RuntimeProvider
|
||||||
|
"""提供商实例"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_entity: persistence_model.RerankModel,
|
||||||
|
provider: RuntimeProvider,
|
||||||
|
):
|
||||||
|
self.model_entity = model_entity
|
||||||
|
self.provider = provider
|
||||||
|
|
||||||
|
|
||||||
class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||||
"""Provider API请求器"""
|
"""Provider API请求器"""
|
||||||
|
|
||||||
@@ -376,3 +436,23 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
|||||||
或者 tuple[typing.List[typing.List[float]], dict]: 返回 (embedding 向量, usage_info)
|
或者 tuple[typing.List[typing.List[float]], dict]: 返回 (embedding 向量, usage_info)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def invoke_rerank(
|
||||||
|
self,
|
||||||
|
model: RuntimeRerankModel,
|
||||||
|
query: str,
|
||||||
|
documents: typing.List[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> typing.List[dict]:
|
||||||
|
"""调用 Rerank API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (RuntimeRerankModel): 使用的模型信息
|
||||||
|
query (str): 查询文本
|
||||||
|
documents (typing.List[str]): 待重排序的文档列表
|
||||||
|
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
typing.List[dict]: [{"index": int, "relevance_score": float}, ...]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('This requester does not support rerank')
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class AI302ChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""302.AI ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.302.ai/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 302.AI
|
zh_Hans: 302.AI
|
||||||
icon: 302ai.png
|
icon: 302ai.png
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -25,6 +26,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,370 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import json
|
|
||||||
import platform
|
|
||||||
import socket
|
|
||||||
import anthropic
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from .. import errors, requester
|
|
||||||
|
|
||||||
from ....utils import image
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessages(requester.ProviderAPIRequester):
|
|
||||||
"""Anthropic Messages API 请求器"""
|
|
||||||
|
|
||||||
client: anthropic.AsyncAnthropic
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.anthropic.com',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
# 兼容 Windows 缺失 TCP_KEEPINTVL 和 TCP_KEEPCNT 的问题
|
|
||||||
if platform.system() == 'Windows':
|
|
||||||
if not hasattr(socket, 'TCP_KEEPINTVL'):
|
|
||||||
socket.TCP_KEEPINTVL = 0
|
|
||||||
if not hasattr(socket, 'TCP_KEEPCNT'):
|
|
||||||
socket.TCP_KEEPCNT = 0
|
|
||||||
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
|
|
||||||
base_url=self.requester_cfg['base_url'],
|
|
||||||
# cast to a valid type because mypy doesn't understand our type narrowing
|
|
||||||
timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']),
|
|
||||||
limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS,
|
|
||||||
follow_redirects=True,
|
|
||||||
trust_env=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client = anthropic.AsyncAnthropic(
|
|
||||||
api_key='',
|
|
||||||
http_client=httpx_client,
|
|
||||||
base_url=self.requester_cfg['base_url'],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def invoke_llm(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: requester.RuntimeLLMModel,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
self.client.api_key = model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = extra_args.copy()
|
|
||||||
args['model'] = model.model_entity.name
|
|
||||||
|
|
||||||
# 处理消息
|
|
||||||
|
|
||||||
# system
|
|
||||||
system_role_message = None
|
|
||||||
|
|
||||||
for i, m in enumerate(messages):
|
|
||||||
if m.role == 'system':
|
|
||||||
system_role_message = m
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
if system_role_message:
|
|
||||||
messages.pop(i)
|
|
||||||
|
|
||||||
if isinstance(system_role_message, provider_message.Message) and isinstance(system_role_message.content, str):
|
|
||||||
args['system'] = system_role_message.content
|
|
||||||
|
|
||||||
req_messages = []
|
|
||||||
|
|
||||||
for m in messages:
|
|
||||||
if m.role == 'tool':
|
|
||||||
tool_call_id = m.tool_call_id
|
|
||||||
|
|
||||||
req_messages.append(
|
|
||||||
{
|
|
||||||
'role': 'user',
|
|
||||||
'content': [
|
|
||||||
{
|
|
||||||
'type': 'tool_result',
|
|
||||||
'tool_use_id': tool_call_id,
|
|
||||||
'is_error': False,
|
|
||||||
'content': [{'type': 'text', 'text': m.content}],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
msg_dict = m.dict(exclude_none=True)
|
|
||||||
|
|
||||||
if isinstance(m.content, str) and m.content.strip() != '':
|
|
||||||
msg_dict['content'] = [{'type': 'text', 'text': m.content}]
|
|
||||||
elif isinstance(m.content, list):
|
|
||||||
for i, ce in enumerate(m.content):
|
|
||||||
if ce.type == 'image_base64':
|
|
||||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
|
||||||
|
|
||||||
alter_image_ele = {
|
|
||||||
'type': 'image',
|
|
||||||
'source': {
|
|
||||||
'type': 'base64',
|
|
||||||
'media_type': f'image/{image_format}',
|
|
||||||
'data': image_b64,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
msg_dict['content'][i] = alter_image_ele
|
|
||||||
|
|
||||||
if m.tool_calls:
|
|
||||||
for tool_call in m.tool_calls:
|
|
||||||
msg_dict['content'].append(
|
|
||||||
{
|
|
||||||
'type': 'tool_use',
|
|
||||||
'id': tool_call.id,
|
|
||||||
'name': tool_call.function.name,
|
|
||||||
'input': json.loads(tool_call.function.arguments),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
del msg_dict['tool_calls']
|
|
||||||
|
|
||||||
req_messages.append(msg_dict)
|
|
||||||
|
|
||||||
args['messages'] = req_messages
|
|
||||||
|
|
||||||
if 'thinking' in args:
|
|
||||||
args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000}
|
|
||||||
|
|
||||||
if funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp = await self.client.messages.create(**args)
|
|
||||||
|
|
||||||
args = {
|
|
||||||
'content': '',
|
|
||||||
'role': resp.role,
|
|
||||||
}
|
|
||||||
assert type(resp) is anthropic.types.message.Message
|
|
||||||
|
|
||||||
for block in resp.content:
|
|
||||||
if not remove_think and block.type == 'thinking':
|
|
||||||
args['content'] = '<think>\n' + block.thinking + '\n</think>\n' + args['content']
|
|
||||||
elif block.type == 'text':
|
|
||||||
args['content'] += block.text
|
|
||||||
elif block.type == 'tool_use':
|
|
||||||
assert type(block) is anthropic.types.tool_use_block.ToolUseBlock
|
|
||||||
tool_call = provider_message.ToolCall(
|
|
||||||
id=block.id,
|
|
||||||
type='function',
|
|
||||||
function=provider_message.FunctionCall(name=block.name, arguments=json.dumps(block.input)),
|
|
||||||
)
|
|
||||||
if 'tool_calls' not in args:
|
|
||||||
args['tool_calls'] = []
|
|
||||||
args['tool_calls'].append(tool_call)
|
|
||||||
|
|
||||||
return provider_message.Message(**args)
|
|
||||||
except anthropic.AuthenticationError as e:
|
|
||||||
raise errors.RequesterError(f'api-key 无效: {e.message}')
|
|
||||||
except anthropic.BadRequestError as e:
|
|
||||||
raise errors.RequesterError(str(e.message))
|
|
||||||
except anthropic.NotFoundError as e:
|
|
||||||
if 'model: ' in str(e):
|
|
||||||
raise errors.RequesterError(f'模型无效: {e.message}')
|
|
||||||
else:
|
|
||||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
|
||||||
|
|
||||||
async def invoke_llm_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: requester.RuntimeLLMModel,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
self.client.api_key = model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = extra_args.copy()
|
|
||||||
args['model'] = model.model_entity.name
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
# 处理消息
|
|
||||||
|
|
||||||
# system
|
|
||||||
system_role_message = None
|
|
||||||
|
|
||||||
for i, m in enumerate(messages):
|
|
||||||
if m.role == 'system':
|
|
||||||
system_role_message = m
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
if system_role_message:
|
|
||||||
messages.pop(i)
|
|
||||||
|
|
||||||
if isinstance(system_role_message, provider_message.Message) and isinstance(system_role_message.content, str):
|
|
||||||
args['system'] = system_role_message.content
|
|
||||||
|
|
||||||
req_messages = []
|
|
||||||
|
|
||||||
for m in messages:
|
|
||||||
if m.role == 'tool':
|
|
||||||
tool_call_id = m.tool_call_id
|
|
||||||
|
|
||||||
req_messages.append(
|
|
||||||
{
|
|
||||||
'role': 'user',
|
|
||||||
'content': [
|
|
||||||
{
|
|
||||||
'type': 'tool_result',
|
|
||||||
'tool_use_id': tool_call_id,
|
|
||||||
'is_error': False, # 暂时直接写false
|
|
||||||
'content': [
|
|
||||||
{'type': 'text', 'text': m.content}
|
|
||||||
], # 这里要是list包裹,应该是多个返回的情况?type类型好像也可以填其他的,暂时只写text
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
msg_dict = m.dict(exclude_none=True)
|
|
||||||
|
|
||||||
if isinstance(m.content, str) and m.content.strip() != '':
|
|
||||||
msg_dict['content'] = [{'type': 'text', 'text': m.content}]
|
|
||||||
elif isinstance(m.content, list):
|
|
||||||
for i, ce in enumerate(m.content):
|
|
||||||
if ce.type == 'image_base64':
|
|
||||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
|
||||||
|
|
||||||
alter_image_ele = {
|
|
||||||
'type': 'image',
|
|
||||||
'source': {
|
|
||||||
'type': 'base64',
|
|
||||||
'media_type': f'image/{image_format}',
|
|
||||||
'data': image_b64,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
msg_dict['content'][i] = alter_image_ele
|
|
||||||
if isinstance(msg_dict['content'], str) and msg_dict['content'] == '':
|
|
||||||
msg_dict['content'] = [] # 这里不知道为什么会莫名有个空导致content为字符
|
|
||||||
if m.tool_calls:
|
|
||||||
for tool_call in m.tool_calls:
|
|
||||||
msg_dict['content'].append(
|
|
||||||
{
|
|
||||||
'type': 'tool_use',
|
|
||||||
'id': tool_call.id,
|
|
||||||
'name': tool_call.function.name,
|
|
||||||
'input': json.loads(tool_call.function.arguments),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
del msg_dict['tool_calls']
|
|
||||||
|
|
||||||
req_messages.append(msg_dict)
|
|
||||||
if 'thinking' in args:
|
|
||||||
args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000}
|
|
||||||
|
|
||||||
args['messages'] = req_messages
|
|
||||||
|
|
||||||
if funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
try:
|
|
||||||
role = 'assistant' # 默认角色
|
|
||||||
# chunk_idx = 0
|
|
||||||
think_started = False
|
|
||||||
think_ended = False
|
|
||||||
finish_reason = False
|
|
||||||
tool_name = ''
|
|
||||||
tool_id = ''
|
|
||||||
async for chunk in await self.client.messages.create(**args):
|
|
||||||
content = ''
|
|
||||||
tool_call = {'id': None, 'function': {'name': None, 'arguments': None}, 'type': 'function'}
|
|
||||||
if isinstance(
|
|
||||||
chunk, anthropic.types.raw_content_block_start_event.RawContentBlockStartEvent
|
|
||||||
): # 记录开始
|
|
||||||
if chunk.content_block.type == 'tool_use':
|
|
||||||
if chunk.content_block.name is not None:
|
|
||||||
tool_name = chunk.content_block.name
|
|
||||||
if chunk.content_block.id is not None:
|
|
||||||
tool_id = chunk.content_block.id
|
|
||||||
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
tool_call['function']['arguments'] = ''
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
|
|
||||||
if not remove_think:
|
|
||||||
if chunk.content_block.type == 'thinking' and not remove_think:
|
|
||||||
think_started = True
|
|
||||||
elif chunk.content_block.type == 'text' and chunk.index != 0 and not remove_think:
|
|
||||||
think_ended = True
|
|
||||||
continue
|
|
||||||
elif isinstance(chunk, anthropic.types.raw_content_block_delta_event.RawContentBlockDeltaEvent):
|
|
||||||
if chunk.delta.type == 'thinking_delta':
|
|
||||||
if think_started:
|
|
||||||
think_started = False
|
|
||||||
content = '<think>\n' + chunk.delta.thinking
|
|
||||||
elif remove_think:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
content = chunk.delta.thinking
|
|
||||||
elif chunk.delta.type == 'text_delta':
|
|
||||||
if think_ended:
|
|
||||||
think_ended = False
|
|
||||||
content = '\n</think>\n' + chunk.delta.text
|
|
||||||
else:
|
|
||||||
content = chunk.delta.text
|
|
||||||
elif chunk.delta.type == 'input_json_delta':
|
|
||||||
tool_call['function']['arguments'] = chunk.delta.partial_json
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
elif isinstance(chunk, anthropic.types.raw_content_block_stop_event.RawContentBlockStopEvent):
|
|
||||||
continue # 记录raw_content_block结束的
|
|
||||||
|
|
||||||
elif isinstance(chunk, anthropic.types.raw_message_delta_event.RawMessageDeltaEvent):
|
|
||||||
if chunk.delta.stop_reason == 'end_turn':
|
|
||||||
finish_reason = True
|
|
||||||
elif isinstance(chunk, anthropic.types.raw_message_stop_event.RawMessageStopEvent):
|
|
||||||
continue # 这个好像是完全结束
|
|
||||||
else:
|
|
||||||
# print(chunk)
|
|
||||||
self.ap.logger.debug(f'anthropic chunk: {chunk}')
|
|
||||||
continue
|
|
||||||
|
|
||||||
args = {
|
|
||||||
'content': content,
|
|
||||||
'role': role,
|
|
||||||
'is_final': finish_reason,
|
|
||||||
'tool_calls': None if tool_call['id'] is None else [tool_call],
|
|
||||||
}
|
|
||||||
# if chunk_idx == 0:
|
|
||||||
# chunk_idx += 1
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# assert type(chunk) is anthropic.types.message.Chunk
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**args)
|
|
||||||
|
|
||||||
# return llm_entities.Message(**args)
|
|
||||||
except anthropic.AuthenticationError as e:
|
|
||||||
raise errors.RequesterError(f'api-key 无效: {e.message}')
|
|
||||||
except anthropic.BadRequestError as e:
|
|
||||||
raise errors.RequesterError(str(e.message))
|
|
||||||
except anthropic.NotFoundError as e:
|
|
||||||
if 'model: ' in str(e):
|
|
||||||
raise errors.RequesterError(f'模型无效: {e.message}')
|
|
||||||
else:
|
|
||||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: Anthropic
|
zh_Hans: Anthropic
|
||||||
icon: anthropic.svg
|
icon: anthropic.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: anthropic
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
5
src/langbot/pkg/provider/modelmgr/requesters/baidu.svg
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#2932E1"/>
|
||||||
|
<text x="30" y="28" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="white" text-anchor="middle">Baidu</text>
|
||||||
|
<text x="30" y="40" font-family="Arial, sans-serif" font-size="8" fill="white" text-anchor="middle">ERNIE</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 396 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: baidu-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: Baidu ERNIE
|
||||||
|
zh_Hans: 百度文心一言
|
||||||
|
icon: baidu.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import dashscope
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import modelscopechatcmpl
|
|
||||||
from .. import requester
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
|
|
||||||
|
|
||||||
class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
|
||||||
"""阿里云百炼大模型平台 ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _closure_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
is_use_dashscope_call = False # 是否使用阿里原生库调用
|
|
||||||
is_enable_multi_model = True # 是否支持多轮对话
|
|
||||||
use_time_num = 0 # 模型已调用次数,防止存在多文件时重复调用
|
|
||||||
use_time_ids = [] # 已调用的ID列表
|
|
||||||
message_id = 0 # 记录消息序号
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
# print(msg)
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
elif me['type'] == 'file_url' and '.' in me.get('file_name', ''):
|
|
||||||
# 1. 视频文件推理
|
|
||||||
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2845871
|
|
||||||
file_type = me.get('file_name').lower().split('.')[-1]
|
|
||||||
if file_type in ['mp4', 'avi', 'mkv', 'mov', 'flv', 'wmv']:
|
|
||||||
me['type'] = 'video_url'
|
|
||||||
me['video_url'] = {'url': me['file_url']}
|
|
||||||
del me['file_url']
|
|
||||||
del me['file_name']
|
|
||||||
use_time_num += 1
|
|
||||||
use_time_ids.append(message_id)
|
|
||||||
is_enable_multi_model = False
|
|
||||||
# 2. 语音文件识别, 无法通过openai的audio字段传递,暂时不支持
|
|
||||||
# https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2979031
|
|
||||||
elif file_type in [
|
|
||||||
'aac',
|
|
||||||
'amr',
|
|
||||||
'aiff',
|
|
||||||
'flac',
|
|
||||||
'm4a',
|
|
||||||
'mp3',
|
|
||||||
'mpeg',
|
|
||||||
'ogg',
|
|
||||||
'opus',
|
|
||||||
'wav',
|
|
||||||
'webm',
|
|
||||||
'wma',
|
|
||||||
]:
|
|
||||||
me['audio'] = me['file_url']
|
|
||||||
me['type'] = 'audio'
|
|
||||||
del me['file_url']
|
|
||||||
del me['type']
|
|
||||||
del me['file_name']
|
|
||||||
is_use_dashscope_call = True
|
|
||||||
use_time_num += 1
|
|
||||||
use_time_ids.append(message_id)
|
|
||||||
is_enable_multi_model = False
|
|
||||||
message_id += 1
|
|
||||||
|
|
||||||
# 使用列表推导式,保留不在 use_time_ids[:-1] 中的元素,仅保留最后一个多媒体消息
|
|
||||||
if not is_enable_multi_model and use_time_num > 1:
|
|
||||||
messages = [msg for idx, msg in enumerate(messages) if idx not in use_time_ids[:-1]]
|
|
||||||
|
|
||||||
if not is_enable_multi_model:
|
|
||||||
messages = [msg for msg in messages if 'resp_message_id' not in msg]
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
# 流式处理状态
|
|
||||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
role = 'assistant' # 默认角色
|
|
||||||
|
|
||||||
if is_use_dashscope_call:
|
|
||||||
response = dashscope.MultiModalConversation.call(
|
|
||||||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx"
|
|
||||||
api_key=use_model.provider.token_mgr.get_token(),
|
|
||||||
model=use_model.model_entity.name,
|
|
||||||
messages=messages,
|
|
||||||
result_format='message',
|
|
||||||
asr_options={
|
|
||||||
# "language": "zh", # 可选,若已知音频的语种,可通过该参数指定待识别语种,以提升识别准确率
|
|
||||||
'enable_lid': True,
|
|
||||||
'enable_itn': False,
|
|
||||||
},
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
content_length_list = []
|
|
||||||
previous_length = 0 # 记录上一次的内容长度
|
|
||||||
for res in response:
|
|
||||||
chunk = res['output']
|
|
||||||
# 解析 chunk 数据
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta_content = choice['message'].content[0]['text']
|
|
||||||
finish_reason = choice['finish_reason']
|
|
||||||
content_length_list.append(len(delta_content))
|
|
||||||
else:
|
|
||||||
delta_content = ''
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
|
||||||
if chunk_idx == 0 and not delta_content:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查 content_length_list 是否有足够的数据
|
|
||||||
if len(content_length_list) >= 2:
|
|
||||||
now_content = delta_content[previous_length : content_length_list[-1]]
|
|
||||||
previous_length = content_length_list[-1] # 更新上一次的长度
|
|
||||||
else:
|
|
||||||
now_content = delta_content # 第一次循环时直接使用 delta_content
|
|
||||||
previous_length = len(delta_content) # 更新上一次的长度
|
|
||||||
|
|
||||||
# 构建 MessageChunk - 只包含增量内容
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': now_content if now_content else None,
|
|
||||||
'is_final': bool(finish_reason) and finish_reason != 'null',
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除 None 值
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
else:
|
|
||||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
|
||||||
# 解析 chunk 数据
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
|
||||||
finish_reason = getattr(choice, 'finish_reason', None)
|
|
||||||
else:
|
|
||||||
delta = {}
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
|
||||||
if 'role' in delta and delta['role']:
|
|
||||||
role = delta['role']
|
|
||||||
|
|
||||||
# 获取增量内容
|
|
||||||
delta_content = delta.get('content', '')
|
|
||||||
reasoning_content = delta.get('reasoning_content', '')
|
|
||||||
|
|
||||||
# 处理 reasoning_content
|
|
||||||
if reasoning_content:
|
|
||||||
# accumulated_reasoning += reasoning_content
|
|
||||||
# 如果设置了 remove_think,跳过 reasoning_content
|
|
||||||
if remove_think:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
|
||||||
if not thinking_started:
|
|
||||||
thinking_started = True
|
|
||||||
delta_content = '<think>\n' + reasoning_content
|
|
||||||
else:
|
|
||||||
# 继续输出 reasoning_content
|
|
||||||
delta_content = reasoning_content
|
|
||||||
elif thinking_started and not thinking_ended and delta_content:
|
|
||||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
|
||||||
thinking_ended = True
|
|
||||||
delta_content = '\n</think>\n' + delta_content
|
|
||||||
|
|
||||||
# 处理工具调用增量
|
|
||||||
if delta.get('tool_calls'):
|
|
||||||
for tool_call in delta['tool_calls']:
|
|
||||||
if tool_call['id'] != '':
|
|
||||||
tool_id = tool_call['id']
|
|
||||||
if tool_call['function']['name'] is not None:
|
|
||||||
tool_name = tool_call['function']['name']
|
|
||||||
|
|
||||||
if tool_call['type'] is None:
|
|
||||||
tool_call['type'] = 'function'
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
tool_call['function']['arguments'] = (
|
|
||||||
'' if tool_call['function']['arguments'] is None else tool_call['function']['arguments']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
|
||||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 构建 MessageChunk - 只包含增量内容
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': delta_content if delta_content else None,
|
|
||||||
'tool_calls': delta.get('tool_calls'),
|
|
||||||
'is_final': bool(finish_reason),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除 None 值
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
# return
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 阿里云百炼
|
zh_Hans: 阿里云百炼
|
||||||
icon: bailian.png
|
icon: bailian.png
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,617 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import typing
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import openai.types.chat.chat_completion as chat_completion_module
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from .. import errors, requester
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
|
||||||
"""OpenAI ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.openai.com/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
self.client = openai.AsyncClient(
|
|
||||||
api_key='',
|
|
||||||
base_url=self.requester_cfg['base_url'].replace(' ', ''),
|
|
||||||
timeout=self.requester_cfg['timeout'],
|
|
||||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _mask_api_key(self, api_key: str | None) -> str:
|
|
||||||
if not api_key:
|
|
||||||
return ''
|
|
||||||
if len(api_key) <= 8:
|
|
||||||
return '****'
|
|
||||||
return f'{api_key[:4]}...{api_key[-4:]}'
|
|
||||||
|
|
||||||
def _infer_model_type(self, model_id: str) -> str:
|
|
||||||
normalized_model_id = (model_id or '').lower()
|
|
||||||
embedding_keywords = (
|
|
||||||
'embedding',
|
|
||||||
'embed',
|
|
||||||
'bge-',
|
|
||||||
'e5-',
|
|
||||||
'm3e',
|
|
||||||
'gte-',
|
|
||||||
'multilingual-e5',
|
|
||||||
'text-embedding',
|
|
||||||
)
|
|
||||||
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
|
|
||||||
|
|
||||||
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
|
|
||||||
normalized_model_id = (model_id or '').lower()
|
|
||||||
abilities: set[str] = set()
|
|
||||||
|
|
||||||
def _flatten(value: typing.Any) -> list[str]:
|
|
||||||
if value is None:
|
|
||||||
return []
|
|
||||||
if isinstance(value, str):
|
|
||||||
return [value.lower()]
|
|
||||||
if isinstance(value, dict):
|
|
||||||
flattened: list[str] = []
|
|
||||||
for nested_value in value.values():
|
|
||||||
flattened.extend(_flatten(nested_value))
|
|
||||||
return flattened
|
|
||||||
if isinstance(value, (list, tuple, set)):
|
|
||||||
flattened: list[str] = []
|
|
||||||
for nested_value in value:
|
|
||||||
flattened.extend(_flatten(nested_value))
|
|
||||||
return flattened
|
|
||||||
return [str(value).lower()]
|
|
||||||
|
|
||||||
capability_tokens = _flatten(item.get('capabilities'))
|
|
||||||
capability_tokens.extend(_flatten(item.get('modalities')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('input_modalities')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('output_modalities')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('supported_generation_methods')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('supported_parameters')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('architecture')))
|
|
||||||
|
|
||||||
combined_tokens = capability_tokens + [normalized_model_id]
|
|
||||||
|
|
||||||
vision_keywords = (
|
|
||||||
'vision',
|
|
||||||
'image',
|
|
||||||
'file',
|
|
||||||
'video',
|
|
||||||
'multimodal',
|
|
||||||
'vl',
|
|
||||||
'ocr',
|
|
||||||
'omni',
|
|
||||||
)
|
|
||||||
function_call_keywords = (
|
|
||||||
'function',
|
|
||||||
'tool',
|
|
||||||
'tools',
|
|
||||||
'tool_choice',
|
|
||||||
'tool_call',
|
|
||||||
'tool-use',
|
|
||||||
'tool_use',
|
|
||||||
)
|
|
||||||
|
|
||||||
if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens):
|
|
||||||
abilities.add('vision')
|
|
||||||
|
|
||||||
if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens):
|
|
||||||
abilities.add('func_call')
|
|
||||||
|
|
||||||
return sorted(abilities)
|
|
||||||
|
|
||||||
def _normalize_modalities(self, value: typing.Any) -> list[str]:
|
|
||||||
normalized: list[str] = []
|
|
||||||
|
|
||||||
def _collect(item: typing.Any):
|
|
||||||
if item is None:
|
|
||||||
return
|
|
||||||
if isinstance(item, str):
|
|
||||||
for part in item.replace('->', ',').replace('+', ',').split(','):
|
|
||||||
token = part.strip().lower()
|
|
||||||
if token and token not in normalized:
|
|
||||||
normalized.append(token)
|
|
||||||
return
|
|
||||||
if isinstance(item, dict):
|
|
||||||
for nested in item.values():
|
|
||||||
_collect(nested)
|
|
||||||
return
|
|
||||||
if isinstance(item, (list, tuple, set)):
|
|
||||||
for nested in item:
|
|
||||||
_collect(nested)
|
|
||||||
return
|
|
||||||
|
|
||||||
_collect(value)
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]:
|
|
||||||
display_name = item.get('name')
|
|
||||||
if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id:
|
|
||||||
display_name = ''
|
|
||||||
|
|
||||||
description = item.get('description')
|
|
||||||
if not isinstance(description, str) or not description.strip():
|
|
||||||
description = ''
|
|
||||||
|
|
||||||
context_length = item.get('context_length')
|
|
||||||
if context_length is None and isinstance(item.get('top_provider'), dict):
|
|
||||||
context_length = item['top_provider'].get('context_length')
|
|
||||||
|
|
||||||
if not isinstance(context_length, int):
|
|
||||||
try:
|
|
||||||
context_length = int(context_length) if context_length is not None else None
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
context_length = None
|
|
||||||
|
|
||||||
input_modalities = self._normalize_modalities(item.get('input_modalities'))
|
|
||||||
output_modalities = self._normalize_modalities(item.get('output_modalities'))
|
|
||||||
|
|
||||||
if isinstance(item.get('architecture'), dict):
|
|
||||||
if not input_modalities:
|
|
||||||
input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities'))
|
|
||||||
if not output_modalities:
|
|
||||||
output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities'))
|
|
||||||
|
|
||||||
owned_by = item.get('owned_by')
|
|
||||||
if not isinstance(owned_by, str) or not owned_by.strip():
|
|
||||||
owned_by = ''
|
|
||||||
|
|
||||||
return {
|
|
||||||
'display_name': display_name or None,
|
|
||||||
'description': description or None,
|
|
||||||
'context_length': context_length,
|
|
||||||
'owned_by': owned_by or None,
|
|
||||||
'input_modalities': input_modalities,
|
|
||||||
'output_modalities': output_modalities,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
|
||||||
headers = {}
|
|
||||||
if api_key:
|
|
||||||
headers['Authorization'] = f'Bearer {api_key}'
|
|
||||||
|
|
||||||
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models'
|
|
||||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
|
||||||
response = await client.get(models_url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
payload = response.json()
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for item in payload.get('data', []):
|
|
||||||
model_id = item.get('id')
|
|
||||||
if not model_id:
|
|
||||||
continue
|
|
||||||
models.append(
|
|
||||||
{
|
|
||||||
'id': model_id,
|
|
||||||
'name': model_id,
|
|
||||||
'type': self._infer_model_type(model_id),
|
|
||||||
'abilities': self._infer_model_abilities(item, model_id),
|
|
||||||
**self._extract_scan_metadata(item, model_id),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
|
||||||
return {
|
|
||||||
'models': models,
|
|
||||||
'debug': {
|
|
||||||
'request': {
|
|
||||||
'method': 'GET',
|
|
||||||
'url': models_url,
|
|
||||||
'headers': {
|
|
||||||
'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
'response': payload,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _req(
|
|
||||||
self,
|
|
||||||
args: dict,
|
|
||||||
extra_body: dict = {},
|
|
||||||
) -> chat_completion_module.ChatCompletion:
|
|
||||||
return await self.client.chat.completions.create(**args, extra_body=extra_body)
|
|
||||||
|
|
||||||
async def _req_stream(
|
|
||||||
self,
|
|
||||||
args: dict,
|
|
||||||
extra_body: dict = {},
|
|
||||||
):
|
|
||||||
async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def _make_msg(
|
|
||||||
self,
|
|
||||||
chat_completion: chat_completion_module.ChatCompletion,
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
if not isinstance(chat_completion, chat_completion_module.ChatCompletion):
|
|
||||||
raise TypeError(f'Expected ChatCompletion, got {type(chat_completion).__name__}: {chat_completion[:16]}')
|
|
||||||
|
|
||||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
|
||||||
|
|
||||||
# 确保 role 字段存在且不为 None
|
|
||||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
|
||||||
chatcmpl_message['role'] = 'assistant'
|
|
||||||
|
|
||||||
# 处理思维链
|
|
||||||
content = chatcmpl_message.get('content', '')
|
|
||||||
reasoning_content = chatcmpl_message.get('reasoning_content', None)
|
|
||||||
|
|
||||||
processed_content, _ = await self._process_thinking_content(
|
|
||||||
content=content, reasoning_content=reasoning_content, remove_think=remove_think
|
|
||||||
)
|
|
||||||
|
|
||||||
chatcmpl_message['content'] = processed_content
|
|
||||||
|
|
||||||
# 移除 reasoning_content 字段,避免传递给 Message
|
|
||||||
if 'reasoning_content' in chatcmpl_message:
|
|
||||||
del chatcmpl_message['reasoning_content']
|
|
||||||
|
|
||||||
message = provider_message.Message(**chatcmpl_message)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def _process_thinking_content(
|
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
reasoning_content: str = None,
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""处理思维链内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 原始内容
|
|
||||||
reasoning_content: reasoning_content 字段内容
|
|
||||||
remove_think: 是否移除思维链
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(处理后的内容, 提取的思维链内容)
|
|
||||||
"""
|
|
||||||
thinking_content = ''
|
|
||||||
|
|
||||||
# 1. 从 reasoning_content 提取思维链
|
|
||||||
if reasoning_content:
|
|
||||||
thinking_content = reasoning_content
|
|
||||||
|
|
||||||
# 2. 从 content 中提取 <think> 标签内容
|
|
||||||
if content and '<think>' in content and '</think>' in content:
|
|
||||||
import re
|
|
||||||
|
|
||||||
think_pattern = r'<think>(.*?)</think>'
|
|
||||||
think_matches = re.findall(think_pattern, content, re.DOTALL)
|
|
||||||
if think_matches:
|
|
||||||
# 如果已有 reasoning_content,则追加
|
|
||||||
if thinking_content:
|
|
||||||
thinking_content += '\n' + '\n'.join(think_matches)
|
|
||||||
else:
|
|
||||||
thinking_content = '\n'.join(think_matches)
|
|
||||||
# 移除 content 中的 <think> 标签
|
|
||||||
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
|
||||||
|
|
||||||
# 3. 根据 remove_think 参数决定是否保留思维链
|
|
||||||
if remove_think:
|
|
||||||
return content, ''
|
|
||||||
else:
|
|
||||||
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
|
|
||||||
if thinking_content:
|
|
||||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
|
||||||
return content, thinking_content
|
|
||||||
|
|
||||||
async def _closure_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.MessageChunk:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
# 检查vision
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
# 流式处理状态
|
|
||||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
role = 'assistant' # 默认角色
|
|
||||||
tool_id = ''
|
|
||||||
tool_name = ''
|
|
||||||
# accumulated_reasoning = '' # 仅用于判断何时结束思维链
|
|
||||||
|
|
||||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
|
||||||
# 解析 chunk 数据
|
|
||||||
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
|
||||||
|
|
||||||
finish_reason = getattr(choice, 'finish_reason', None)
|
|
||||||
else:
|
|
||||||
delta = {}
|
|
||||||
finish_reason = None
|
|
||||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
|
||||||
if 'role' in delta and delta['role']:
|
|
||||||
role = delta['role']
|
|
||||||
|
|
||||||
# 获取增量内容
|
|
||||||
delta_content = delta.get('content', '')
|
|
||||||
reasoning_content = delta.get('reasoning_content', '')
|
|
||||||
|
|
||||||
# 处理 reasoning_content
|
|
||||||
if reasoning_content:
|
|
||||||
# accumulated_reasoning += reasoning_content
|
|
||||||
# 如果设置了 remove_think,跳过 reasoning_content
|
|
||||||
if remove_think:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
|
||||||
if not thinking_started:
|
|
||||||
thinking_started = True
|
|
||||||
delta_content = '<think>\n' + reasoning_content
|
|
||||||
else:
|
|
||||||
# 继续输出 reasoning_content
|
|
||||||
delta_content = reasoning_content
|
|
||||||
elif thinking_started and not thinking_ended and delta_content:
|
|
||||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
|
||||||
thinking_ended = True
|
|
||||||
delta_content = '\n</think>\n' + delta_content
|
|
||||||
|
|
||||||
# 处理 content 中已有的 <think> 标签(如果需要移除)
|
|
||||||
# if delta_content and remove_think and '<think>' in delta_content:
|
|
||||||
# import re
|
|
||||||
#
|
|
||||||
# # 移除 <think> 标签及其内容
|
|
||||||
# delta_content = re.sub(r'<think>.*?</think>', '', delta_content, flags=re.DOTALL)
|
|
||||||
|
|
||||||
# 处理工具调用增量
|
|
||||||
# delta_tool_calls = None
|
|
||||||
if delta.get('tool_calls'):
|
|
||||||
for tool_call in delta['tool_calls']:
|
|
||||||
if tool_call['id'] and tool_call['function']['name']:
|
|
||||||
tool_id = tool_call['id']
|
|
||||||
tool_name = tool_call['function']['name']
|
|
||||||
else:
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
if tool_call['type'] is None:
|
|
||||||
tool_call['type'] = 'function'
|
|
||||||
|
|
||||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
|
||||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
# 构建 MessageChunk - 只包含增量内容
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': delta_content if delta_content else None,
|
|
||||||
'tool_calls': delta.get('tool_calls'),
|
|
||||||
'is_final': bool(finish_reason),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除 None 值
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
|
|
||||||
async def _closure(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[provider_message.Message, dict]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
# 检查vision
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
|
|
||||||
# 发送请求
|
|
||||||
|
|
||||||
resp = await self._req(args, extra_body=extra_args)
|
|
||||||
# 处理请求结果
|
|
||||||
message = await self._make_msg(resp, remove_think)
|
|
||||||
|
|
||||||
# Extract token usage from response
|
|
||||||
usage_info = {}
|
|
||||||
if hasattr(resp, 'usage') and resp.usage:
|
|
||||||
usage_info['input_tokens'] = resp.usage.prompt_tokens or 0
|
|
||||||
usage_info['output_tokens'] = resp.usage.completion_tokens or 0
|
|
||||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
|
||||||
|
|
||||||
return message, usage_info
|
|
||||||
|
|
||||||
async def invoke_llm(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: requester.RuntimeLLMModel,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[provider_message.Message, dict]:
|
|
||||||
"""Invoke LLM and return message with usage info"""
|
|
||||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
|
||||||
for m in messages:
|
|
||||||
msg_dict = m.dict(exclude_none=True)
|
|
||||||
content = msg_dict.get('content')
|
|
||||||
if isinstance(content, list):
|
|
||||||
# 检查 content 列表中是否每个部分都是文本
|
|
||||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
|
||||||
# 将所有文本部分合并为一个字符串
|
|
||||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
|
||||||
req_messages.append(msg_dict)
|
|
||||||
|
|
||||||
try:
|
|
||||||
msg, usage_info = await self._closure(
|
|
||||||
query=query,
|
|
||||||
req_messages=req_messages,
|
|
||||||
use_model=model,
|
|
||||||
use_funcs=funcs,
|
|
||||||
extra_args=extra_args,
|
|
||||||
remove_think=remove_think,
|
|
||||||
)
|
|
||||||
return msg, usage_info
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise errors.RequesterError('请求超时')
|
|
||||||
except openai.BadRequestError as e:
|
|
||||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
|
||||||
if 'context_length_exceeded' in str(e):
|
|
||||||
raise errors.RequesterError(f'上文过长,请重置会话: {error_message}')
|
|
||||||
else:
|
|
||||||
raise errors.RequesterError(f'请求参数错误: {error_message}')
|
|
||||||
except openai.AuthenticationError as e:
|
|
||||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
|
||||||
raise errors.RequesterError(f'无效的 api-key: {error_message}')
|
|
||||||
except openai.NotFoundError as e:
|
|
||||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
|
||||||
raise errors.RequesterError(f'请求路径错误: {error_message}')
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
|
||||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {error_message}')
|
|
||||||
except openai.APIConnectionError as e:
|
|
||||||
error_message = f'连接错误: {str(e)}'
|
|
||||||
raise errors.RequesterError(error_message)
|
|
||||||
except openai.APIError as e:
|
|
||||||
error_message = str(e.message) if hasattr(e, 'message') else str(e)
|
|
||||||
raise errors.RequesterError(f'请求错误: {error_message}')
|
|
||||||
|
|
||||||
async def invoke_embedding(
|
|
||||||
self,
|
|
||||||
model: requester.RuntimeEmbeddingModel,
|
|
||||||
input_text: list[str],
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
) -> tuple[list[list[float]], dict]:
|
|
||||||
"""调用 Embedding API, returns (embeddings, usage_info)"""
|
|
||||||
self.client.api_key = model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {
|
|
||||||
'model': model.model_entity.name,
|
|
||||||
'input': input_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
if model.model_entity.extra_args:
|
|
||||||
args.update(model.model_entity.extra_args)
|
|
||||||
|
|
||||||
args.update(extra_args)
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp = await self.client.embeddings.create(**args)
|
|
||||||
|
|
||||||
# Extract usage info
|
|
||||||
usage_info = {}
|
|
||||||
if hasattr(resp, 'usage') and resp.usage:
|
|
||||||
usage_info['prompt_tokens'] = resp.usage.prompt_tokens or 0
|
|
||||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
|
||||||
|
|
||||||
return [d.embedding for d in resp.data], usage_info
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise errors.RequesterError('请求超时')
|
|
||||||
except openai.BadRequestError as e:
|
|
||||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
|
||||||
|
|
||||||
async def invoke_llm_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: requester.RuntimeLLMModel,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.MessageChunk:
|
|
||||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
|
||||||
for m in messages:
|
|
||||||
msg_dict = m.dict(exclude_none=True)
|
|
||||||
content = msg_dict.get('content')
|
|
||||||
if isinstance(content, list):
|
|
||||||
# 检查 content 列表中是否每个部分都是文本
|
|
||||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
|
||||||
# 将所有文本部分合并为一个字符串
|
|
||||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
|
||||||
req_messages.append(msg_dict)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for item in self._closure_stream(
|
|
||||||
query=query,
|
|
||||||
req_messages=req_messages,
|
|
||||||
use_model=model,
|
|
||||||
use_funcs=funcs,
|
|
||||||
extra_args=extra_args,
|
|
||||||
remove_think=remove_think,
|
|
||||||
):
|
|
||||||
yield item
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise errors.RequesterError('请求超时')
|
|
||||||
except openai.BadRequestError as e:
|
|
||||||
if 'context_length_exceeded' in e.message:
|
|
||||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
|
||||||
else:
|
|
||||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
|
||||||
except openai.AuthenticationError as e:
|
|
||||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
|
||||||
except openai.NotFoundError as e:
|
|
||||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
|
||||||
except openai.APIError as e:
|
|
||||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: OpenAI
|
zh_Hans: OpenAI
|
||||||
icon: openai.svg
|
icon: openai.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -25,6 +26,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 128 128" id="Chroma--Streamline-Svg-Logos" height="128" width="128">
|
||||||
<rect width="24" height="24" rx="5" fill="#7B68EE"/>
|
<desc>
|
||||||
<circle cx="12" cy="12" r="6" fill="#FF6B35"/>
|
Chroma Streamline Icon: https://streamlinehq.com
|
||||||
<circle cx="12" cy="12" r="3" fill="#7B68EE"/>
|
</desc>
|
||||||
<path d="M12 6V18" stroke="#FFF" stroke-width="1.5" stroke-linecap="round"/>
|
<path fill="#ffde2d" d="M84.88839999999999 104.10666666666665c23.0732 0 41.77773333333333 -17.956266666666664 41.77773333333333 -40.10653333333333 0 -22.150266666666667 -18.70453333333333 -40.10653333333333 -41.77773333333333 -40.10653333333333 -23.0732 0 -41.77773333333333 17.956266666666664 -41.77773333333333 40.10653333333333 0 22.150266666666667 18.70453333333333 40.10653333333333 41.77773333333333 40.10653333333333Z" stroke-width="1.3333"></path>
|
||||||
<path d="M6 12H18" stroke="#FFF" stroke-width="1.5" stroke-linecap="round"/>
|
<path fill="#327eff" d="M43.111066666666666 104.10666666666665c23.0732 0 41.77773333333333 -17.956266666666664 41.77773333333333 -40.10653333333333 0 -22.150266666666667 -18.70453333333333 -40.10653333333333 -41.77773333333333 -40.10653333333333C20.037866666666666 23.8936 1.3333333333333333 41.849866666666664 1.3333333333333333 64.00013333333334 1.3333333333333333 86.15039999999999 20.037866666666666 104.10666666666665 43.111066666666666 104.10666666666665Z" stroke-width="1.3333"></path>
|
||||||
|
<path fill="#ff6446" d="M84.88866666666667 64.00013333333334c0 22.150399999999998 -18.704666666666665 40.10626666666666 -41.778 40.10626666666666V64.00013333333334h41.778Zm-41.778 0c0 -22.150266666666667 18.70453333333333 -40.10653333333333 41.778 -40.10653333333333v40.10653333333333H43.11066666666666Z" stroke-width="1.3333"></path>
|
||||||
</svg>
|
</svg>
|
||||||
|
Before Width: | Height: | Size: 413 B After Width: | Height: | Size: 1.5 KiB |
1
src/langbot/pkg/provider/modelmgr/requesters/cohere.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Cohere</title><path clip-rule="evenodd" d="M8.128 14.099c.592 0 1.77-.033 3.398-.703 1.897-.781 5.672-2.2 8.395-3.656 1.905-1.018 2.74-2.366 2.74-4.18A4.56 4.56 0 0018.1 1H7.549A6.55 6.55 0 001 7.55c0 3.617 2.745 6.549 7.128 6.549z" fill="#39594D" fill-rule="evenodd"></path><path clip-rule="evenodd" d="M9.912 18.61a4.387 4.387 0 012.705-4.052l3.323-1.38c3.361-1.394 7.06 1.076 7.06 4.715a5.104 5.104 0 01-5.105 5.104l-3.597-.001a4.386 4.386 0 01-4.386-4.387z" fill="#D18EE2" fill-rule="evenodd"></path><path d="M4.776 14.962A3.775 3.775 0 001 18.738v.489a3.776 3.776 0 007.551 0v-.49a3.775 3.775 0 00-3.775-3.775z" fill="#FF7759"></path></svg>
|
||||||
|
After Width: | Height: | Size: 769 B |
@@ -0,0 +1,32 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: cohere-rerank
|
||||||
|
label:
|
||||||
|
en_US: Cohere
|
||||||
|
zh_Hans: Cohere
|
||||||
|
icon: cohere.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: cohere
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.cohere.com/v2
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./chatcmpl.py
|
||||||
|
attr: OpenAIChatCompletions
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class CompShareChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""CompShare ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.modelverse.cn/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 优云智算
|
zh_Hans: 优云智算
|
||||||
icon: compshare.png
|
icon: compshare.png
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
from .. import errors, requester
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""Deepseek ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.deepseek.com',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _closure(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[provider_message.Message, dict]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages
|
|
||||||
|
|
||||||
# deepseek 不支持多模态,把content都转换成纯文字
|
|
||||||
for m in messages:
|
|
||||||
if 'content' in m and isinstance(m['content'], list):
|
|
||||||
m['content'] = ' '.join([c['text'] for c in m['content'] if 'text' in c])
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
|
|
||||||
# 发送请求
|
|
||||||
resp = await self._req(args, extra_body=extra_args)
|
|
||||||
|
|
||||||
# print(resp)
|
|
||||||
|
|
||||||
if resp is None:
|
|
||||||
raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常')
|
|
||||||
# 处理请求结果
|
|
||||||
message = await self._make_msg(resp, remove_think)
|
|
||||||
|
|
||||||
# Extract token usage from response
|
|
||||||
usage_info = {}
|
|
||||||
if hasattr(resp, 'usage') and resp.usage:
|
|
||||||
usage_info['input_tokens'] = resp.usage.prompt_tokens or 0
|
|
||||||
usage_info['output_tokens'] = resp.usage.completion_tokens or 0
|
|
||||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
|
||||||
|
|
||||||
return message, usage_info
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: DeepSeek
|
zh_Hans: DeepSeek
|
||||||
icon: deepseek.svg
|
icon: deepseek.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: deepseek
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
4
src/langbot/pkg/provider/modelmgr/requesters/doubao.svg
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#3B82F6"/>
|
||||||
|
<text x="30" y="32" font-family="Arial, sans-serif" font-size="12" font-weight="bold" fill="white" text-anchor="middle">豆包</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 282 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: doubao-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: ByteDance Doubao
|
||||||
|
zh_Hans: 字节豆包
|
||||||
|
icon: doubao.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://ark.cn-beijing.volces.com/api/v3
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from .. import requester
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""Google Gemini API 请求器"""
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://generativelanguage.googleapis.com/v1beta/openai',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
|
||||||
models_url = 'https://generativelanguage.googleapis.com/v1beta/models'
|
|
||||||
params = {'key': api_key} if api_key else {}
|
|
||||||
|
|
||||||
all_models: list[dict[str, typing.Any]] = []
|
|
||||||
next_page_token = ''
|
|
||||||
last_payload: dict[str, typing.Any] = {}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
|
||||||
while True:
|
|
||||||
request_params = dict(params)
|
|
||||||
if next_page_token:
|
|
||||||
request_params['pageToken'] = next_page_token
|
|
||||||
|
|
||||||
response = await client.get(models_url, params=request_params)
|
|
||||||
response.raise_for_status()
|
|
||||||
payload = response.json()
|
|
||||||
last_payload = payload
|
|
||||||
|
|
||||||
for item in payload.get('models', []):
|
|
||||||
model_name = item.get('name', '')
|
|
||||||
model_id = model_name.replace('models/', '', 1)
|
|
||||||
if not model_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
supported_methods = item.get('supportedGenerationMethods', []) or []
|
|
||||||
if 'embedContent' in supported_methods and 'generateContent' not in supported_methods:
|
|
||||||
model_type = 'embedding'
|
|
||||||
else:
|
|
||||||
model_type = 'llm'
|
|
||||||
|
|
||||||
all_models.append(
|
|
||||||
{
|
|
||||||
'id': model_id,
|
|
||||||
'name': model_id,
|
|
||||||
'type': model_type,
|
|
||||||
'abilities': self._infer_model_abilities(item, model_id),
|
|
||||||
'display_name': item.get('displayName') or None,
|
|
||||||
'description': item.get('description') or None,
|
|
||||||
'context_length': item.get('inputTokenLimit'),
|
|
||||||
'input_modalities': self._normalize_modalities(item.get('inputModalities')),
|
|
||||||
'output_modalities': self._normalize_modalities(item.get('outputModalities')),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
next_page_token = payload.get('nextPageToken', '')
|
|
||||||
if not next_page_token:
|
|
||||||
break
|
|
||||||
|
|
||||||
all_models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
|
||||||
return {
|
|
||||||
'models': all_models,
|
|
||||||
'debug': {
|
|
||||||
'request': {
|
|
||||||
'method': 'GET',
|
|
||||||
'url': models_url,
|
|
||||||
'query': {'key': self._mask_api_key(api_key)} if api_key else {},
|
|
||||||
},
|
|
||||||
'response': last_payload,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _closure_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.MessageChunk:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
# 检查vision
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
# 流式处理状态
|
|
||||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
role = 'assistant' # 默认角色
|
|
||||||
tool_id = ''
|
|
||||||
tool_name = ''
|
|
||||||
# accumulated_reasoning = '' # 仅用于判断何时结束思维链
|
|
||||||
|
|
||||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
|
||||||
# 解析 chunk 数据
|
|
||||||
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
|
||||||
|
|
||||||
finish_reason = getattr(choice, 'finish_reason', None)
|
|
||||||
else:
|
|
||||||
delta = {}
|
|
||||||
finish_reason = None
|
|
||||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
|
||||||
if 'role' in delta and delta['role']:
|
|
||||||
role = delta['role']
|
|
||||||
|
|
||||||
# 获取增量内容
|
|
||||||
delta_content = delta.get('content', '')
|
|
||||||
reasoning_content = delta.get('reasoning_content', '')
|
|
||||||
|
|
||||||
# 处理 reasoning_content
|
|
||||||
if reasoning_content:
|
|
||||||
# accumulated_reasoning += reasoning_content
|
|
||||||
# 如果设置了 remove_think,跳过 reasoning_content
|
|
||||||
if remove_think:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
|
||||||
if not thinking_started:
|
|
||||||
thinking_started = True
|
|
||||||
delta_content = '<think>\n' + reasoning_content
|
|
||||||
else:
|
|
||||||
# 继续输出 reasoning_content
|
|
||||||
delta_content = reasoning_content
|
|
||||||
elif thinking_started and not thinking_ended and delta_content:
|
|
||||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
|
||||||
thinking_ended = True
|
|
||||||
delta_content = '\n</think>\n' + delta_content
|
|
||||||
|
|
||||||
# 处理 content 中已有的 <think> 标签(如果需要移除)
|
|
||||||
# if delta_content and remove_think and '<think>' in delta_content:
|
|
||||||
# import re
|
|
||||||
#
|
|
||||||
# # 移除 <think> 标签及其内容
|
|
||||||
# delta_content = re.sub(r'<think>.*?</think>', '', delta_content, flags=re.DOTALL)
|
|
||||||
|
|
||||||
# 处理工具调用增量
|
|
||||||
# delta_tool_calls = None
|
|
||||||
if delta.get('tool_calls'):
|
|
||||||
for tool_call in delta['tool_calls']:
|
|
||||||
if tool_call['id'] == '' and tool_id == '':
|
|
||||||
tool_id = str(uuid.uuid4())
|
|
||||||
if tool_call['function']['name']:
|
|
||||||
tool_name = tool_call['function']['name']
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
if tool_call['type'] is None:
|
|
||||||
tool_call['type'] = 'function'
|
|
||||||
|
|
||||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
|
||||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
# 构建 MessageChunk - 只包含增量内容
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': delta_content if delta_content else None,
|
|
||||||
'tool_calls': delta.get('tool_calls'),
|
|
||||||
'is_final': bool(finish_reason),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除 None 值
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: Google Gemini
|
zh_Hans: Google Gemini
|
||||||
icon: gemini.svg
|
icon: gemini.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: gemini
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from . import ppiochatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class GiteeAIChatCompletions(ppiochatcmpl.PPIOChatCompletions):
|
|
||||||
"""Gitee AI ChatCompletions API 请求器"""
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://ai.gitee.com/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: Gitee AI
|
zh_Hans: Gitee AI
|
||||||
icon: giteeai.svg
|
icon: giteeai.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -25,6 +26,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
4
src/langbot/pkg/provider/modelmgr/requesters/groq.svg
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#F97316"/>
|
||||||
|
<text x="30" y="32" font-family="Arial, sans-serif" font-size="14" font-weight="bold" fill="white" text-anchor="middle">Groq</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 280 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: groq-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: Groq
|
||||||
|
zh_Hans: Groq
|
||||||
|
icon: groq.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: groq
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.groq.com/openai/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
5
src/langbot/pkg/provider/modelmgr/requesters/iflytek.svg
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#0066FF"/>
|
||||||
|
<text x="30" y="28" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="white" text-anchor="middle">iFlytek</text>
|
||||||
|
<text x="30" y="40" font-family="Arial, sans-serif" font-size="8" fill="white" text-anchor="middle">Spark</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 398 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: iflytek-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: iFlytek Spark
|
||||||
|
zh_Hans: 讯飞星火
|
||||||
|
icon: iflytek.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://spark-api-open.xf-yun.com/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
from .. import requester
|
|
||||||
import openai.types.chat.chat_completion as chat_completion
|
|
||||||
import re
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
|
|
||||||
|
|
||||||
class JieKouAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""接口 AI ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.jiekou.ai/openai',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
is_think: bool = False
|
|
||||||
|
|
||||||
async def _make_msg(
|
|
||||||
self,
|
|
||||||
chat_completion: chat_completion.ChatCompletion,
|
|
||||||
remove_think: bool,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
|
||||||
# print(chatcmpl_message.keys(), chatcmpl_message.values())
|
|
||||||
|
|
||||||
# 确保 role 字段存在且不为 None
|
|
||||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
|
||||||
chatcmpl_message['role'] = 'assistant'
|
|
||||||
|
|
||||||
reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None
|
|
||||||
|
|
||||||
# deepseek的reasoner模型
|
|
||||||
chatcmpl_message['content'] = await self._process_thinking_content(
|
|
||||||
chatcmpl_message['content'], reasoning_content, remove_think
|
|
||||||
)
|
|
||||||
|
|
||||||
# 移除 reasoning_content 字段,避免传递给 Message
|
|
||||||
if 'reasoning_content' in chatcmpl_message:
|
|
||||||
del chatcmpl_message['reasoning_content']
|
|
||||||
|
|
||||||
message = provider_message.Message(**chatcmpl_message)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def _process_thinking_content(
|
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
reasoning_content: str = None,
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""处理思维链内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 原始内容
|
|
||||||
reasoning_content: reasoning_content 字段内容
|
|
||||||
remove_think: 是否移除思维链
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
处理后的内容
|
|
||||||
"""
|
|
||||||
if remove_think:
|
|
||||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
|
|
||||||
else:
|
|
||||||
if reasoning_content is not None:
|
|
||||||
content = '<think>\n' + reasoning_content + '\n</think>\n' + content
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _make_msg_chunk(
|
|
||||||
self,
|
|
||||||
delta: dict[str, typing.Any],
|
|
||||||
idx: int,
|
|
||||||
) -> provider_message.MessageChunk:
|
|
||||||
# 处理流式chunk和完整响应的差异
|
|
||||||
# print(chat_completion.choices[0])
|
|
||||||
|
|
||||||
# 确保 role 字段存在且不为 None
|
|
||||||
if 'role' not in delta or delta['role'] is None:
|
|
||||||
delta['role'] = 'assistant'
|
|
||||||
|
|
||||||
reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None
|
|
||||||
|
|
||||||
delta['content'] = '' if delta['content'] is None else delta['content']
|
|
||||||
# print(reasoning_content)
|
|
||||||
|
|
||||||
# deepseek的reasoner模型
|
|
||||||
|
|
||||||
if reasoning_content is not None:
|
|
||||||
delta['content'] += reasoning_content
|
|
||||||
|
|
||||||
message = provider_message.MessageChunk(**delta)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def _closure_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
# 检查vision
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
role = 'assistant' # 默认角色
|
|
||||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
|
||||||
# 解析 chunk 数据
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
|
||||||
finish_reason = getattr(choice, 'finish_reason', None)
|
|
||||||
else:
|
|
||||||
delta = {}
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
|
||||||
if 'role' in delta and delta['role']:
|
|
||||||
role = delta['role']
|
|
||||||
|
|
||||||
# 获取增量内容
|
|
||||||
delta_content = delta.get('content', '')
|
|
||||||
# reasoning_content = delta.get('reasoning_content', '')
|
|
||||||
|
|
||||||
if remove_think:
|
|
||||||
if delta['content'] is not None:
|
|
||||||
if '<think>' in delta['content'] and not thinking_started and not thinking_ended:
|
|
||||||
thinking_started = True
|
|
||||||
continue
|
|
||||||
elif delta['content'] == r'</think>' and not thinking_ended:
|
|
||||||
thinking_ended = True
|
|
||||||
continue
|
|
||||||
elif thinking_ended and delta['content'] == '\n\n' and thinking_started:
|
|
||||||
thinking_started = False
|
|
||||||
continue
|
|
||||||
elif thinking_started and not thinking_ended:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# delta_tool_calls = None
|
|
||||||
if delta.get('tool_calls'):
|
|
||||||
for tool_call in delta['tool_calls']:
|
|
||||||
if tool_call['id'] and tool_call['function']['name']:
|
|
||||||
tool_id = tool_call['id']
|
|
||||||
tool_name = tool_call['function']['name']
|
|
||||||
|
|
||||||
if tool_call['id'] is None:
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
if tool_call['function']['name'] is None:
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
if tool_call['function']['arguments'] is None:
|
|
||||||
tool_call['function']['arguments'] = ''
|
|
||||||
if tool_call['type'] is None:
|
|
||||||
tool_call['type'] = 'function'
|
|
||||||
|
|
||||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
|
||||||
if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'):
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 构建 MessageChunk - 只包含增量内容
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': delta_content if delta_content else None,
|
|
||||||
'tool_calls': delta.get('tool_calls'),
|
|
||||||
'is_final': bool(finish_reason),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除 None 值
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 接口 AI
|
zh_Hans: 接口 AI
|
||||||
icon: jiekouai.png
|
icon: jiekouai.png
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
1
src/langbot/pkg/provider/modelmgr/requesters/jina.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg fill="currentColor" fill-rule="evenodd" height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Jina</title><path d="M6.608 21.416a4.608 4.608 0 100-9.217 4.608 4.608 0 000 9.217zM20.894 2.015c.614 0 1.106.492 1.106 1.106v9.002c0 5.13-4.148 9.309-9.217 9.37v-9.355l-.03-9.032c0-.614.491-1.106 1.106-1.106h7.158l-.123.015z"></path></svg>
|
||||||
|
After Width: | Height: | Size: 404 B |
32
src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: jina-rerank
|
||||||
|
label:
|
||||||
|
en_US: Jina
|
||||||
|
zh_Hans: Jina
|
||||||
|
icon: jina.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.jina.ai/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./chatcmpl.py
|
||||||
|
attr: OpenAIChatCompletions
|
||||||
407
src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
"""LiteLLM unified requester for chat, embedding, and rerank."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import acompletion, aembedding, arerank
|
||||||
|
|
||||||
|
from .. import errors, requester
|
||||||
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||||
|
"""LiteLLM unified API requester supporting chat, embedding, and rerank."""
|
||||||
|
|
||||||
|
default_config: dict[str, typing.Any] = {
|
||||||
|
'base_url': '',
|
||||||
|
'timeout': 120,
|
||||||
|
'custom_llm_provider': '',
|
||||||
|
'drop_params': False,
|
||||||
|
'num_retries': 0,
|
||||||
|
'api_version': '',
|
||||||
|
}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize LiteLLM client settings."""
|
||||||
|
# LiteLLM doesn't require explicit client initialization
|
||||||
|
# Configuration is passed per-request via litellm params
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _build_litellm_model_name(self, model_name: str, custom_llm_provider: str | None = None) -> str:
|
||||||
|
"""Build LiteLLM model name with provider prefix if needed."""
|
||||||
|
provider = custom_llm_provider or self.requester_cfg.get('custom_llm_provider', '')
|
||||||
|
if provider:
|
||||||
|
# LiteLLM format: provider/model_name
|
||||||
|
return f'{provider}/{model_name}'
|
||||||
|
# If no custom provider, assume model_name already includes prefix or is OpenAI-compatible
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: typing.List[provider_message.Message]) -> list[dict]:
|
||||||
|
"""Convert LangBot messages to LiteLLM/OpenAI format."""
|
||||||
|
req_messages = []
|
||||||
|
for m in messages:
|
||||||
|
msg_dict = m.dict(exclude_none=True)
|
||||||
|
content = msg_dict.get('content')
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get('type') == 'image_base64':
|
||||||
|
part['image_url'] = {'url': part['image_base64']}
|
||||||
|
part['type'] = 'image_url'
|
||||||
|
del part['image_base64']
|
||||||
|
|
||||||
|
req_messages.append(msg_dict)
|
||||||
|
|
||||||
|
return req_messages
|
||||||
|
|
||||||
|
def _process_thinking_content(self, content: str, reasoning_content: str | None, remove_think: bool) -> str:
|
||||||
|
"""Process thinking/reasoning content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The main content from response
|
||||||
|
reasoning_content: Separate reasoning content from model
|
||||||
|
remove_think: If True, remove thinking markers; if False, preserve them
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed content string
|
||||||
|
"""
|
||||||
|
# Extract and handle thinking tags
|
||||||
|
if content and 'CRETIRE_REASONING_BEGINk' in content and 'CRETIRE_REASONING_ENDk' in content:
|
||||||
|
import re
|
||||||
|
|
||||||
|
think_pattern = r'CRETIRE_REASONING_BEGINk(.*?)CRETIRE_REASONING_ENDk'
|
||||||
|
|
||||||
|
if remove_think:
|
||||||
|
# Remove thinking tags and their content from output
|
||||||
|
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
||||||
|
# else: preserve thinking content as-is
|
||||||
|
|
||||||
|
# Handle separate reasoning_content field
|
||||||
|
# Currently we don't include reasoning_content in user-facing output regardless of remove_think
|
||||||
|
# because it's typically internal model reasoning, not user-visible thinking
|
||||||
|
return content or ''
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> dict:
|
||||||
|
"""Extract usage info from LiteLLM response."""
|
||||||
|
usage = response.usage
|
||||||
|
return {
|
||||||
|
'prompt_tokens': usage.prompt_tokens or 0,
|
||||||
|
'completion_tokens': usage.completion_tokens or 0,
|
||||||
|
'total_tokens': usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
|
||||||
|
"""Apply common requester config to args dict."""
|
||||||
|
if self.requester_cfg.get('base_url'):
|
||||||
|
args['api_base'] = self.requester_cfg['base_url']
|
||||||
|
if self.requester_cfg.get('timeout'):
|
||||||
|
args['timeout'] = self.requester_cfg['timeout']
|
||||||
|
if include_retry_params:
|
||||||
|
if self.requester_cfg.get('drop_params'):
|
||||||
|
args['drop_params'] = self.requester_cfg['drop_params']
|
||||||
|
if self.requester_cfg.get('num_retries'):
|
||||||
|
args['num_retries'] = self.requester_cfg['num_retries']
|
||||||
|
if self.requester_cfg.get('api_version'):
|
||||||
|
args['api_version'] = self.requester_cfg['api_version']
|
||||||
|
return args
|
||||||
|
|
||||||
|
def _handle_litellm_error(self, e: Exception) -> None:
|
||||||
|
"""Convert LiteLLM exceptions to RequesterError. Never returns, always raises."""
|
||||||
|
# Check more specific exceptions first (they inherit from base exceptions)
|
||||||
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
|
raise errors.RequesterError(f'上下文长度超限: {str(e)}')
|
||||||
|
if isinstance(e, litellm.BadRequestError):
|
||||||
|
raise errors.RequesterError(f'请求参数错误: {str(e)}')
|
||||||
|
if isinstance(e, litellm.AuthenticationError):
|
||||||
|
raise errors.RequesterError(f'API key 无效: {str(e)}')
|
||||||
|
if isinstance(e, litellm.NotFoundError):
|
||||||
|
raise errors.RequesterError(f'模型或路径无效: {str(e)}')
|
||||||
|
if isinstance(e, litellm.RateLimitError):
|
||||||
|
raise errors.RequesterError(f'请求过于频繁或余额不足: {str(e)}')
|
||||||
|
if isinstance(e, litellm.Timeout):
|
||||||
|
raise errors.RequesterError(f'请求超时: {str(e)}')
|
||||||
|
if isinstance(e, litellm.APIConnectionError):
|
||||||
|
raise errors.RequesterError(f'连接错误: {str(e)}')
|
||||||
|
if isinstance(e, litellm.APIError):
|
||||||
|
raise errors.RequesterError(f'API 错误: {str(e)}')
|
||||||
|
raise errors.RequesterError(f'未知错误: {str(e)}')
|
||||||
|
|
||||||
|
async def _build_completion_args(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeLLMModel,
|
||||||
|
messages: typing.List[provider_message.Message],
|
||||||
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
stream: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Build common completion arguments for invoke_llm and invoke_llm_stream."""
|
||||||
|
req_messages = self._convert_messages(messages)
|
||||||
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'model': model_name,
|
||||||
|
'messages': req_messages,
|
||||||
|
'api_key': api_key,
|
||||||
|
}
|
||||||
|
if stream:
|
||||||
|
args['stream'] = True
|
||||||
|
args['stream_options'] = {'include_usage': True}
|
||||||
|
self._build_common_args(args)
|
||||||
|
|
||||||
|
# Apply model-level extra_args first, then call-level extra_args
|
||||||
|
if model.model_entity.extra_args:
|
||||||
|
args.update(model.model_entity.extra_args)
|
||||||
|
args.update(extra_args)
|
||||||
|
|
||||||
|
if funcs:
|
||||||
|
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
||||||
|
if tools:
|
||||||
|
args['tools'] = tools
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
async def invoke_llm(
|
||||||
|
self,
|
||||||
|
query: pipeline_query.Query,
|
||||||
|
model: requester.RuntimeLLMModel,
|
||||||
|
messages: typing.List[provider_message.Message],
|
||||||
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
remove_think: bool = False,
|
||||||
|
) -> tuple[provider_message.Message, dict]:
|
||||||
|
"""Invoke LLM and return message with usage info."""
|
||||||
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await acompletion(**args)
|
||||||
|
|
||||||
|
message_data = response.choices[0].message.model_dump()
|
||||||
|
if 'role' not in message_data or message_data['role'] is None:
|
||||||
|
message_data['role'] = 'assistant'
|
||||||
|
|
||||||
|
content = message_data.get('content', '')
|
||||||
|
reasoning_content = message_data.get('reasoning_content', None)
|
||||||
|
message_data['content'] = self._process_thinking_content(content, reasoning_content, remove_think)
|
||||||
|
|
||||||
|
if 'reasoning_content' in message_data:
|
||||||
|
del message_data['reasoning_content']
|
||||||
|
|
||||||
|
message = provider_message.Message(**message_data)
|
||||||
|
usage_info = self._extract_usage(response)
|
||||||
|
|
||||||
|
return message, usage_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def invoke_llm_stream(
|
||||||
|
self,
|
||||||
|
query: pipeline_query.Query,
|
||||||
|
model: requester.RuntimeLLMModel,
|
||||||
|
messages: typing.List[provider_message.Message],
|
||||||
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
remove_think: bool = False,
|
||||||
|
) -> provider_message.MessageChunk:
|
||||||
|
"""Invoke LLM streaming and yield chunks."""
|
||||||
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True)
|
||||||
|
|
||||||
|
chunk_idx = 0
|
||||||
|
role = 'assistant'
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await acompletion(**args)
|
||||||
|
async for chunk in response:
|
||||||
|
# Check for usage chunk (final chunk with stream_options include_usage)
|
||||||
|
if hasattr(chunk, 'usage') and chunk.usage and (not hasattr(chunk, 'choices') or not chunk.choices):
|
||||||
|
usage_info = {
|
||||||
|
'prompt_tokens': chunk.usage.prompt_tokens or 0,
|
||||||
|
'completion_tokens': chunk.usage.completion_tokens or 0,
|
||||||
|
'total_tokens': chunk.usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
if query.variables is None:
|
||||||
|
query.variables = {}
|
||||||
|
query.variables['_stream_usage'] = usage_info
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(chunk, 'choices') or not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
||||||
|
finish_reason = getattr(choice, 'finish_reason', None)
|
||||||
|
|
||||||
|
if 'role' in delta and delta['role']:
|
||||||
|
role = delta['role']
|
||||||
|
|
||||||
|
delta_content = delta.get('content', '')
|
||||||
|
reasoning_content = delta.get('reasoning_content', '')
|
||||||
|
|
||||||
|
# Handle reasoning_content based on remove_think flag
|
||||||
|
if reasoning_content:
|
||||||
|
if remove_think:
|
||||||
|
# Skip reasoning content when remove_think is True
|
||||||
|
chunk_idx += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Use reasoning_content as the displayed content
|
||||||
|
delta_content = reasoning_content
|
||||||
|
|
||||||
|
if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'):
|
||||||
|
chunk_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_data = {
|
||||||
|
'role': role,
|
||||||
|
'content': delta_content if delta_content else None,
|
||||||
|
'tool_calls': delta.get('tool_calls'),
|
||||||
|
'is_final': bool(finish_reason),
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||||
|
yield provider_message.MessageChunk(**chunk_data)
|
||||||
|
chunk_idx += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def invoke_embedding(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeEmbeddingModel,
|
||||||
|
input_text: list[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> tuple[list[list[float]], dict]:
|
||||||
|
"""Invoke embedding and return vectors with usage info."""
|
||||||
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'model': model_name,
|
||||||
|
'input': input_text,
|
||||||
|
'api_key': api_key,
|
||||||
|
}
|
||||||
|
self._build_common_args(args, include_retry_params=False)
|
||||||
|
|
||||||
|
if model.model_entity.extra_args:
|
||||||
|
args.update(model.model_entity.extra_args)
|
||||||
|
|
||||||
|
args.update(extra_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await aembedding(**args)
|
||||||
|
|
||||||
|
embeddings = [d.embedding for d in response.data]
|
||||||
|
usage_info = self._extract_usage(response)
|
||||||
|
|
||||||
|
return embeddings, usage_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def invoke_rerank(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeRerankModel,
|
||||||
|
query: str,
|
||||||
|
documents: typing.List[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> typing.List[dict]:
|
||||||
|
"""Invoke rerank and return relevance scores."""
|
||||||
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'model': model_name,
|
||||||
|
'query': query,
|
||||||
|
'documents': documents,
|
||||||
|
'api_key': api_key,
|
||||||
|
'top_n': min(len(documents), 64),
|
||||||
|
}
|
||||||
|
self._build_common_args(args, include_retry_params=False)
|
||||||
|
|
||||||
|
if model.model_entity.extra_args:
|
||||||
|
args.update(model.model_entity.extra_args)
|
||||||
|
|
||||||
|
args.update(extra_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await arerank(**args)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for r in response.results:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
'index': r.get('index', 0),
|
||||||
|
'relevance_score': r.get('relevance_score', 0.0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
scores = [r['relevance_score'] for r in results]
|
||||||
|
min_score = min(scores)
|
||||||
|
max_score = max(scores)
|
||||||
|
if max_score - min_score > 1e-6:
|
||||||
|
for r in results:
|
||||||
|
r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||||
|
"""Scan models supported by the provider."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
base_url = self.requester_cfg.get('base_url', '').rstrip('/')
|
||||||
|
timeout = self.requester_cfg.get('timeout', 120)
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise errors.RequesterError('Base URL required for model scanning')
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers['Authorization'] = f'Bearer {api_key}'
|
||||||
|
|
||||||
|
models_url = f'{base_url}/models'
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client:
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = response.json()
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for item in payload.get('data', []):
|
||||||
|
model_id = item.get('id')
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Infer model type
|
||||||
|
normalized_id = (model_id or '').lower()
|
||||||
|
embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
|
||||||
|
model_type = 'embedding' if any(kw in normalized_id for kw in embedding_keywords) else 'llm'
|
||||||
|
|
||||||
|
models.append(
|
||||||
|
{
|
||||||
|
'id': model_id,
|
||||||
|
'name': model_id,
|
||||||
|
'type': model_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower()))
|
||||||
|
|
||||||
|
return {'models': models}
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise errors.RequesterError(f'Model scan failed: {e.response.status_code}')
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise errors.RequesterError('Model scan timeout')
|
||||||
|
except Exception as e:
|
||||||
|
raise errors.RequesterError(f'Model scan error: {str(e)}')
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: litellm-chat
|
||||||
|
label:
|
||||||
|
en_US: LiteLLM (Unified)
|
||||||
|
zh_Hans: LiteLLM (统一请求器)
|
||||||
|
icon: litellm.svg
|
||||||
|
spec:
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
- name: custom_llm_provider
|
||||||
|
label:
|
||||||
|
en_US: Custom Provider
|
||||||
|
zh_Hans: 自定义 Provider
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
description:
|
||||||
|
en_US: Force provider type (e.g., anthropic, openai, gemini)
|
||||||
|
zh_Hans: 强制指定 provider 类型(如 anthropic, openai, gemini)
|
||||||
|
- name: drop_params
|
||||||
|
label:
|
||||||
|
en_US: Drop Unsupported Params
|
||||||
|
zh_Hans: 丢弃不支持参数
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
- name: num_retries
|
||||||
|
label:
|
||||||
|
en_US: Number of Retries
|
||||||
|
zh_Hans: 重试次数
|
||||||
|
type: integer
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
- name: api_version
|
||||||
|
label:
|
||||||
|
en_US: API Version
|
||||||
|
zh_Hans: API 版本
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: unified
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./litellmchat.py
|
||||||
|
attr: LiteLLMRequester
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""LMStudio ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'http://127.0.0.1:1234/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: LM Studio
|
zh_Hans: LM Studio
|
||||||
icon: lmstudio.webp
|
icon: lmstudio.webp
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
4
src/langbot/pkg/provider/modelmgr/requesters/mimo.svg
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#FF6700"/>
|
||||||
|
<text x="30" y="32" font-family="Arial, sans-serif" font-size="18" font-weight="bold" fill="white" text-anchor="middle">MiMo</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 280 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: mimo-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: Xiaomi MiMo
|
||||||
|
zh_Hans: 小米 MiMo
|
||||||
|
icon: mimo.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.xiaomimimo.com/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
4
src/langbot/pkg/provider/modelmgr/requesters/minimax.svg
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#4F46E5"/>
|
||||||
|
<text x="30" y="32" font-family="Arial, sans-serif" font-size="12" font-weight="bold" fill="white" text-anchor="middle">MiniMax</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 283 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: minimax-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: MiniMax
|
||||||
|
zh_Hans: MiniMax
|
||||||
|
icon: minimax.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.minimax.chat/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
5
src/langbot/pkg/provider/modelmgr/requesters/mistral.svg
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#FF6B35"/>
|
||||||
|
<text x="30" y="28" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="white" text-anchor="middle">Mistral</text>
|
||||||
|
<text x="30" y="40" font-family="Arial, sans-serif" font-size="8" fill="white" text-anchor="middle">AI</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 395 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: mistral-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: Mistral AI
|
||||||
|
zh_Hans: Mistral AI
|
||||||
|
icon: mistral.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: mistral
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.mistral.ai/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
@@ -1,561 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import typing
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import openai.types.chat.chat_completion as chat_completion
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from .. import entities, errors, requester
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
|
|
||||||
|
|
||||||
class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
|
||||||
"""ModelScope ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api-inference.modelscope.cn/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
self.client = openai.AsyncClient(
|
|
||||||
api_key='',
|
|
||||||
base_url=self.requester_cfg['base_url'],
|
|
||||||
timeout=self.requester_cfg['timeout'],
|
|
||||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _mask_api_key(self, api_key: str | None) -> str:
|
|
||||||
if not api_key:
|
|
||||||
return ''
|
|
||||||
if len(api_key) <= 8:
|
|
||||||
return '****'
|
|
||||||
return f'{api_key[:4]}...{api_key[-4:]}'
|
|
||||||
|
|
||||||
def _infer_model_type(self, model_id: str) -> str:
|
|
||||||
normalized_model_id = (model_id or '').lower()
|
|
||||||
embedding_keywords = (
|
|
||||||
'embedding',
|
|
||||||
'embed',
|
|
||||||
'bge-',
|
|
||||||
'e5-',
|
|
||||||
'm3e',
|
|
||||||
'gte-',
|
|
||||||
'multilingual-e5',
|
|
||||||
'text-embedding',
|
|
||||||
)
|
|
||||||
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
|
|
||||||
|
|
||||||
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
|
|
||||||
normalized_model_id = (model_id or '').lower()
|
|
||||||
abilities: set[str] = set()
|
|
||||||
|
|
||||||
def _flatten(value: typing.Any) -> list[str]:
|
|
||||||
if value is None:
|
|
||||||
return []
|
|
||||||
if isinstance(value, str):
|
|
||||||
return [value.lower()]
|
|
||||||
if isinstance(value, dict):
|
|
||||||
flattened: list[str] = []
|
|
||||||
for nested_value in value.values():
|
|
||||||
flattened.extend(_flatten(nested_value))
|
|
||||||
return flattened
|
|
||||||
if isinstance(value, (list, tuple, set)):
|
|
||||||
flattened: list[str] = []
|
|
||||||
for nested_value in value:
|
|
||||||
flattened.extend(_flatten(nested_value))
|
|
||||||
return flattened
|
|
||||||
return [str(value).lower()]
|
|
||||||
|
|
||||||
capability_tokens = _flatten(item.get('capabilities'))
|
|
||||||
capability_tokens.extend(_flatten(item.get('modalities')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('input_modalities')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('output_modalities')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('supported_generation_methods')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('supported_parameters')))
|
|
||||||
capability_tokens.extend(_flatten(item.get('architecture')))
|
|
||||||
|
|
||||||
combined_tokens = capability_tokens + [normalized_model_id]
|
|
||||||
|
|
||||||
vision_keywords = ('vision', 'image', 'file', 'video', 'multimodal', 'vl', 'ocr', 'omni')
|
|
||||||
function_call_keywords = ('function', 'tool', 'tools', 'tool_choice', 'tool_call', 'tool-use', 'tool_use')
|
|
||||||
|
|
||||||
if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens):
|
|
||||||
abilities.add('vision')
|
|
||||||
|
|
||||||
if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens):
|
|
||||||
abilities.add('func_call')
|
|
||||||
|
|
||||||
return sorted(abilities)
|
|
||||||
|
|
||||||
def _normalize_modalities(self, value: typing.Any) -> list[str]:
|
|
||||||
normalized: list[str] = []
|
|
||||||
|
|
||||||
def _collect(item: typing.Any):
|
|
||||||
if item is None:
|
|
||||||
return
|
|
||||||
if isinstance(item, str):
|
|
||||||
for part in item.replace('->', ',').replace('+', ',').split(','):
|
|
||||||
token = part.strip().lower()
|
|
||||||
if token and token not in normalized:
|
|
||||||
normalized.append(token)
|
|
||||||
return
|
|
||||||
if isinstance(item, dict):
|
|
||||||
for nested in item.values():
|
|
||||||
_collect(nested)
|
|
||||||
return
|
|
||||||
if isinstance(item, (list, tuple, set)):
|
|
||||||
for nested in item:
|
|
||||||
_collect(nested)
|
|
||||||
return
|
|
||||||
|
|
||||||
_collect(value)
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]:
|
|
||||||
display_name = item.get('name')
|
|
||||||
if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id:
|
|
||||||
display_name = ''
|
|
||||||
|
|
||||||
description = item.get('description')
|
|
||||||
if not isinstance(description, str) or not description.strip():
|
|
||||||
description = ''
|
|
||||||
|
|
||||||
context_length = item.get('context_length')
|
|
||||||
if context_length is None and isinstance(item.get('top_provider'), dict):
|
|
||||||
context_length = item['top_provider'].get('context_length')
|
|
||||||
|
|
||||||
if not isinstance(context_length, int):
|
|
||||||
try:
|
|
||||||
context_length = int(context_length) if context_length is not None else None
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
context_length = None
|
|
||||||
|
|
||||||
input_modalities = self._normalize_modalities(item.get('input_modalities'))
|
|
||||||
output_modalities = self._normalize_modalities(item.get('output_modalities'))
|
|
||||||
|
|
||||||
if isinstance(item.get('architecture'), dict):
|
|
||||||
if not input_modalities:
|
|
||||||
input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities'))
|
|
||||||
if not output_modalities:
|
|
||||||
output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities'))
|
|
||||||
|
|
||||||
owned_by = item.get('owned_by')
|
|
||||||
if not isinstance(owned_by, str) or not owned_by.strip():
|
|
||||||
owned_by = ''
|
|
||||||
|
|
||||||
return {
|
|
||||||
'display_name': display_name or None,
|
|
||||||
'description': description or None,
|
|
||||||
'context_length': context_length,
|
|
||||||
'owned_by': owned_by or None,
|
|
||||||
'input_modalities': input_modalities,
|
|
||||||
'output_modalities': output_modalities,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
|
||||||
headers = {}
|
|
||||||
if api_key:
|
|
||||||
headers['Authorization'] = f'Bearer {api_key}'
|
|
||||||
|
|
||||||
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models'
|
|
||||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
|
||||||
response = await client.get(models_url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
payload = response.json()
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for item in payload.get('data', []):
|
|
||||||
model_id = item.get('id')
|
|
||||||
if not model_id:
|
|
||||||
continue
|
|
||||||
models.append(
|
|
||||||
{
|
|
||||||
'id': model_id,
|
|
||||||
'name': model_id,
|
|
||||||
'type': self._infer_model_type(model_id),
|
|
||||||
'abilities': self._infer_model_abilities(item, model_id),
|
|
||||||
**self._extract_scan_metadata(item, model_id),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
|
||||||
return {
|
|
||||||
'models': models,
|
|
||||||
'debug': {
|
|
||||||
'request': {
|
|
||||||
'method': 'GET',
|
|
||||||
'url': models_url,
|
|
||||||
'headers': {
|
|
||||||
'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
'response': payload,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _req(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
args: dict,
|
|
||||||
extra_body: dict = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> list[dict[str, typing.Any]]:
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
chunk = None
|
|
||||||
|
|
||||||
pending_content = ''
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
|
|
||||||
resp_gen: openai.AsyncStream = await self.client.chat.completions.create(**args, extra_body=extra_body)
|
|
||||||
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
tool_id = ''
|
|
||||||
tool_name = ''
|
|
||||||
message_delta = {}
|
|
||||||
async for chunk in resp_gen:
|
|
||||||
if not chunk or not chunk.id or not chunk.choices or not chunk.choices[0] or not chunk.choices[0].delta:
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = chunk.choices[0].delta.model_dump() if hasattr(chunk.choices[0], 'delta') else {}
|
|
||||||
reasoning_content = delta.get('reasoning_content')
|
|
||||||
# 处理 reasoning_content
|
|
||||||
if reasoning_content:
|
|
||||||
# accumulated_reasoning += reasoning_content
|
|
||||||
# 如果设置了 remove_think,跳过 reasoning_content
|
|
||||||
if remove_think:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
|
||||||
if not thinking_started:
|
|
||||||
thinking_started = True
|
|
||||||
pending_content += '<think>\n' + reasoning_content
|
|
||||||
else:
|
|
||||||
# 继续输出 reasoning_content
|
|
||||||
pending_content += reasoning_content
|
|
||||||
elif thinking_started and not thinking_ended and delta.get('content'):
|
|
||||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
|
||||||
thinking_ended = True
|
|
||||||
pending_content += '\n</think>\n' + delta.get('content')
|
|
||||||
|
|
||||||
if delta.get('content') is not None:
|
|
||||||
pending_content += delta.get('content')
|
|
||||||
|
|
||||||
if delta.get('tool_calls') is not None:
|
|
||||||
for tool_call in delta.get('tool_calls'):
|
|
||||||
if tool_call['id'] != '':
|
|
||||||
tool_id = tool_call['id']
|
|
||||||
if tool_call['function']['name'] is not None:
|
|
||||||
tool_name = tool_call['function']['name']
|
|
||||||
if tool_call['function']['arguments'] is None:
|
|
||||||
continue
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
tool_call['name'] = tool_name
|
|
||||||
for tc in tool_calls:
|
|
||||||
if tc['index'] == tool_call['index']:
|
|
||||||
tc['function']['arguments'] += tool_call['function']['arguments']
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
if chunk.choices[0].finish_reason is not None:
|
|
||||||
break
|
|
||||||
message_delta['content'] = pending_content
|
|
||||||
message_delta['role'] = 'assistant'
|
|
||||||
|
|
||||||
message_delta['tool_calls'] = tool_calls if tool_calls else None
|
|
||||||
return [message_delta]
|
|
||||||
|
|
||||||
async def _make_msg(
|
|
||||||
self,
|
|
||||||
chat_completion: list[dict[str, typing.Any]],
|
|
||||||
) -> provider_message.Message:
|
|
||||||
chatcmpl_message = chat_completion[0]
|
|
||||||
|
|
||||||
# 确保 role 字段存在且不为 None
|
|
||||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
|
||||||
chatcmpl_message['role'] = 'assistant'
|
|
||||||
|
|
||||||
message = provider_message.Message(**chatcmpl_message)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def _closure(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[provider_message.Message, dict]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
# 检查vision
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
|
|
||||||
# 发送请求
|
|
||||||
resp = await self._req(query, args, extra_body=extra_args, remove_think=remove_think)
|
|
||||||
|
|
||||||
# 处理请求结果
|
|
||||||
message = await self._make_msg(resp)
|
|
||||||
|
|
||||||
# ModelScope uses streaming, usage info not available
|
|
||||||
usage_info = {}
|
|
||||||
|
|
||||||
return message, usage_info
|
|
||||||
|
|
||||||
async def _req_stream(
|
|
||||||
self,
|
|
||||||
args: dict,
|
|
||||||
extra_body: dict = {},
|
|
||||||
) -> chat_completion.ChatCompletion:
|
|
||||||
async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def _closure_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
# 检查vision
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
# 流式处理状态
|
|
||||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
role = 'assistant' # 默认角色
|
|
||||||
# accumulated_reasoning = '' # 仅用于判断何时结束思维链
|
|
||||||
|
|
||||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
|
||||||
# 解析 chunk 数据
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
|
||||||
finish_reason = getattr(choice, 'finish_reason', None)
|
|
||||||
else:
|
|
||||||
delta = {}
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
|
||||||
if 'role' in delta and delta['role']:
|
|
||||||
role = delta['role']
|
|
||||||
|
|
||||||
# 获取增量内容
|
|
||||||
delta_content = delta.get('content', '')
|
|
||||||
reasoning_content = delta.get('reasoning_content', '')
|
|
||||||
|
|
||||||
# 处理 reasoning_content
|
|
||||||
if reasoning_content:
|
|
||||||
# accumulated_reasoning += reasoning_content
|
|
||||||
# 如果设置了 remove_think,跳过 reasoning_content
|
|
||||||
if remove_think:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
|
||||||
if not thinking_started:
|
|
||||||
thinking_started = True
|
|
||||||
delta_content = '<think>\n' + reasoning_content
|
|
||||||
else:
|
|
||||||
# 继续输出 reasoning_content
|
|
||||||
delta_content = reasoning_content
|
|
||||||
elif thinking_started and not thinking_ended and delta_content:
|
|
||||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
|
||||||
thinking_ended = True
|
|
||||||
delta_content = '\n</think>\n' + delta_content
|
|
||||||
|
|
||||||
# 处理 content 中已有的 <think> 标签(如果需要移除)
|
|
||||||
# if delta_content and remove_think and '<think>' in delta_content:
|
|
||||||
# import re
|
|
||||||
#
|
|
||||||
# # 移除 <think> 标签及其内容
|
|
||||||
# delta_content = re.sub(r'<think>.*?</think>', '', delta_content, flags=re.DOTALL)
|
|
||||||
|
|
||||||
# 处理工具调用增量
|
|
||||||
if delta.get('tool_calls'):
|
|
||||||
for tool_call in delta['tool_calls']:
|
|
||||||
if tool_call['id'] != '':
|
|
||||||
tool_id = tool_call['id']
|
|
||||||
if tool_call['function']['name'] is not None:
|
|
||||||
tool_name = tool_call['function']['name']
|
|
||||||
|
|
||||||
if tool_call['type'] is None:
|
|
||||||
tool_call['type'] = 'function'
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
tool_call['function']['arguments'] = (
|
|
||||||
'' if tool_call['function']['arguments'] is None else tool_call['function']['arguments']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
|
||||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 构建 MessageChunk - 只包含增量内容
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': delta_content if delta_content else None,
|
|
||||||
'tool_calls': delta.get('tool_calls'),
|
|
||||||
'is_final': bool(finish_reason),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除 None 值
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
# return
|
|
||||||
|
|
||||||
async def invoke_llm(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: entities.LLMModelInfo,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
|
||||||
for m in messages:
|
|
||||||
msg_dict = m.dict(exclude_none=True)
|
|
||||||
content = msg_dict.get('content')
|
|
||||||
if isinstance(content, list):
|
|
||||||
# 检查 content 列表中是否每个部分都是文本
|
|
||||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
|
||||||
# 将所有文本部分合并为一个字符串
|
|
||||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
|
||||||
req_messages.append(msg_dict)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self._closure(
|
|
||||||
query=query,
|
|
||||||
req_messages=req_messages,
|
|
||||||
use_model=model,
|
|
||||||
use_funcs=funcs,
|
|
||||||
extra_args=extra_args,
|
|
||||||
remove_think=remove_think,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise errors.RequesterError('请求超时')
|
|
||||||
except openai.BadRequestError as e:
|
|
||||||
if 'context_length_exceeded' in e.message:
|
|
||||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
|
||||||
else:
|
|
||||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
|
||||||
except openai.AuthenticationError as e:
|
|
||||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
|
||||||
except openai.NotFoundError as e:
|
|
||||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
|
||||||
except openai.APIError as e:
|
|
||||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
|
||||||
|
|
||||||
async def invoke_llm_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: requester.RuntimeLLMModel,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.MessageChunk:
|
|
||||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
|
||||||
for m in messages:
|
|
||||||
msg_dict = m.dict(exclude_none=True)
|
|
||||||
content = msg_dict.get('content')
|
|
||||||
if isinstance(content, list):
|
|
||||||
# 检查 content 列表中是否每个部分都是文本
|
|
||||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
|
||||||
# 将所有文本部分合并为一个字符串
|
|
||||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
|
||||||
req_messages.append(msg_dict)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for item in self._closure_stream(
|
|
||||||
query=query,
|
|
||||||
req_messages=req_messages,
|
|
||||||
use_model=model,
|
|
||||||
use_funcs=funcs,
|
|
||||||
extra_args=extra_args,
|
|
||||||
remove_think=remove_think,
|
|
||||||
):
|
|
||||||
yield item
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise errors.RequesterError('请求超时')
|
|
||||||
except openai.BadRequestError as e:
|
|
||||||
if 'context_length_exceeded' in e.message:
|
|
||||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
|
||||||
else:
|
|
||||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
|
||||||
except openai.AuthenticationError as e:
|
|
||||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
|
||||||
except openai.NotFoundError as e:
|
|
||||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
|
||||||
except openai.APIError as e:
|
|
||||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 魔搭社区
|
zh_Hans: 魔搭社区
|
||||||
icon: modelscope.svg
|
icon: modelscope.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -31,6 +32,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
from .. import requester
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
|
|
||||||
|
|
||||||
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""Moonshot ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.moonshot.cn/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _closure(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[provider_message.Message, dict]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages
|
|
||||||
|
|
||||||
# deepseek 不支持多模态,把content都转换成纯文字
|
|
||||||
for m in messages:
|
|
||||||
if 'content' in m and isinstance(m['content'], list):
|
|
||||||
m['content'] = ' '.join([c['text'] for c in m['content']])
|
|
||||||
|
|
||||||
# 删除空的,不知道干嘛的,直接删了。
|
|
||||||
# messages = [m for m in messages if m["content"].strip() != "" and ('tool_calls' not in m or not m['tool_calls'])]
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
|
|
||||||
# 发送请求
|
|
||||||
resp = await self._req(args, extra_body=extra_args)
|
|
||||||
|
|
||||||
# 处理请求结果
|
|
||||||
message = await self._make_msg(resp, remove_think)
|
|
||||||
|
|
||||||
# Extract token usage from response
|
|
||||||
usage_info = {}
|
|
||||||
if hasattr(resp, 'usage') and resp.usage:
|
|
||||||
usage_info['input_tokens'] = resp.usage.prompt_tokens or 0
|
|
||||||
usage_info['output_tokens'] = resp.usage.completion_tokens or 0
|
|
||||||
usage_info['total_tokens'] = resp.usage.total_tokens or 0
|
|
||||||
|
|
||||||
return message, usage_info
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 月之暗面
|
zh_Hans: 月之暗面
|
||||||
icon: moonshot.png
|
icon: moonshot.png
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class NewAPIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""New API ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'http://localhost:3000/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: New API
|
zh_Hans: New API
|
||||||
icon: newapi.png
|
icon: newapi.png
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
@@ -1,314 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import typing
|
|
||||||
from typing import Union, Mapping, Any, AsyncIterator
|
|
||||||
import uuid
|
|
||||||
import json
|
|
||||||
|
|
||||||
import ollama
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from .. import errors, requester
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
|
|
||||||
REQUESTER_NAME: str = 'ollama-chat'
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatCompletions(requester.ProviderAPIRequester):
|
|
||||||
"""Ollama平台 ChatCompletion API请求器"""
|
|
||||||
|
|
||||||
client: ollama.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'http://127.0.0.1:11434',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url']
|
|
||||||
self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout'])
|
|
||||||
|
|
||||||
def _infer_model_type(self, model_id: str) -> str:
|
|
||||||
normalized_model_id = (model_id or '').lower()
|
|
||||||
embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
|
|
||||||
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
|
|
||||||
|
|
||||||
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
|
|
||||||
normalized_model_id = (model_id or '').lower()
|
|
||||||
abilities: set[str] = set()
|
|
||||||
details = item.get('details', {}) or {}
|
|
||||||
families = details.get('families', []) or []
|
|
||||||
tokens = [normalized_model_id, str(details.get('family', '')).lower()]
|
|
||||||
tokens.extend(str(family).lower() for family in families)
|
|
||||||
|
|
||||||
if any(keyword in token for token in tokens for keyword in ('vision', 'vl', 'omni', 'llava', 'ocr')):
|
|
||||||
abilities.add('vision')
|
|
||||||
if any(keyword in token for token in tokens for keyword in ('tool', 'function')):
|
|
||||||
abilities.add('func_call')
|
|
||||||
return sorted(abilities)
|
|
||||||
|
|
||||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
|
||||||
del api_key
|
|
||||||
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/api/tags'
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
|
|
||||||
response = await client.get(models_url)
|
|
||||||
response.raise_for_status()
|
|
||||||
payload = response.json()
|
|
||||||
|
|
||||||
models: list[dict[str, typing.Any]] = []
|
|
||||||
for item in payload.get('models', []):
|
|
||||||
model_id = item.get('model') or item.get('name')
|
|
||||||
if not model_id:
|
|
||||||
continue
|
|
||||||
models.append(
|
|
||||||
{
|
|
||||||
'id': model_id,
|
|
||||||
'name': item.get('name', model_id),
|
|
||||||
'type': self._infer_model_type(model_id),
|
|
||||||
'abilities': self._infer_model_abilities(item, model_id),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
|
|
||||||
return {
|
|
||||||
'models': models,
|
|
||||||
'debug': {
|
|
||||||
'request': {
|
|
||||||
'method': 'GET',
|
|
||||||
'url': models_url,
|
|
||||||
},
|
|
||||||
'response': payload,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _req(
|
|
||||||
self,
|
|
||||||
args: dict,
|
|
||||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
|
||||||
return await self.client.chat(**args)
|
|
||||||
|
|
||||||
async def _closure(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
args = extra_args.copy()
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
messages: list[dict] = req_messages.copy()
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
text_content: list = []
|
|
||||||
image_urls: list = []
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'text':
|
|
||||||
text_content.append(me['text'])
|
|
||||||
elif me['type'] == 'image_base64':
|
|
||||||
image_urls.append(me['image_base64'])
|
|
||||||
|
|
||||||
msg['content'] = '\n'.join(text_content)
|
|
||||||
msg['images'] = [url.split(',')[1] for url in image_urls]
|
|
||||||
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
|
||||||
for tool_call in msg['tool_calls']:
|
|
||||||
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
|
|
||||||
args['messages'] = messages
|
|
||||||
|
|
||||||
args['tools'] = []
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
resp = await self._req(args)
|
|
||||||
message: provider_message.Message = await self._make_msg(resp)
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def _make_msg(self, chat_completions: ollama.ChatResponse) -> provider_message.Message:
|
|
||||||
message: ollama.Message = chat_completions.message
|
|
||||||
if message is None:
|
|
||||||
raise ValueError("chat_completions must contain a 'message' field")
|
|
||||||
|
|
||||||
ret_msg: provider_message.Message = None
|
|
||||||
|
|
||||||
if message.content is not None:
|
|
||||||
ret_msg = provider_message.Message(role='assistant', content=message.content)
|
|
||||||
if message.tool_calls is not None and len(message.tool_calls) > 0:
|
|
||||||
tool_calls: list[provider_message.ToolCall] = []
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
tool_calls.append(
|
|
||||||
provider_message.ToolCall(
|
|
||||||
id=uuid.uuid4().hex,
|
|
||||||
type='function',
|
|
||||||
function=provider_message.FunctionCall(
|
|
||||||
name=tool_call.function.name,
|
|
||||||
arguments=json.dumps(tool_call.function.arguments),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ret_msg.tool_calls = tool_calls
|
|
||||||
|
|
||||||
return ret_msg
|
|
||||||
|
|
||||||
async def _prepare_messages(
|
|
||||||
self,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Prepare messages for Ollama API request."""
|
|
||||||
req_messages: list = []
|
|
||||||
for m in messages:
|
|
||||||
msg_dict: dict = m.dict(exclude_none=True)
|
|
||||||
content: Any = msg_dict.get('content')
|
|
||||||
if isinstance(content, list):
|
|
||||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
|
||||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
|
||||||
req_messages.append(msg_dict)
|
|
||||||
return req_messages
|
|
||||||
|
|
||||||
async def invoke_llm(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: requester.RuntimeLLMModel,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
req_messages = await self._prepare_messages(messages)
|
|
||||||
try:
|
|
||||||
return await self._closure(
|
|
||||||
query=query,
|
|
||||||
req_messages=req_messages,
|
|
||||||
use_model=model,
|
|
||||||
use_funcs=funcs,
|
|
||||||
extra_args=extra_args,
|
|
||||||
remove_think=remove_think,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise errors.RequesterError('请求超时')
|
|
||||||
|
|
||||||
async def invoke_llm_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
model: requester.RuntimeLLMModel,
|
|
||||||
messages: typing.List[provider_message.Message],
|
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.MessageChunk:
|
|
||||||
req_messages = await self._prepare_messages(messages)
|
|
||||||
|
|
||||||
try:
|
|
||||||
args = extra_args.copy()
|
|
||||||
args['model'] = model.model_entity.name
|
|
||||||
|
|
||||||
# Process messages for Ollama format
|
|
||||||
msgs: list[dict] = req_messages.copy()
|
|
||||||
for msg in msgs:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
text_content: list = []
|
|
||||||
image_urls: list = []
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'text':
|
|
||||||
text_content.append(me['text'])
|
|
||||||
elif me['type'] == 'image_base64':
|
|
||||||
image_urls.append(me['image_base64'])
|
|
||||||
msg['content'] = '\n'.join(text_content)
|
|
||||||
msg['images'] = [url.split(',')[1] for url in image_urls]
|
|
||||||
if 'tool_calls' in msg:
|
|
||||||
for tool_call in msg['tool_calls']:
|
|
||||||
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
|
|
||||||
args['messages'] = msgs
|
|
||||||
|
|
||||||
args['tools'] = []
|
|
||||||
if funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
role = 'assistant'
|
|
||||||
|
|
||||||
async for chunk in await self.client.chat(**args):
|
|
||||||
message: ollama.Message = chunk.message
|
|
||||||
done = chunk.done
|
|
||||||
|
|
||||||
delta_content = message.content or ''
|
|
||||||
reasoning_content = getattr(message, 'thinking', '') or ''
|
|
||||||
|
|
||||||
# Handle reasoning/thinking content
|
|
||||||
if reasoning_content:
|
|
||||||
if remove_think:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not thinking_started:
|
|
||||||
thinking_started = True
|
|
||||||
delta_content = '<think>\n' + reasoning_content
|
|
||||||
else:
|
|
||||||
delta_content = reasoning_content
|
|
||||||
elif thinking_started and not thinking_ended and delta_content:
|
|
||||||
thinking_ended = True
|
|
||||||
delta_content = '\n</think>\n' + delta_content
|
|
||||||
|
|
||||||
# Handle tool calls
|
|
||||||
tool_calls_data = None
|
|
||||||
if message.tool_calls:
|
|
||||||
tool_calls_data = []
|
|
||||||
for tc in message.tool_calls:
|
|
||||||
tool_calls_data.append(
|
|
||||||
{
|
|
||||||
'id': uuid.uuid4().hex,
|
|
||||||
'type': 'function',
|
|
||||||
'function': {
|
|
||||||
'name': tc.function.name,
|
|
||||||
'arguments': json.dumps(tc.function.arguments),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip empty first chunk
|
|
||||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not tool_calls_data:
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': delta_content if delta_content else None,
|
|
||||||
'tool_calls': tool_calls_data,
|
|
||||||
'is_final': bool(done),
|
|
||||||
}
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise errors.RequesterError('请求超时')
|
|
||||||
|
|
||||||
async def invoke_embedding(
|
|
||||||
self,
|
|
||||||
model: requester.RuntimeEmbeddingModel,
|
|
||||||
input_text: list[str],
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
) -> list[list[float]]:
|
|
||||||
return (
|
|
||||||
await self.client.embed(
|
|
||||||
model=model.model_entity.name,
|
|
||||||
input=input_text,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
).embeddings
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: Ollama
|
zh_Hans: Ollama
|
||||||
icon: ollama.svg
|
icon: ollama.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: ollama
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import modelscopechatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
|
|
||||||
"""OpenRouter ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://openrouter.ai/api/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
|
||||||
original_base_url = self.requester_cfg.get('base_url', '')
|
|
||||||
self.requester_cfg['base_url'] = 'https://openrouter.ai/api/v1'
|
|
||||||
try:
|
|
||||||
return await super().scan_models(api_key)
|
|
||||||
finally:
|
|
||||||
self.requester_cfg['base_url'] = original_base_url
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: OpenRouter
|
zh_Hans: OpenRouter
|
||||||
icon: openrouter.svg
|
icon: openrouter.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -25,6 +26,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,208 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
from .. import requester
|
|
||||||
import openai.types.chat.chat_completion as chat_completion
|
|
||||||
import re
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
||||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
||||||
|
|
||||||
|
|
||||||
class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""欧派云 ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.ppinfra.com/v3/openai',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
is_think: bool = False
|
|
||||||
|
|
||||||
async def _make_msg(
|
|
||||||
self,
|
|
||||||
chat_completion: chat_completion.ChatCompletion,
|
|
||||||
remove_think: bool,
|
|
||||||
) -> provider_message.Message:
|
|
||||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
|
||||||
# print(chatcmpl_message.keys(), chatcmpl_message.values())
|
|
||||||
|
|
||||||
# 确保 role 字段存在且不为 None
|
|
||||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
|
||||||
chatcmpl_message['role'] = 'assistant'
|
|
||||||
|
|
||||||
reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None
|
|
||||||
|
|
||||||
# deepseek的reasoner模型
|
|
||||||
chatcmpl_message['content'] = await self._process_thinking_content(
|
|
||||||
chatcmpl_message['content'], reasoning_content, remove_think
|
|
||||||
)
|
|
||||||
|
|
||||||
# 移除 reasoning_content 字段,避免传递给 Message
|
|
||||||
if 'reasoning_content' in chatcmpl_message:
|
|
||||||
del chatcmpl_message['reasoning_content']
|
|
||||||
|
|
||||||
message = provider_message.Message(**chatcmpl_message)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def _process_thinking_content(
|
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
reasoning_content: str = None,
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""处理思维链内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 原始内容
|
|
||||||
reasoning_content: reasoning_content 字段内容
|
|
||||||
remove_think: 是否移除思维链
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
处理后的内容
|
|
||||||
"""
|
|
||||||
if remove_think:
|
|
||||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
|
|
||||||
else:
|
|
||||||
if reasoning_content is not None:
|
|
||||||
content = '<think>\n' + reasoning_content + '\n</think>\n' + content
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _make_msg_chunk(
|
|
||||||
self,
|
|
||||||
delta: dict[str, typing.Any],
|
|
||||||
idx: int,
|
|
||||||
) -> provider_message.MessageChunk:
|
|
||||||
# 处理流式chunk和完整响应的差异
|
|
||||||
# print(chat_completion.choices[0])
|
|
||||||
|
|
||||||
# 确保 role 字段存在且不为 None
|
|
||||||
if 'role' not in delta or delta['role'] is None:
|
|
||||||
delta['role'] = 'assistant'
|
|
||||||
|
|
||||||
reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None
|
|
||||||
|
|
||||||
delta['content'] = '' if delta['content'] is None else delta['content']
|
|
||||||
# print(reasoning_content)
|
|
||||||
|
|
||||||
# deepseek的reasoner模型
|
|
||||||
|
|
||||||
if reasoning_content is not None:
|
|
||||||
delta['content'] += reasoning_content
|
|
||||||
|
|
||||||
message = provider_message.MessageChunk(**delta)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def _closure_stream(
|
|
||||||
self,
|
|
||||||
query: pipeline_query.Query,
|
|
||||||
req_messages: list[dict],
|
|
||||||
use_model: requester.RuntimeLLMModel,
|
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
|
||||||
extra_args: dict[str, typing.Any] = {},
|
|
||||||
remove_think: bool = False,
|
|
||||||
) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
|
||||||
self.client.api_key = use_model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {}
|
|
||||||
args['model'] = use_model.model_entity.name
|
|
||||||
|
|
||||||
if use_funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
# 设置此次请求中的messages
|
|
||||||
messages = req_messages.copy()
|
|
||||||
|
|
||||||
# 检查vision
|
|
||||||
for msg in messages:
|
|
||||||
if 'content' in msg and isinstance(msg['content'], list):
|
|
||||||
for me in msg['content']:
|
|
||||||
if me['type'] == 'image_base64':
|
|
||||||
me['image_url'] = {'url': me['image_base64']}
|
|
||||||
me['type'] = 'image_url'
|
|
||||||
del me['image_base64']
|
|
||||||
|
|
||||||
args['messages'] = messages
|
|
||||||
args['stream'] = True
|
|
||||||
|
|
||||||
# tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
|
||||||
chunk_idx = 0
|
|
||||||
thinking_started = False
|
|
||||||
thinking_ended = False
|
|
||||||
role = 'assistant' # 默认角色
|
|
||||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
|
||||||
# 解析 chunk 数据
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
|
||||||
finish_reason = getattr(choice, 'finish_reason', None)
|
|
||||||
else:
|
|
||||||
delta = {}
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
|
||||||
if 'role' in delta and delta['role']:
|
|
||||||
role = delta['role']
|
|
||||||
|
|
||||||
# 获取增量内容
|
|
||||||
delta_content = delta.get('content', '')
|
|
||||||
# reasoning_content = delta.get('reasoning_content', '')
|
|
||||||
|
|
||||||
if remove_think:
|
|
||||||
if delta['content'] is not None:
|
|
||||||
if '<think>' in delta['content'] and not thinking_started and not thinking_ended:
|
|
||||||
thinking_started = True
|
|
||||||
continue
|
|
||||||
elif delta['content'] == r'</think>' and not thinking_ended:
|
|
||||||
thinking_ended = True
|
|
||||||
continue
|
|
||||||
elif thinking_ended and delta['content'] == '\n\n' and thinking_started:
|
|
||||||
thinking_started = False
|
|
||||||
continue
|
|
||||||
elif thinking_started and not thinking_ended:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# delta_tool_calls = None
|
|
||||||
if delta.get('tool_calls'):
|
|
||||||
for tool_call in delta['tool_calls']:
|
|
||||||
if tool_call['id'] and tool_call['function']['name']:
|
|
||||||
tool_id = tool_call['id']
|
|
||||||
tool_name = tool_call['function']['name']
|
|
||||||
|
|
||||||
if tool_call['id'] is None:
|
|
||||||
tool_call['id'] = tool_id
|
|
||||||
if tool_call['function']['name'] is None:
|
|
||||||
tool_call['function']['name'] = tool_name
|
|
||||||
if tool_call['function']['arguments'] is None:
|
|
||||||
tool_call['function']['arguments'] = ''
|
|
||||||
if tool_call['type'] is None:
|
|
||||||
tool_call['type'] = 'function'
|
|
||||||
|
|
||||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
|
||||||
if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'):
|
|
||||||
chunk_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 构建 MessageChunk - 只包含增量内容
|
|
||||||
chunk_data = {
|
|
||||||
'role': role,
|
|
||||||
'content': delta_content if delta_content else None,
|
|
||||||
'tool_calls': delta.get('tool_calls'),
|
|
||||||
'is_final': bool(finish_reason),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除 None 值
|
|
||||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
||||||
|
|
||||||
yield provider_message.MessageChunk(**chunk_data)
|
|
||||||
chunk_idx += 1
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 派欧云
|
zh_Hans: 派欧云
|
||||||
icon: ppio.svg
|
icon: ppio.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class QHAIGCChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""启航 AI ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.qhaigc.com/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 启航 AI
|
zh_Hans: 启航 AI
|
||||||
icon: qhaigc.png
|
icon: qhaigc.png
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
@@ -1,8 +1,17 @@
|
|||||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg id="_图层_1" data-name="图层 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 334.84 76.22">
|
||||||
<rect width="24" height="24" rx="5" fill="#1E3A5F"/>
|
<defs>
|
||||||
<path d="M6 12C6 8.68629 8.68629 6 12 6C15.3137 6 18 8.68629 18 12" stroke="#4FC3F7" stroke-width="2" stroke-linecap="round"/>
|
<style>
|
||||||
<path d="M18 12C18 15.3137 15.3137 18 12 18C8.68629 18 6 15.3137 6 12" stroke="#81D4FA" stroke-width="2" stroke-linecap="round"/>
|
.cls-1 {
|
||||||
<circle cx="12" cy="12" r="2" fill="#4FC3F7"/>
|
fill: currentColor;
|
||||||
<circle cx="6" cy="12" r="1.5" fill="#81D4FA"/>
|
}
|
||||||
<circle cx="18" cy="12" r="1.5" fill="#4FC3F7"/>
|
</style>
|
||||||
</svg>
|
</defs>
|
||||||
|
<path class="cls-1" d="M308.56,23.63c-5.04,0-9.73,1.43-13.73,3.88V1.08l-12.56,4.61v70h12.56v-3.35c4,2.46,8.71,3.88,13.73,3.88,14.49,0,26.29-11.79,26.29-26.29s-11.79-26.29-26.29-26.29h0ZM308.56,63.88c-6.87,0-12.57-4.98-13.73-11.51v-4.91c1.16-6.54,6.88-11.51,13.73-11.51,7.7,0,13.96,6.26,13.96,13.96s-6.26,13.96-13.96,13.96Z"></path>
|
||||||
|
<path class="cls-1" d="M255.54,5.69v21.83c-4-2.46-8.71-3.88-13.73-3.88-14.49,0-26.29,11.79-26.29,26.29s11.79,26.29,26.29,26.29c5.04,0,9.73-1.43,13.73-3.88v3.35h12.56V1.08l-12.56,4.61ZM241.81,63.88c-7.7,0-13.96-6.26-13.96-13.96s6.26-13.96,13.96-13.96c6.87,0,12.57,4.98,13.73,11.51v4.91c-1.16,6.54-6.88,11.51-13.73,11.51Z"></path>
|
||||||
|
<polygon class="cls-1" points="195.35 52.2 186.65 61.17 200.64 75.62 209.32 75.62 218.01 75.62 195.35 52.2"></polygon>
|
||||||
|
<path class="cls-1" d="M167.14,4.59c.65,3.99.68,8.04.03,12.15-.03.17.16.3.31.21,3.82-2.21,7.82-3.69,12.01-4.33.12-.02.19-.13.17-.23-.68-4.13-.61-8.18-.03-12.16.02-.17-.16-.3-.31-.2-4.01,2.31-8.01,3.81-12.01,4.34-.12.01-.19.12-.17.23h0Z"></path>
|
||||||
|
<path class="cls-1" d="M198.75,24.09l-19.07,19.72v-25.57c-4.49.67-8.7,2.11-12.56,4.57v52.83h12.56v-13.87l3.78-3.9.02.02,8.68-8.97-.02-.02,23.98-24.8h-17.37Z"></path>
|
||||||
|
<path class="cls-1" d="M145.03,57.86c-2.56,4.45-7.17,7.2-12.13,7.2-5.96,0-11.3-3.96-13.32-9.85h38.87l.08-.42c.29-1.5.42-3.06.42-4.65,0-14.37-11.69-26.06-26.06-26.06s-26.06,11.69-26.06,26.06,11.69,26.06,26.06,26.06c9.63,0,18.43-5.28,22.98-13.77l.26-.49-11.1-4.08h-.01ZM132.88,35.19h.03c5.96,0,11.3,3.96,13.32,9.85h-26.67c2.02-5.89,7.36-9.85,13.32-9.85Z"></path>
|
||||||
|
<path class="cls-1" d="M75.92,65.07c-5.96,0-11.29-3.96-13.32-9.85h38.87l.08-.42c.29-1.5.42-3.06.42-4.65,0-14.37-11.69-26.06-26.06-26.06s-26.06,11.69-26.06,26.06,11.69,26.06,26.06,26.06c9.63,0,18.43-5.28,22.98-13.77l.26-.49h0l-11.1-4.08c-2.56,4.45-7.17,7.2-12.13,7.2h-.01ZM75.92,35.19h.03c5.96,0,11.29,3.96,13.32,9.85h-26.67c2.03-5.89,7.36-9.85,13.32-9.85Z"></path>
|
||||||
|
<path class="cls-1" d="M30.43,45.58l-10.2-1.91c-3.03-.56-4.98-2.25-4.98-4.33,0-1.5,1.61-4.35,7.68-4.35,5.53,0,9.36,3.5,10.25,6.26l10.9-4-.14-.42c-1.17-3.54-3.5-6.58-6.94-9.04-3.49-2.49-8.04-3.69-13.88-3.69s-10.98,1.5-14.78,4.34c-3.88,2.91-5.84,6.76-5.84,11.46,0,7.98,4.72,12.77,14.42,14.64l9.9,1.81c3.05.61,4.94,2.27,4.94,4.33,0,2.61-3.58,4.44-8.7,4.44-5.79,0-9.9-3.72-11.85-7.14L0,62.1l.14.39c1.3,3.8,3.89,7.07,7.7,9.71,3.78,2.6,8.65,3.95,14.51,3.98l.25.03c6.87,0,12.55-1.57,16.43-4.53,3.98-3.05,6-6.99,6-11.74,0-3.73-1.14-6.7-3.6-9.33-2.27-2.42-5.98-4.11-10.98-5.02h-.02Z"></path>
|
||||||
|
</svg>
|
||||||
|
Before Width: | Height: | Size: 569 B After Width: | Height: | Size: 2.7 KiB |
@@ -1,32 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
import openai.types.chat.chat_completion as chat_completion
|
|
||||||
|
|
||||||
|
|
||||||
class ShengSuanYunChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""胜算云(ModelSpot.AI) ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://router.shengsuanyun.com/api/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _req(
|
|
||||||
self,
|
|
||||||
args: dict,
|
|
||||||
extra_body: dict = {},
|
|
||||||
) -> chat_completion.ChatCompletion:
|
|
||||||
return await self.client.chat.completions.create(
|
|
||||||
**args,
|
|
||||||
extra_body=extra_body,
|
|
||||||
extra_headers={
|
|
||||||
'HTTP-Referer': 'https://langbot.app',
|
|
||||||
'X-Title': 'LangBot',
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 胜算云
|
zh_Hans: 胜算云
|
||||||
icon: shengsuanyun.svg
|
icon: shengsuanyun.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""SiliconFlow ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.siliconflow.cn/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 硅基流动
|
zh_Hans: 硅基流动
|
||||||
icon: siliconflow.svg
|
icon: siliconflow.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -25,6 +26,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class LangBotSpaceChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""LangBot Space ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.langbot.cloud/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: Space
|
zh_Hans: Space
|
||||||
icon: space.webp
|
icon: space.webp
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
5
src/langbot/pkg/provider/modelmgr/requesters/tencent.svg
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#0052D9"/>
|
||||||
|
<text x="30" y="28" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="white" text-anchor="middle">Tencent</text>
|
||||||
|
<text x="30" y="40" font-family="Arial, sans-serif" font-size="8" fill="white" text-anchor="middle">Hunyuan</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 400 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: tencent-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: Tencent Hunyuan
|
||||||
|
zh_Hans: 腾讯混元
|
||||||
|
icon: tencent.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://hunyuan.tencentcloudapi.com/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#8B5CF6"/>
|
||||||
|
<text x="30" y="28" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="white" text-anchor="middle">Together</text>
|
||||||
|
<text x="30" y="40" font-family="Arial, sans-serif" font-size="8" fill="white" text-anchor="middle">AI</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 396 B |
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: together-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: Together AI
|
||||||
|
zh_Hans: Together AI
|
||||||
|
icon: together.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: together_ai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.together.xyz/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 小马算力
|
zh_Hans: 小马算力
|
||||||
icon: tokenpony.svg
|
icon: tokenpony.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class TokenPonyChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""TokenPony ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.tokenpony.cn/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""火山方舟大模型平台 ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 火山方舟
|
zh_Hans: 火山方舟
|
||||||
icon: volcark.svg
|
icon: volcark.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Voyage</title><path d="M5.407 0v.066a.974.974 0 00-.048.245c-.011.11-.016.208-.016.295 0 .339.043.715.128 1.13.097.405.274.912.531 1.524l7.125 16.366L20.011 3.39c.161-.404.333-.846.515-1.327.182-.48.273-.966.273-1.458a1.406 1.406 0 00-.096-.54V0H24v.066c-.204.207-.45.578-.74 1.114-.29.535-.606 1.195-.949 1.982L13.095 24h-1.287L3.075 3.965c-.204-.47-.418-.923-.644-1.36-.214-.437-.418-.83-.61-1.18-.194-.36-.365-.66-.515-.9A5.666 5.666 0 001 .064V0h4.407z" fill="#012E33"></path></svg>
|
||||||
|
After Width: | Height: | Size: 610 B |
@@ -0,0 +1,32 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: voyageai-rerank
|
||||||
|
label:
|
||||||
|
en_US: Voyage AI
|
||||||
|
zh_Hans: Voyage AI
|
||||||
|
icon: voyageai.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.voyageai.com/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./chatcmpl.py
|
||||||
|
attr: OpenAIChatCompletions
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""xAI ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://api.x.ai/v1',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: xAI
|
zh_Hans: xAI
|
||||||
icon: xai.svg
|
icon: xai.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
5
src/langbot/pkg/provider/modelmgr/requesters/yi.svg
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="60" height="50" viewBox="0 0 60 50" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="60" height="50" rx="8" fill="#10B981"/>
|
||||||
|
<text x="30" y="28" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="white" text-anchor="middle">01.AI</text>
|
||||||
|
<text x="30" y="40" font-family="Arial, sans-serif" font-size="8" fill="white" text-anchor="middle">Yi</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 393 B |
30
src/langbot/pkg/provider/modelmgr/requesters/yichatcmpl.yaml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: yi-chat-completions
|
||||||
|
label:
|
||||||
|
en_US: 01.AI Yi
|
||||||
|
zh_Hans: 零一万物
|
||||||
|
icon: yi.svg
|
||||||
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.lingyiwanwu.com/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from . import chatcmpl
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|
||||||
"""智谱AI ChatCompletion API 请求器"""
|
|
||||||
|
|
||||||
client: openai.AsyncClient
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
|
||||||
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
|
|
||||||
'timeout': 120,
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ metadata:
|
|||||||
zh_Hans: 智谱 AI
|
zh_Hans: 智谱 AI
|
||||||
icon: zhipuai.svg
|
icon: zhipuai.svg
|
||||||
spec:
|
spec:
|
||||||
|
litellm_provider: openai
|
||||||
config:
|
config:
|
||||||
- name: base_url
|
- name: base_url
|
||||||
label:
|
label:
|
||||||
@@ -24,6 +25,8 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -172,6 +172,45 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
if result:
|
if result:
|
||||||
all_results.extend(result)
|
all_results.extend(result)
|
||||||
|
|
||||||
|
# Rerank step: re-score results using a rerank model if configured
|
||||||
|
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||||
|
rerank_model_uuid = local_agent_config.get('rerank-model', '')
|
||||||
|
if rerank_model_uuid == '__none__':
|
||||||
|
rerank_model_uuid = ''
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'Rerank config: model_uuid={rerank_model_uuid!r}, '
|
||||||
|
f'results={len(all_results)}, '
|
||||||
|
f'local_agent_keys={list(local_agent_config.keys())}'
|
||||||
|
)
|
||||||
|
if all_results and rerank_model_uuid:
|
||||||
|
try:
|
||||||
|
rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid)
|
||||||
|
rerank_top_k = int(local_agent_config.get('rerank-top-k', 5))
|
||||||
|
|
||||||
|
doc_texts = []
|
||||||
|
for entry in all_results:
|
||||||
|
text = ' '.join(c.text for c in entry.content if c.type == 'text' and c.text)
|
||||||
|
doc_texts.append(text)
|
||||||
|
|
||||||
|
doc_texts_capped = doc_texts[:64]
|
||||||
|
scores = await rerank_model.provider.invoke_rerank(
|
||||||
|
model=rerank_model,
|
||||||
|
query=user_message_text,
|
||||||
|
documents=doc_texts_capped,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True)
|
||||||
|
top_indices = [s['index'] for s in scored[:rerank_top_k] if s['index'] < len(all_results)]
|
||||||
|
all_results = [all_results[i] for i in top_indices]
|
||||||
|
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'Rerank complete: {len(doc_texts)} docs reranked -> top {len(all_results)} kept (top_k={rerank_top_k})'
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
self.ap.logger.warning(f'Rerank model {rerank_model_uuid} not found, skipping rerank')
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Rerank failed, using original order: {e}')
|
||||||
|
|
||||||
final_user_message_text = ''
|
final_user_message_text = ''
|
||||||
|
|
||||||
if all_results:
|
if all_results:
|
||||||
|
|||||||
@@ -57,41 +57,6 @@ class ToolManager:
|
|||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list:
|
|
||||||
"""为anthropic生成函数列表
|
|
||||||
|
|
||||||
e.g.
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "get_stock_price",
|
|
||||||
"description": "Get the current stock price for a given ticker symbol.",
|
|
||||||
"input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"ticker": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["ticker"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
|
|
||||||
for function in use_funcs:
|
|
||||||
function_schema = {
|
|
||||||
'name': function.name,
|
|
||||||
'description': function.description,
|
|
||||||
'input_schema': function.parameters,
|
|
||||||
}
|
|
||||||
tools.append(function_schema)
|
|
||||||
|
|
||||||
return tools
|
|
||||||
|
|
||||||
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
|
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
|
||||||
"""执行函数调用"""
|
"""执行函数调用"""
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,9 @@
|
|||||||
"content": "You are a helpful assistant."
|
"content": "You are a helpful assistant."
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"knowledge-bases": []
|
"knowledge-bases": [],
|
||||||
|
"rerank-model": "",
|
||||||
|
"rerank-top-k": 5
|
||||||
},
|
},
|
||||||
"dify-service-api": {
|
"dify-service-api": {
|
||||||
"base-url": "https://api.dify.ai/v1",
|
"base-url": "https://api.dify.ai/v1",
|
||||||
|
|||||||
@@ -104,6 +104,34 @@ stages:
|
|||||||
field: __system.is_wizard
|
field: __system.is_wizard
|
||||||
operator: neq
|
operator: neq
|
||||||
value: true
|
value: true
|
||||||
|
- name: rerank-model
|
||||||
|
label:
|
||||||
|
en_US: Rerank Model
|
||||||
|
zh_Hans: 重排序模型
|
||||||
|
description:
|
||||||
|
en_US: Optional rerank model to improve retrieval quality by re-scoring retrieved chunks
|
||||||
|
zh_Hans: 可选的重排序模型,通过重新评分检索结果来提升检索质量
|
||||||
|
type: rerank-model-selector
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
show_if:
|
||||||
|
field: knowledge-bases
|
||||||
|
operator: neq
|
||||||
|
value: []
|
||||||
|
- name: rerank-top-k
|
||||||
|
label:
|
||||||
|
en_US: Rerank Top K
|
||||||
|
zh_Hans: 重排序保留数量
|
||||||
|
description:
|
||||||
|
en_US: Number of top results to keep after reranking
|
||||||
|
zh_Hans: 重排序后保留的最相关结果数量
|
||||||
|
type: integer
|
||||||
|
required: false
|
||||||
|
default: 5
|
||||||
|
show_if:
|
||||||
|
field: rerank-model
|
||||||
|
operator: neq
|
||||||
|
value: ''
|
||||||
- name: dify-service-api
|
- name: dify-service-api
|
||||||
label:
|
label:
|
||||||
en_US: Dify Service API
|
en_US: Dify Service API
|
||||||
|
|||||||
1
tests/unit_tests/provider/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Provider requester tests"""
|
||||||
633
tests/unit_tests/provider/test_litellmchat.py
Normal file
@@ -0,0 +1,633 @@
|
|||||||
|
"""
|
||||||
|
Tests for LiteLLMRequester - unified requester for chat, embedding, and rerank.
|
||||||
|
|
||||||
|
These tests verify:
|
||||||
|
- Parameter building and LiteLLM API calls
|
||||||
|
- Response processing and usage extraction
|
||||||
|
- Error handling and exception translation
|
||||||
|
- Model name building with provider prefix
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langbot.pkg.provider.modelmgr.requesters import litellmchat
|
||||||
|
from langbot.pkg.provider.modelmgr import errors
|
||||||
|
|
||||||
|
|
||||||
|
class MockRuntimeModel:
|
||||||
|
"""Mock RuntimeLLMModel for testing"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = 'gpt-4o', api_key: str = 'test-key'):
|
||||||
|
self.model_entity = Mock()
|
||||||
|
self.model_entity.name = model_name
|
||||||
|
self.model_entity.extra_args = {}
|
||||||
|
self.provider = Mock()
|
||||||
|
self.provider.token_mgr = Mock()
|
||||||
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
class MockRuntimeEmbeddingModel:
|
||||||
|
"""Mock RuntimeEmbeddingModel for testing"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = 'text-embedding-3-small', api_key: str = 'test-key'):
|
||||||
|
self.model_entity = Mock()
|
||||||
|
self.model_entity.name = model_name
|
||||||
|
self.model_entity.extra_args = {}
|
||||||
|
self.provider = Mock()
|
||||||
|
self.provider.token_mgr = Mock()
|
||||||
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
class MockRuntimeRerankModel:
|
||||||
|
"""Mock RuntimeRerankModel for testing"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = 'cohere/rerank-english-v3.0', api_key: str = 'test-key'):
|
||||||
|
self.model_entity = Mock()
|
||||||
|
self.model_entity.name = model_name
|
||||||
|
self.model_entity.extra_args = {}
|
||||||
|
self.provider = Mock()
|
||||||
|
self.provider.token_mgr = Mock()
|
||||||
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildLiteLLMModelName:
|
||||||
|
"""Test _build_litellm_model_name method"""
|
||||||
|
|
||||||
|
def test_no_provider_prefix(self):
|
||||||
|
"""Test model name without provider prefix"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': ''})
|
||||||
|
result = requester._build_litellm_model_name('gpt-4o')
|
||||||
|
assert result == 'gpt-4o'
|
||||||
|
|
||||||
|
def test_with_provider_prefix(self):
|
||||||
|
"""Test model name with provider prefix"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||||
|
result = requester._build_litellm_model_name('gpt-4o')
|
||||||
|
assert result == 'openai/gpt-4o'
|
||||||
|
|
||||||
|
def test_override_provider(self):
|
||||||
|
"""Test override provider via parameter"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||||
|
result = requester._build_litellm_model_name('claude-3', custom_llm_provider='anthropic')
|
||||||
|
assert result == 'anthropic/claude-3'
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUsage:
|
||||||
|
"""Test _extract_usage method"""
|
||||||
|
|
||||||
|
def test_extract_usage_with_data(self):
|
||||||
|
"""Test extraction with valid usage data"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = 100
|
||||||
|
response.usage.completion_tokens = 50
|
||||||
|
response.usage.total_tokens = 150
|
||||||
|
|
||||||
|
result = requester._extract_usage(response)
|
||||||
|
|
||||||
|
assert result['prompt_tokens'] == 100
|
||||||
|
assert result['completion_tokens'] == 50
|
||||||
|
assert result['total_tokens'] == 150
|
||||||
|
|
||||||
|
def test_extract_usage_with_zero_values(self):
|
||||||
|
"""Test extraction when values are 0"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = 0
|
||||||
|
response.usage.completion_tokens = 0
|
||||||
|
response.usage.total_tokens = 0
|
||||||
|
|
||||||
|
result = requester._extract_usage(response)
|
||||||
|
|
||||||
|
assert result['prompt_tokens'] == 0
|
||||||
|
assert result['completion_tokens'] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessThinkingContent:
|
||||||
|
"""Test _process_thinking_content method"""
|
||||||
|
|
||||||
|
def test_no_thinking_markers(self):
|
||||||
|
"""Test content without thinking markers"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
result = requester._process_thinking_content('Hello world', None, remove_think=True)
|
||||||
|
assert result == 'Hello world'
|
||||||
|
|
||||||
|
def test_remove_thinking_markers(self):
|
||||||
|
"""Test removing thinking markers when remove_think=True"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
|
||||||
|
result = requester._process_thinking_content(content, None, remove_think=True)
|
||||||
|
assert result == 'The answer is 42.'
|
||||||
|
|
||||||
|
def test_preserve_thinking_markers(self):
|
||||||
|
"""Test preserving thinking markers when remove_think=False"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
|
||||||
|
result = requester._process_thinking_content(content, None, remove_think=False)
|
||||||
|
assert 'CRETIRE_REASONING_BEGINk' in result
|
||||||
|
assert 'The answer is 42.' in result
|
||||||
|
|
||||||
|
def test_empty_content(self):
|
||||||
|
"""Test empty content"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
result = requester._process_thinking_content('', None, remove_think=True)
|
||||||
|
assert result == ''
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildCommonArgs:
|
||||||
|
"""Test _build_common_args method"""
|
||||||
|
|
||||||
|
def test_build_args_with_all_params(self):
|
||||||
|
"""Test building args with all config params"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
'drop_params': True,
|
||||||
|
'num_retries': 3,
|
||||||
|
'api_version': '2024-01-01',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
requester._build_common_args(args)
|
||||||
|
|
||||||
|
assert args['api_base'] == 'https://api.openai.com/v1'
|
||||||
|
assert args['timeout'] == 60
|
||||||
|
assert args['drop_params'] == True
|
||||||
|
assert args['num_retries'] == 3
|
||||||
|
assert args['api_version'] == '2024-01-01'
|
||||||
|
|
||||||
|
def test_build_args_without_retry_params(self):
|
||||||
|
"""Test building args without retry params for embedding/rerank"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
'num_retries': 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
requester._build_common_args(args, include_retry_params=False)
|
||||||
|
|
||||||
|
assert args['api_base'] == 'https://api.openai.com/v1'
|
||||||
|
assert args['timeout'] == 60
|
||||||
|
assert 'num_retries' not in args
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleLiteLLMError:
|
||||||
|
"""Test _handle_litellm_error method"""
|
||||||
|
|
||||||
|
def test_bad_request_error(self):
|
||||||
|
"""Test BadRequestError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
# Create proper LiteLLM exception with required args
|
||||||
|
error = litellm.BadRequestError(message='test error', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '请求参数错误' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_authentication_error(self):
|
||||||
|
"""Test AuthenticationError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert 'API key 无效' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_rate_limit_error(self):
|
||||||
|
"""Test RateLimitError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.RateLimitError(message='rate limited', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '请求过于频繁' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_timeout_error(self):
|
||||||
|
"""Test Timeout translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.Timeout(message='timeout', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '请求超时' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_context_window_error(self):
|
||||||
|
"""Test ContextWindowExceededError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.ContextWindowExceededError(message='context too long', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '上下文长度超限' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_unknown_error(self):
|
||||||
|
"""Test unknown error translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(Exception('unknown'))
|
||||||
|
|
||||||
|
assert '未知错误' in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvokeLLM:
|
||||||
|
"""Test invoke_llm method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_llm_basic(self):
|
||||||
|
"""Test basic LLM invocation"""
|
||||||
|
mock_ap = Mock()
|
||||||
|
mock_ap.tool_mgr = Mock()
|
||||||
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=mock_ap,
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock LiteLLM response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock()]
|
||||||
|
mock_response.choices[0].message = Mock()
|
||||||
|
mock_response.choices[0].message.model_dump = Mock(
|
||||||
|
return_value={
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'Hello! How can I help you?',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.usage = Mock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 20
|
||||||
|
mock_response.usage.total_tokens = 30
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='Hello')]
|
||||||
|
|
||||||
|
# Patch acompletion at the import location
|
||||||
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
result_msg, usage = await requester.invoke_llm(
|
||||||
|
query=None,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_msg.role == 'assistant'
|
||||||
|
assert result_msg.content == 'Hello! How can I help you?'
|
||||||
|
assert usage['prompt_tokens'] == 10
|
||||||
|
assert usage['completion_tokens'] == 20
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_llm_with_tools(self):
|
||||||
|
"""Test LLM invocation with function calling"""
|
||||||
|
mock_ap = Mock()
|
||||||
|
mock_ap.tool_mgr = Mock()
|
||||||
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
|
||||||
|
return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}]
|
||||||
|
)
|
||||||
|
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||||
|
|
||||||
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock()]
|
||||||
|
mock_response.choices[0].message = Mock()
|
||||||
|
mock_response.choices[0].message.model_dump = Mock(
|
||||||
|
return_value={
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': None,
|
||||||
|
'tool_calls': [
|
||||||
|
{'id': 'call_123', 'type': 'function', 'function': {'name': 'get_weather', 'arguments': '{}'}}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.usage = Mock()
|
||||||
|
mock_response.usage.prompt_tokens = 15
|
||||||
|
mock_response.usage.completion_tokens = 10
|
||||||
|
mock_response.usage.total_tokens = 25
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='What is the weather?')]
|
||||||
|
# Create proper LLMTool with all required fields
|
||||||
|
funcs = [Mock(spec=resource_tool.LLMTool)]
|
||||||
|
funcs[0].name = 'get_weather'
|
||||||
|
funcs[0].description = 'Get weather'
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
result_msg, usage = await requester.invoke_llm(
|
||||||
|
query=None,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
funcs=funcs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_msg.tool_calls is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_llm_error_handling(self):
|
||||||
|
"""Test LLM invocation error handling"""
|
||||||
|
mock_ap = Mock()
|
||||||
|
mock_ap.tool_mgr = Mock()
|
||||||
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||||
|
|
||||||
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='Hello')]
|
||||||
|
|
||||||
|
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, side_effect=error):
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
await requester.invoke_llm(
|
||||||
|
query=None,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 'API key 无效' in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvokeEmbedding:
|
||||||
|
"""Test invoke_embedding method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_embedding_basic(self):
|
||||||
|
"""Test basic embedding invocation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MockRuntimeEmbeddingModel('text-embedding-3-small', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock LiteLLM embedding response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_response.usage = Mock()
|
||||||
|
mock_response.usage.prompt_tokens = 20
|
||||||
|
mock_response.usage.completion_tokens = 0
|
||||||
|
mock_response.usage.total_tokens = 20
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'aembedding', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
embeddings, usage = await requester.invoke_embedding(
|
||||||
|
model=model,
|
||||||
|
input_text=['Hello', 'World'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert embeddings[0] == [0.1, 0.2, 0.3]
|
||||||
|
assert embeddings[1] == [0.4, 0.5, 0.6]
|
||||||
|
assert usage['prompt_tokens'] == 20
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvokeRerank:
|
||||||
|
"""Test invoke_rerank method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_rerank_basic(self):
|
||||||
|
"""Test basic rerank invocation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.cohere.ai',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock LiteLLM rerank response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.results = [
|
||||||
|
{'index': 0, 'relevance_score': 0.95},
|
||||||
|
{'index': 1, 'relevance_score': 0.3},
|
||||||
|
{'index': 2, 'relevance_score': 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
results = await requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query='What is the capital of France?',
|
||||||
|
documents=['Paris is the capital.', 'London is a city.', 'France is in Europe.'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
# Scores should be normalized
|
||||||
|
assert results[0]['index'] == 0
|
||||||
|
assert results[0]['relevance_score'] >= 0 and results[0]['relevance_score'] <= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_rerank_normalization(self):
|
||||||
|
"""Test rerank score normalization"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock response with varying scores
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.results = [
|
||||||
|
{'index': 0, 'relevance_score': 0.9},
|
||||||
|
{'index': 1, 'relevance_score': 0.1},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
results = await requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query='test query',
|
||||||
|
documents=['doc1', 'doc2'],
|
||||||
|
)
|
||||||
|
|
||||||
|
# After normalization: 0.9 -> 1.0, 0.1 -> 0.0
|
||||||
|
assert results[0]['relevance_score'] == 1.0
|
||||||
|
assert results[1]['relevance_score'] == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_rerank_single_document(self):
|
||||||
|
"""Test rerank with single document (no normalization needed)"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.results = [
|
||||||
|
{'index': 0, 'relevance_score': 0.5},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
results = await requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query='test query',
|
||||||
|
documents=['doc1'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
# Single score stays as is (min==max, no normalization)
|
||||||
|
assert results[0]['relevance_score'] == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertMessages:
|
||||||
|
"""Test _convert_messages method"""
|
||||||
|
|
||||||
|
def test_convert_simple_message(self):
|
||||||
|
"""Test converting simple text message"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='Hello')]
|
||||||
|
result = requester._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]['role'] == 'user'
|
||||||
|
assert result[0]['content'] == 'Hello'
|
||||||
|
|
||||||
|
def test_convert_message_with_image_base64(self):
|
||||||
|
"""Test converting message with image_base64 content"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
provider_message.Message(
|
||||||
|
role='user',
|
||||||
|
content=[
|
||||||
|
{'type': 'text', 'text': 'What is in this image?'},
|
||||||
|
{'type': 'image_base64', 'image_base64': 'data:image/png;base64,abc123'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = requester._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
content = result[0]['content']
|
||||||
|
assert isinstance(content, list)
|
||||||
|
# Check image_base64 converted to image_url
|
||||||
|
image_part = [p for p in content if p.get('type') == 'image_url'][0]
|
||||||
|
assert 'image_url' in image_part
|
||||||
|
assert image_part['image_url']['url'] == 'data:image/png;base64,abc123'
|
||||||
|
|
||||||
|
def test_convert_message_with_multiple_text_parts(self):
|
||||||
|
"""Test converting message with multiple text parts (LiteLLM handles this)"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
provider_message.Message(
|
||||||
|
role='user',
|
||||||
|
content=[
|
||||||
|
{'type': 'text', 'text': 'Hello'},
|
||||||
|
{'type': 'text', 'text': 'World'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = requester._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
# LiteLLM handles multiple text parts, we pass them through
|
||||||
|
assert isinstance(result[0]['content'], list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanModels:
|
||||||
|
"""Test scan_models method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_models_basic(self):
|
||||||
|
"""Test basic model scanning"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock httpx response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json = Mock(
|
||||||
|
return_value={
|
||||||
|
'data': [
|
||||||
|
{'id': 'gpt-4o'},
|
||||||
|
{'id': 'text-embedding-3-small'},
|
||||||
|
{'id': 'gpt-3.5-turbo'},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client:
|
||||||
|
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||||
|
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await requester.scan_models(api_key='test-key')
|
||||||
|
|
||||||
|
assert 'models' in result
|
||||||
|
assert len(result['models']) == 3
|
||||||
|
# Check LLM models are first
|
||||||
|
assert result['models'][0]['type'] == 'llm'
|
||||||
|
# Check embedding model is detected
|
||||||
|
embedding_models = [m for m in result['models'] if m['type'] == 'embedding']
|
||||||
|
assert len(embedding_models) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_models_no_base_url(self):
|
||||||
|
"""Test scan_models without base_url raises error"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': '',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
await requester.scan_models()
|
||||||
|
|
||||||
|
assert 'Base URL required' in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
@@ -240,6 +240,9 @@ export default function DynamicFormComponent({
|
|||||||
case 'embedding-model-selector':
|
case 'embedding-model-selector':
|
||||||
fieldSchema = z.string();
|
fieldSchema = z.string();
|
||||||
break;
|
break;
|
||||||
|
case 'rerank-model-selector':
|
||||||
|
fieldSchema = z.string();
|
||||||
|
break;
|
||||||
case 'knowledge-base-selector':
|
case 'knowledge-base-selector':
|
||||||
fieldSchema = z.string();
|
fieldSchema = z.string();
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import {
|
|||||||
Bot,
|
Bot,
|
||||||
KnowledgeBase,
|
KnowledgeBase,
|
||||||
EmbeddingModel,
|
EmbeddingModel,
|
||||||
|
RerankModel,
|
||||||
PluginTool,
|
PluginTool,
|
||||||
} from '@/app/infra/entities/api';
|
} from '@/app/infra/entities/api';
|
||||||
import { toast } from 'sonner';
|
import { toast } from 'sonner';
|
||||||
@@ -74,6 +75,7 @@ export default function DynamicFormItemComponent({
|
|||||||
}) {
|
}) {
|
||||||
const [llmModels, setLlmModels] = useState<LLMModel[]>([]);
|
const [llmModels, setLlmModels] = useState<LLMModel[]>([]);
|
||||||
const [embeddingModels, setEmbeddingModels] = useState<EmbeddingModel[]>([]);
|
const [embeddingModels, setEmbeddingModels] = useState<EmbeddingModel[]>([]);
|
||||||
|
const [rerankModels, setRerankModels] = useState<RerankModel[]>([]);
|
||||||
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBase[]>([]);
|
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBase[]>([]);
|
||||||
const [bots, setBots] = useState<Bot[]>([]);
|
const [bots, setBots] = useState<Bot[]>([]);
|
||||||
const [tools, setTools] = useState<PluginTool[]>([]);
|
const [tools, setTools] = useState<PluginTool[]>([]);
|
||||||
@@ -180,6 +182,19 @@ export default function DynamicFormItemComponent({
|
|||||||
}
|
}
|
||||||
}, [config.type]);
|
}, [config.type]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (config.type === DynamicFormItemType.RERANK_MODEL_SELECTOR) {
|
||||||
|
httpClient
|
||||||
|
.getProviderRerankModels()
|
||||||
|
.then((resp) => {
|
||||||
|
setRerankModels(resp.models);
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
toast.error('Failed to load rerank models: ' + err.msg);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [config.type]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (config.type === DynamicFormItemType.MODEL_FALLBACK_SELECTOR) {
|
if (config.type === DynamicFormItemType.MODEL_FALLBACK_SELECTOR) {
|
||||||
fetchLlmModels();
|
fetchLlmModels();
|
||||||
@@ -585,6 +600,45 @@ export default function DynamicFormItemComponent({
|
|||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
case DynamicFormItemType.RERANK_MODEL_SELECTOR:
|
||||||
|
const groupedRerankModels = rerankModels.reduce(
|
||||||
|
(acc, model) => {
|
||||||
|
const providerName = model.provider?.name || 'Unknown';
|
||||||
|
if (!acc[providerName]) acc[providerName] = [];
|
||||||
|
acc[providerName].push(model);
|
||||||
|
return acc;
|
||||||
|
},
|
||||||
|
{} as Record<string, RerankModel[]>,
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-md">
|
||||||
|
<Select
|
||||||
|
value={field.value || '__none__'}
|
||||||
|
onValueChange={(v) => field.onChange(v === '__none__' ? '' : v)}
|
||||||
|
>
|
||||||
|
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
|
||||||
|
<SelectValue placeholder={t('models.rerank')} />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
<SelectItem value="__none__">{t('common.none')}</SelectItem>
|
||||||
|
{Object.entries(groupedRerankModels).map(
|
||||||
|
([providerName, models]) => (
|
||||||
|
<SelectGroup key={providerName}>
|
||||||
|
<SelectLabel>{providerName}</SelectLabel>
|
||||||
|
{models.map((model) => (
|
||||||
|
<SelectItem key={model.uuid} value={model.uuid}>
|
||||||
|
{model.name}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
),
|
||||||
|
)}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
case DynamicFormItemType.MODEL_FALLBACK_SELECTOR: {
|
case DynamicFormItemType.MODEL_FALLBACK_SELECTOR: {
|
||||||
// Separate space models from regular models
|
// Separate space models from regular models
|
||||||
const fbSpaceModels = llmModels.filter(
|
const fbSpaceModels = llmModels.filter(
|
||||||
|
|||||||
@@ -147,15 +147,17 @@ export default function ModelsDialog({
|
|||||||
setLoadingProviders((prev) => new Set(prev).add(providerUuid));
|
setLoadingProviders((prev) => new Set(prev).add(providerUuid));
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
const [llmResp, embeddingResp] = await Promise.all([
|
const [llmResp, embeddingResp, rerankResp] = await Promise.all([
|
||||||
httpClient.getProviderLLMModels(providerUuid),
|
httpClient.getProviderLLMModels(providerUuid),
|
||||||
httpClient.getProviderEmbeddingModels(providerUuid),
|
httpClient.getProviderEmbeddingModels(providerUuid),
|
||||||
|
httpClient.getProviderRerankModels(providerUuid),
|
||||||
]);
|
]);
|
||||||
setProviderModels((prev) => ({
|
setProviderModels((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
[providerUuid]: {
|
[providerUuid]: {
|
||||||
llm: llmResp.models,
|
llm: llmResp.models,
|
||||||
embedding: embeddingResp.models,
|
embedding: embeddingResp.models,
|
||||||
|
rerank: rerankResp.models,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -247,12 +249,18 @@ export default function ModelsDialog({
|
|||||||
abilities,
|
abilities,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.createProviderEmbeddingModel({
|
await httpClient.createProviderEmbeddingModel({
|
||||||
name,
|
name,
|
||||||
provider_uuid: providerUuid,
|
provider_uuid: providerUuid,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
|
} else {
|
||||||
|
await httpClient.createProviderRerankModel({
|
||||||
|
name,
|
||||||
|
provider_uuid: providerUuid,
|
||||||
|
extra_args: extraArgsObj,
|
||||||
|
} as never);
|
||||||
}
|
}
|
||||||
setAddModelPopoverOpen(null);
|
setAddModelPopoverOpen(null);
|
||||||
loadProviderModels(providerUuid, true);
|
loadProviderModels(providerUuid, true);
|
||||||
@@ -341,12 +349,18 @@ export default function ModelsDialog({
|
|||||||
abilities,
|
abilities,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.updateProviderEmbeddingModel(modelId, {
|
await httpClient.updateProviderEmbeddingModel(modelId, {
|
||||||
name,
|
name,
|
||||||
provider_uuid: providerUuid,
|
provider_uuid: providerUuid,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
|
} else {
|
||||||
|
await httpClient.updateProviderRerankModel(modelId, {
|
||||||
|
name,
|
||||||
|
provider_uuid: providerUuid,
|
||||||
|
extra_args: extraArgsObj,
|
||||||
|
} as never);
|
||||||
}
|
}
|
||||||
setEditModelPopoverOpen(null);
|
setEditModelPopoverOpen(null);
|
||||||
loadProviderModels(providerUuid, true);
|
loadProviderModels(providerUuid, true);
|
||||||
@@ -366,8 +380,10 @@ export default function ModelsDialog({
|
|||||||
try {
|
try {
|
||||||
if (modelType === 'llm') {
|
if (modelType === 'llm') {
|
||||||
await httpClient.deleteProviderLLMModel(modelId);
|
await httpClient.deleteProviderLLMModel(modelId);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.deleteProviderEmbeddingModel(modelId);
|
await httpClient.deleteProviderEmbeddingModel(modelId);
|
||||||
|
} else {
|
||||||
|
await httpClient.deleteProviderRerankModel(modelId);
|
||||||
}
|
}
|
||||||
toast.success(t('models.deleteSuccess'));
|
toast.success(t('models.deleteSuccess'));
|
||||||
loadProviderModels(providerUuid, true);
|
loadProviderModels(providerUuid, true);
|
||||||
@@ -407,7 +423,7 @@ export default function ModelsDialog({
|
|||||||
abilities,
|
abilities,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.testEmbeddingModel('_', {
|
await httpClient.testEmbeddingModel('_', {
|
||||||
uuid: '',
|
uuid: '',
|
||||||
name,
|
name,
|
||||||
@@ -415,6 +431,14 @@ export default function ModelsDialog({
|
|||||||
provider: providerData,
|
provider: providerData,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
|
} else {
|
||||||
|
await httpClient.testRerankModel('_', {
|
||||||
|
uuid: '',
|
||||||
|
name,
|
||||||
|
provider_uuid: '',
|
||||||
|
provider: providerData,
|
||||||
|
extra_args: extraArgsObj,
|
||||||
|
} as never);
|
||||||
}
|
}
|
||||||
const duration = Date.now() - startTime;
|
const duration = Date.now() - startTime;
|
||||||
setTestResult({ success: true, duration });
|
setTestResult({ success: true, duration });
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState, useRef, useCallback } from 'react';
|
||||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||||
|
|
||||||
import { zodResolver } from '@hookform/resolvers/zod';
|
import { zodResolver } from '@hookform/resolvers/zod';
|
||||||
@@ -16,19 +16,12 @@ import {
|
|||||||
FormMessage,
|
FormMessage,
|
||||||
} from '@/components/ui/form';
|
} from '@/components/ui/form';
|
||||||
import { Input } from '@/components/ui/input';
|
import { Input } from '@/components/ui/input';
|
||||||
import {
|
|
||||||
Select,
|
|
||||||
SelectContent,
|
|
||||||
SelectGroup,
|
|
||||||
SelectItem,
|
|
||||||
SelectLabel,
|
|
||||||
SelectTrigger,
|
|
||||||
SelectValue,
|
|
||||||
} from '@/components/ui/select';
|
|
||||||
import { DialogFooter } from '@/components/ui/dialog';
|
import { DialogFooter } from '@/components/ui/dialog';
|
||||||
import { toast } from 'sonner';
|
import { toast } from 'sonner';
|
||||||
import { extractI18nObject } from '@/i18n/I18nProvider';
|
import { extractI18nObject } from '@/i18n/I18nProvider';
|
||||||
import { CustomApiError } from '@/app/infra/entities/common';
|
import { CustomApiError } from '@/app/infra/entities/common';
|
||||||
|
import { cn } from '@/lib/utils';
|
||||||
|
import { Check, ChevronDown, Search } from 'lucide-react';
|
||||||
|
|
||||||
const getFormSchema = (t: (key: string) => string) =>
|
const getFormSchema = (t: (key: string) => string) =>
|
||||||
z.object({
|
z.object({
|
||||||
@@ -71,6 +64,10 @@ export default function ProviderForm({
|
|||||||
description: string;
|
description: string;
|
||||||
}[]
|
}[]
|
||||||
>([]);
|
>([]);
|
||||||
|
const [searchQuery, setSearchQuery] = useState('');
|
||||||
|
const [isOpen, setIsOpen] = useState(false);
|
||||||
|
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||||
|
const searchInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
loadRequesters();
|
loadRequesters();
|
||||||
@@ -79,6 +76,54 @@ export default function ProviderForm({
|
|||||||
}
|
}
|
||||||
}, [providerId]);
|
}, [providerId]);
|
||||||
|
|
||||||
|
// Close dropdown when clicking outside
|
||||||
|
useEffect(() => {
|
||||||
|
function handleClickOutside(event: MouseEvent) {
|
||||||
|
if (
|
||||||
|
dropdownRef.current &&
|
||||||
|
!dropdownRef.current.contains(event.target as Node)
|
||||||
|
) {
|
||||||
|
setIsOpen(false);
|
||||||
|
setSearchQuery('');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
document.addEventListener('mousedown', handleClickOutside);
|
||||||
|
return () => document.removeEventListener('mousedown', handleClickOutside);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Focus search input when dropdown opens
|
||||||
|
useEffect(() => {
|
||||||
|
if (isOpen && searchInputRef.current) {
|
||||||
|
searchInputRef.current.focus();
|
||||||
|
}
|
||||||
|
}, [isOpen]);
|
||||||
|
|
||||||
|
// Filter requesters based on search query
|
||||||
|
const filteredRequesters = requesterList.filter(
|
||||||
|
(r) =>
|
||||||
|
r.label.toLowerCase().includes(searchQuery.toLowerCase()) ||
|
||||||
|
r.value.toLowerCase().includes(searchQuery.toLowerCase()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Group filtered requesters by category
|
||||||
|
const groupedRequesters = {
|
||||||
|
builtin: filteredRequesters.filter((r) => r.category === 'builtin'),
|
||||||
|
manufacturer: filteredRequesters.filter(
|
||||||
|
(r) => r.category === 'manufacturer',
|
||||||
|
),
|
||||||
|
maas: filteredRequesters.filter((r) => r.category === 'maas'),
|
||||||
|
'self-hosted': filteredRequesters.filter(
|
||||||
|
(r) => r.category === 'self-hosted',
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
const categoryLabels: Record<string, string> = {
|
||||||
|
builtin: t('models.builtin'),
|
||||||
|
manufacturer: t('models.modelManufacturer'),
|
||||||
|
maas: t('models.aggregationPlatform'),
|
||||||
|
'self-hosted': t('models.selfDeployed'),
|
||||||
|
};
|
||||||
|
|
||||||
async function loadRequesters() {
|
async function loadRequesters() {
|
||||||
const resp = await httpClient.getProviderRequesters();
|
const resp = await httpClient.getProviderRequesters();
|
||||||
setRequesterList(
|
setRequesterList(
|
||||||
@@ -165,17 +210,16 @@ export default function ProviderForm({
|
|||||||
{t('models.requester')}
|
{t('models.requester')}
|
||||||
<span className="text-red-500">*</span>
|
<span className="text-red-500">*</span>
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
<Select
|
<div ref={dropdownRef} className="relative">
|
||||||
onValueChange={(v) => {
|
{/* Trigger button */}
|
||||||
field.onChange(v);
|
<button
|
||||||
const req = requesterList.find((r) => r.value === v);
|
type="button"
|
||||||
if (req && (!providerId || !form.getValues('base_url'))) {
|
onClick={() => setIsOpen(!isOpen)}
|
||||||
form.setValue('base_url', req.defaultUrl);
|
className={cn(
|
||||||
}
|
'flex h-10 w-full items-center justify-between rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50',
|
||||||
}}
|
isOpen && 'ring-2 ring-ring ring-offset-2',
|
||||||
value={field.value}
|
)}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="bg-background">
|
|
||||||
{selectedRequester ? (
|
{selectedRequester ? (
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<img
|
<img
|
||||||
@@ -188,90 +232,102 @@ export default function ProviderForm({
|
|||||||
<span>{selectedRequester.label}</span>
|
<span>{selectedRequester.label}</span>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<SelectValue placeholder={t('models.selectRequester')} />
|
<span className="text-muted-foreground">
|
||||||
|
{t('models.selectRequester')}
|
||||||
|
</span>
|
||||||
)}
|
)}
|
||||||
</SelectTrigger>
|
<ChevronDown
|
||||||
<SelectContent>
|
className={cn(
|
||||||
<SelectGroup>
|
'h-4 w-4 opacity-50 transition-transform',
|
||||||
<SelectLabel>{t('models.builtin')}</SelectLabel>
|
isOpen && 'rotate-180',
|
||||||
{requesterList
|
)}
|
||||||
.filter((r) => r.category === 'builtin')
|
/>
|
||||||
.map((r) => (
|
</button>
|
||||||
<SelectItem key={r.value} value={r.value}>
|
|
||||||
<div className="flex items-center gap-2">
|
{/* Dropdown */}
|
||||||
<img
|
{isOpen && (
|
||||||
src={httpClient.getProviderRequesterIconURL(
|
<div className="absolute z-50 mt-1 w-full rounded-md border bg-popover text-popover-foreground shadow-md animate-in fade-in-0 zoom-in-95">
|
||||||
r.value,
|
{/* Search input */}
|
||||||
)}
|
<div className="flex items-center border-b px-3">
|
||||||
alt={r.label}
|
<Search className="mr-2 h-4 w-4 shrink-0 opacity-50" />
|
||||||
className="h-5 w-5 rounded"
|
<input
|
||||||
/>
|
ref={searchInputRef}
|
||||||
<span>{r.label}</span>
|
type="text"
|
||||||
</div>
|
placeholder={
|
||||||
</SelectItem>
|
t('models.searchProviders') || 'Search providers...'
|
||||||
))}
|
}
|
||||||
</SelectGroup>
|
value={searchQuery}
|
||||||
<SelectGroup>
|
onChange={(e) => setSearchQuery(e.target.value)}
|
||||||
<SelectLabel>{t('models.modelManufacturer')}</SelectLabel>
|
className="flex h-10 w-full rounded-md bg-transparent py-3 text-sm outline-none placeholder:text-muted-foreground"
|
||||||
{requesterList
|
/>
|
||||||
.filter((r) => r.category === 'manufacturer')
|
</div>
|
||||||
.map((r) => (
|
|
||||||
<SelectItem key={r.value} value={r.value}>
|
{/* Options list */}
|
||||||
<div className="flex items-center gap-2">
|
<div className="max-h-[300px] overflow-y-auto p-1">
|
||||||
<img
|
{Object.entries(groupedRequesters).map(
|
||||||
src={httpClient.getProviderRequesterIconURL(
|
([category, items]) => {
|
||||||
r.value,
|
if (items.length === 0) return null;
|
||||||
)}
|
return (
|
||||||
alt={r.label}
|
<div key={category}>
|
||||||
className="h-5 w-5 rounded"
|
<div className="py-1.5 px-2 text-xs font-semibold text-muted-foreground">
|
||||||
/>
|
{categoryLabels[category]}
|
||||||
<span>{r.label}</span>
|
</div>
|
||||||
</div>
|
{items.map((r) => (
|
||||||
</SelectItem>
|
<button
|
||||||
))}
|
key={r.value}
|
||||||
</SelectGroup>
|
type="button"
|
||||||
<SelectGroup>
|
onClick={() => {
|
||||||
<SelectLabel>
|
field.onChange(r.value);
|
||||||
{t('models.aggregationPlatform')}
|
const req = requesterList.find(
|
||||||
</SelectLabel>
|
(req) => req.value === r.value,
|
||||||
{requesterList
|
);
|
||||||
.filter((r) => r.category === 'maas')
|
if (
|
||||||
.map((r) => (
|
req &&
|
||||||
<SelectItem key={r.value} value={r.value}>
|
(!providerId ||
|
||||||
<div className="flex items-center gap-2">
|
!form.getValues('base_url'))
|
||||||
<img
|
) {
|
||||||
src={httpClient.getProviderRequesterIconURL(
|
form.setValue(
|
||||||
r.value,
|
'base_url',
|
||||||
)}
|
req.defaultUrl,
|
||||||
alt={r.label}
|
);
|
||||||
className="h-5 w-5 rounded"
|
}
|
||||||
/>
|
setIsOpen(false);
|
||||||
<span>{r.label}</span>
|
setSearchQuery('');
|
||||||
</div>
|
}}
|
||||||
</SelectItem>
|
className={cn(
|
||||||
))}
|
'flex w-full items-center gap-2 rounded-sm px-2 py-1.5 text-sm outline-none hover:bg-accent hover:text-accent-foreground cursor-pointer',
|
||||||
</SelectGroup>
|
field.value === r.value &&
|
||||||
<SelectGroup>
|
'bg-accent text-accent-foreground',
|
||||||
<SelectLabel>{t('models.selfDeployed')}</SelectLabel>
|
)}
|
||||||
{requesterList
|
>
|
||||||
.filter((r) => r.category === 'self-hosted')
|
<img
|
||||||
.map((r) => (
|
src={httpClient.getProviderRequesterIconURL(
|
||||||
<SelectItem key={r.value} value={r.value}>
|
r.value,
|
||||||
<div className="flex items-center gap-2">
|
)}
|
||||||
<img
|
alt={r.label}
|
||||||
src={httpClient.getProviderRequesterIconURL(
|
className="h-5 w-5 rounded"
|
||||||
r.value,
|
/>
|
||||||
)}
|
<span className="flex-1 text-left">
|
||||||
alt={r.label}
|
{r.label}
|
||||||
className="h-5 w-5 rounded"
|
</span>
|
||||||
/>
|
{field.value === r.value && (
|
||||||
<span>{r.label}</span>
|
<Check className="h-4 w-4" />
|
||||||
</div>
|
)}
|
||||||
</SelectItem>
|
</button>
|
||||||
))}
|
))}
|
||||||
</SelectGroup>
|
</div>
|
||||||
</SelectContent>
|
);
|
||||||
</Select>
|
},
|
||||||
|
)}
|
||||||
|
{filteredRequesters.length === 0 && (
|
||||||
|
<div className="py-6 text-center text-sm text-muted-foreground">
|
||||||
|
No results found.
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
<FormMessage />
|
<FormMessage />
|
||||||
{selectedRequester?.description && (
|
{selectedRequester?.description && (
|
||||||
<p className="text-sm text-muted-foreground">
|
<p className="text-sm text-muted-foreground">
|
||||||
|
|||||||