feat: 对 claude api 的基本支持

This commit is contained in:
RockChinQ
2024-03-17 12:44:45 -04:00
parent 550a131685
commit 1dae7bd655
5 changed files with 92 additions and 4 deletions

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import typing
import anthropic
from .. import api, entities, errors
from .. import api, entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("anthropic-messages")
class AnthropicMessages(api.LLMAPIRequester):
"""Anthropic Messages API 请求器"""
client: anthropic.AsyncAnthropic
async def initialize(self):
self.client = anthropic.AsyncAnthropic(
api_key="",
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
timeout=self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout'],
proxies=self.ap.proxy_mgr.get_forward_proxies()
)
async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
self.client.api_key = query.use_model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages
] + [m.dict(exclude_none=True) for m in query.messages]
# 删除所有 role=system & content='' 的消息
req_messages = [
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
# 检查是否有 role=system 的消息,若有,改为 role=user并在后面加一个 role=assistant 的消息
system_role_index = []
for i, m in enumerate(req_messages):
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"
if system_role_index:
for i in system_role_index[::-1]:
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
args["messages"] = req_messages
resp = await self.client.messages.create(**args)
yield llm_entities.Message(
content=resp.content[0].text,
role=resp.role
)

View File

@@ -9,8 +9,6 @@ import openai
import openai.types.chat.chat_completion as chat_completion
import httpx
from pkg.provider.entities import Message
from .. import api, entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
@@ -127,7 +125,7 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[Message, None]:
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try:
async for msg in self._request(query):
yield msg

View File

@@ -6,7 +6,7 @@ from . import entities
from ...core import app
from . import token, api
from .apis import chatcmpl
from .apis import chatcmpl, anthropicmsgs
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"