mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 08:16:03 +00:00
feat: 基本架构
This commit is contained in:
0
pkg/openai/__init__.py
Normal file
0
pkg/openai/__init__.py
Normal file
30
pkg/openai/manager.py
Normal file
30
pkg/openai/manager.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import openai
|
||||
|
||||
inst = None
|
||||
|
||||
|
||||
class OpenAIInteract:
|
||||
api_key = ''
|
||||
api_params = {}
|
||||
|
||||
def __init__(self, api_key: str, api_params: dict):
|
||||
self.api_key = api_key
|
||||
self.api_params = api_params
|
||||
|
||||
openai.api_key = self.api_key
|
||||
|
||||
global inst
|
||||
inst = self
|
||||
|
||||
def request_completion(self, prompt, stop):
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
**self.api_params
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def get_inst() -> OpenAIInteract:
|
||||
global inst
|
||||
return inst
|
||||
52
pkg/openai/session.py
Normal file
52
pkg/openai/session.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import time
|
||||
|
||||
import pkg.openai.manager
|
||||
|
||||
|
||||
session = {}
|
||||
|
||||
|
||||
# 通用的OpenAI API交互session
|
||||
class Session:
|
||||
name = ''
|
||||
|
||||
prompt = ''
|
||||
|
||||
user_name = 'You'
|
||||
bot_name = 'Bot'
|
||||
|
||||
create_timestamp = 0
|
||||
|
||||
last_interact_timestamp = 0
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.create_timestamp = int(time.time())
|
||||
|
||||
global session
|
||||
session[name] = self
|
||||
|
||||
# 请求回复
|
||||
# 这个函数是阻塞的
|
||||
def append(self, text: str) -> str:
|
||||
self.prompt += self.user_name + ':' + text + '\n'+self.bot_name+':'
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# 向API请求补全
|
||||
response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name+':')
|
||||
|
||||
# 处理回复
|
||||
res_test = response["choices"][0]["text"]
|
||||
res_ans = res_test
|
||||
|
||||
# 去除开头可能的提示
|
||||
res_ans_spt = res_test.split("\n\n")
|
||||
if len(res_ans_spt) > 1:
|
||||
del (res_ans_spt[0])
|
||||
res_ans = '\n\n'.join(res_ans_spt)
|
||||
|
||||
self.prompt += "\n" + self.bot_name + ":{}".format(res_ans)
|
||||
return res_ans
|
||||
|
||||
def persistence(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user