mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
114 lines
3.3 KiB
Python
114 lines
3.3 KiB
Python
import re
|
|
import inspect
|
|
|
|
|
|
def get_func_schema(function: callable) -> dict:
|
|
"""
|
|
Return the data schema of a function.
|
|
{
|
|
"function": function,
|
|
"description": "function description",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"parameter_a": {
|
|
"type": "str",
|
|
"description": "parameter_a description"
|
|
},
|
|
"parameter_b": {
|
|
"type": "int",
|
|
"description": "parameter_b description"
|
|
},
|
|
"parameter_c": {
|
|
"type": "str",
|
|
"description": "parameter_c description",
|
|
"enum": ["a", "b", "c"]
|
|
},
|
|
},
|
|
"required": ["parameter_a", "parameter_b"]
|
|
}
|
|
}
|
|
"""
|
|
func_doc = function.__doc__
|
|
# Google Style Docstring
|
|
if func_doc is None:
|
|
raise Exception('Function {} has no docstring.'.format(function.__name__))
|
|
func_doc = func_doc.strip().replace(' ', '').replace('\t', '')
|
|
# extract doc of args from docstring
|
|
doc_spt = func_doc.split('\n\n')
|
|
desc = doc_spt[0]
|
|
args = doc_spt[1] if len(doc_spt) > 1 else ''
|
|
# returns = doc_spt[2] if len(doc_spt) > 2 else ""
|
|
|
|
# extract args
|
|
# delete the first line of args
|
|
arg_lines = args.split('\n')[1:]
|
|
# arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
|
|
args_doc = {}
|
|
for arg_line in arg_lines:
|
|
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
|
|
if len(doc_tuple) == 0:
|
|
continue
|
|
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
|
|
|
|
# extract returns
|
|
# return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
|
|
|
|
params = enumerate(inspect.signature(function).parameters.values())
|
|
parameters = {
|
|
'type': 'object',
|
|
'required': [],
|
|
'properties': {},
|
|
}
|
|
|
|
for i, param in params:
|
|
# 排除 self, query
|
|
if param.name in ['self', 'query']:
|
|
continue
|
|
|
|
param_type = param.annotation.__name__
|
|
|
|
type_name_mapping = {
|
|
'str': 'string',
|
|
'int': 'integer',
|
|
'float': 'number',
|
|
'bool': 'boolean',
|
|
'list': 'array',
|
|
'dict': 'object',
|
|
}
|
|
|
|
if param_type in type_name_mapping:
|
|
param_type = type_name_mapping[param_type]
|
|
|
|
parameters['properties'][param.name] = {
|
|
'type': param_type,
|
|
'description': args_doc[param.name],
|
|
}
|
|
|
|
# add schema for array
|
|
if param_type == 'array':
|
|
# extract type of array, the int of list[int]
|
|
# use re
|
|
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
|
|
|
|
array_type = 'string'
|
|
|
|
if len(array_type_tuple) > 0:
|
|
array_type = array_type_tuple[0]
|
|
|
|
if array_type in type_name_mapping:
|
|
array_type = type_name_mapping[array_type]
|
|
|
|
parameters['properties'][param.name]['items'] = {
|
|
'type': array_type,
|
|
}
|
|
|
|
if param.default is inspect.Parameter.empty:
|
|
parameters['required'].append(param.name)
|
|
|
|
return {
|
|
'function': function,
|
|
'description': desc,
|
|
'parameters': parameters,
|
|
}
|