encode_question

def encode_question(question: str, dataset_name: str):

Encode multiple prompt instructions into a single string.

APIBenchBenchmark

class APIBenchBenchmark(BaseBenchmark):

APIBench Benchmark adopted from Gorilla: Large Language Model Connected with Massive APIs <https://huggingface.co/datasets/gorilla-llm/APIBench>.

Parameters:

  • data_dir (str): The directory to save the data.
  • save_to (str): The file to save the results.
  • processes (int, optional): The number of processes to use. (default: :obj:1)

init

def __init__(
    self,
    data_dir: str,
    save_to: str,
    processes: int = 1
):

Initialize the APIBench benchmark.

Parameters:

  • data_dir (str): The directory to save the data.
  • save_to (str): The file to save the results.
  • processes (int, optional): The number of processes to use for parallel processing. (default: :obj:1)

download

def download(self):

Download the APIBench dataset.

load

def load(self, dataset_name: str, force_download: bool = False):

Load the APIBench Benchmark dataset.

Parameters:

  • dataset_name (str): Name of the specific dataset to be loaded.
  • force_download (bool, optional): Whether to force download the data. (default: :obj:False)

run

def run(
    self,
    agent: ChatAgent,
    dataset_name: Literal['huggingface', 'tensorflowhub', 'torchhub'],
    randomize: bool = False,
    subset: Optional[int] = None
):

Run the benchmark.

Parameters:

  • agent (ChatAgent): The agent to run the benchmark. dataset_name (Literal[“huggingface”, “tensorflowhub”, “torchhub”]): The dataset to run the benchmark.
  • randomize (bool, optional): Whether to randomize the data. (default: :obj:False)
  • subset (Optional[int], optional): The subset of data to run. (default: :obj:None)

get_all_sub_trees

def get_all_sub_trees(root_node):

ast_parse

def ast_parse(candidate):

get_args

def get_args(node, dataset_name):

ast_check

def ast_check(candidate_subtree_list, base_tree_list, dataset_name):

evaluate_response

def evaluate_response(
    response,
    question_id,
    dataset_name,
    api_database,
    qa_pairs,
    ast_database
):