Source code for camel.benchmarks.apibench

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import json
import logging
import random
from pathlib import Path
from typing import Any, Dict, Literal, Optional

import tree_sitter_python as tspython
from tqdm import tqdm
from tree_sitter import Language, Parser

from camel.agents import ChatAgent
from camel.benchmarks.base import BaseBenchmark
from camel.utils import download_github_subdirectory

logger = logging.getLogger(__name__)


# Mapping of dataset names to file names
# 'Oracle' retriver used here which means all the full
# API documentation will be included in the prompt
dataset_mapping = {
    "huggingface": {
        "api": "huggingface_api.jsonl",
        "eval": "huggingface_eval.json",
        "train": "huggingface_train.json",
        "questions": "questions_huggingface_oracle.jsonl",
    },
    "tensorflowhub": {
        "api": "tensorflowhub_api.jsonl",
        "eval": "tensorflow_eval.json",
        "train": "tensorflow_train.json",
        "questions": "questions_tensorflowhub_oracle.jsonl",
    },
    "torchhub": {
        "api": "torchhub_api.jsonl",
        "eval": "torchhub_eval.json",
        "train": "torchhub_train.json",
        "questions": "questions_torchhub_oracle.jsonl",
    },
}


# This function is migrated from the original repo:
# https://github.com/ShishirPatil/gorilla
[docs] def encode_question(question: str, dataset_name: str) -> str: r"""Encode multiple prompt instructions into a single string.""" if dataset_name == "torchhub": domains = "1. $DOMAIN is inferred from the task description and \ should include one of {Classification, Semantic Segmentation, \ Object Detection, Audio Separation, Video Classification, \ Text-to-Speech}." elif dataset_name == "huggingface": domains = "1. $DOMAIN should include one of {Multimodal Feature \ Extraction, Multimodal Text-to-Image, Multimodal \ Image-to-Text, Multimodal Text-to-Video, \ Multimodal Visual Question Answering, Multimodal Document \ Question Answer, Multimodal Graph Machine Learning, \ Computer Vision Depth Estimation, Computer Vision Image \ Classification, Computer Vision Object Detection, \ Computer Vision Image Segmentation, Computer Vision \ Image-to-Image, Computer Vision Unconditional \ Image Generation, Computer Vision Video Classification, \ Computer Vision Zero-Shor Image Classification, \ Natural Language Processing Text Classification, \ Natural Language Processing Token Classification, \ Natural Language Processing Table Question Answering, \ Natural Language Processing Question Answering, \ Natural Language Processing, Zero-Shot Classification \ Natural Language Processing Translation, Natural Language \ Processing Summarization, Natural Language Processing \ Conversational, Natural Language Processing Text \ Generation, Natural Language Processing Fill-Mask, \ Natural Language Processing Text2Text Generation, \ Natural Language Processing Sentence Similarity, \ Audio Text-to-Speech, Audio Automatic Speech Recognition, \ Audio Audio-to-Audio, Audio Audio Classification, \ Audio Voice Activity Detection, Tabular Tabular \ Classification, Tabular Tabular Regression, \ Reinforcement Learning Reinforcement Learning, \ Reinforcement Learning Robotics }" elif dataset_name == "tensorflowhub": domains = "1. $DOMAIN is inferred from the task description \ and should include one of {text-sequence-alignment, \ text-embedding, text-language-model, text-preprocessing, \ text-classification, text-generation, text-question-answering, \ text-retrieval-question-answering, text-segmentation, \ text-to-mel, image-classification, image-feature-vector, \ image-object-detection, image-segmentation, \ image-generator, image-pose-detection, image-rnn-agent, \ image-augmentation, image-classifier, image-style-transfer, \ image-aesthetic-quality, image-depth-estimation, \ image-super-resolution, image-deblurring, image-extrapolation, \ image-text-recognition, image-dehazing, image-deraining, \ image-enhancemenmt, image-classification-logits, \ image-frame-interpolation, image-text-detection, image-denoising, \ image-others, video-classification, video-feature-extraction, \ video-generation, video-audio-text, video-text, \ audio-embedding, audio-event-classification, audio-command-detection, \ audio-paralinguists-classification, audio-speech-to-text, \ audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}" else: logger.info("Error: API name is not supported.") prompt = ( question + "\nWrite a python program in 1 to 2 lines to call API in " + dataset_name + ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \ <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \ <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \ Here are the requirements:\n" + domains + "\n2. The $API_CALL should have only 1 line of code \ that calls api.\n 3. The $API_PROVIDER should be the \ programming framework used.\n4. $EXPLANATION should be \ a step-by-step explanation.\n5. The $CODE is the python code.\n6. \ Do not repeat the format in your answer." ) return prompt
[docs] class APIBenchBenchmark(BaseBenchmark): r"""APIBench Benchmark adopted from `Gorilla: Large Language Model Connected with Massive APIs` <https://huggingface.co/datasets/gorilla-llm/APIBench>. Args: data_dir (str): The directory to save the data. save_to (str): The file to save the results. processes (int, optional): The number of processes to use. (default: :obj:`1`) """ # TODO: Integrate retriever (pending) def __init__( self, data_dir: str, save_to: str, processes: int = 1, ): r"""Initialize the APIBench benchmark. Args: data_dir (str): The directory to save the data. save_to (str): The file to save the results. processes (int, optional): The number of processes to use for parallel processing. (default: :obj:`1`) """ super().__init__("apibench", data_dir, save_to, processes)
[docs] def download(self): r"""Download the APIBench dataset.""" from huggingface_hub import snapshot_download snapshot_download( repo_id="gorilla-llm/APIBench", repo_type="dataset", local_dir=self.data_dir, local_dir_use_symlinks=True, ) repo = "ShishirPatil/gorilla" subdir = "/gorilla/eval/eval-data/questions" data_dir = self.data_dir download_github_subdirectory(repo, subdir, data_dir)
[docs] def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override] r"""Load the APIBench Benchmark dataset. Args: dataset_name (str): Name of the specific dataset to be loaded. force_download (bool, optional): Whether to force download the data. (default: :obj:`False`) """ if force_download: logger.info("Force downloading data.") self.download() def load_json_lines(file_path: Path): r"""Helper function to load JSON lines from a file.""" try: with open(file_path, "r") as f: return [json.loads(line) for line in f] except FileNotFoundError: raise FileNotFoundError(f"File not found: {file_path}") except json.JSONDecodeError as e: raise ValueError( f"Error decoding JSON in file {file_path}: {e}" ) dataset_path = self.data_dir / dataset_name if not dataset_path.exists(): raise FileNotFoundError( f"Dataset directory does not exist: {dataset_path}" ) for label in ['api', 'eval', 'questions']: file_name = dataset_mapping[dataset_name][label] file_path = ( dataset_path / file_name if label == 'questions' else self.data_dir / file_name ) # Load data based on label type if label in ['api', 'questions', 'eval']: data = load_json_lines(file_path) if label == 'eval': # Extract 'api_data' specifically for eval label data = [item['api_data'] for item in data] self._data[label] = data else: raise ValueError(f"Unknown label: {label}") ast_database = [] for data in self._data['api']: ast_tree = ast_parse(data['api_call']) ast_database.append(ast_tree) self._data['ast'] = ast_database
[docs] def run( # type: ignore[override] self, agent: ChatAgent, dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"], randomize: bool = False, subset: Optional[int] = None, ) -> Dict[str, Any]: r"""Run the benchmark. Args: agent (ChatAgent): The agent to run the benchmark. dataset_name (Literal["huggingface", "tensorflowhub", "torchhub"]): The dataset to run the benchmark. randomize (bool, optional): Whether to randomize the data. (default: :obj:`False`) subset (Optional[int], optional): The subset of data to run. (default: :obj:`None`) """ if dataset_name not in dataset_mapping: raise ValueError(f"Invalid value for dataset: {dataset_name}.") logger.info(f"Running APIBench benchmark on {dataset_name}.") self.load(dataset_name) datas = self._data['questions'] # Shuffle and subset data if necessary if randomize: random.shuffle(datas) if subset: datas = datas[:subset] logger.info(f"Number of tasks: {len(datas)}") # Initialize results storage self._results = [] with open(self.save_to, "w") as f: for question in tqdm(datas, desc="Running"): prompt = encode_question(question["text"], dataset_name) try: # Generate response responses = agent.step(prompt) response = responses.msgs[0].content api_database = self._data['api'] qa_pairs = self._data['eval'] ast_database = self._data['ast'] question_id = question['question_id'] # Evaluate response error, correct, hallucination = evaluate_response( response, question_id, dataset_name, api_database, qa_pairs, ast_database, ) self._results.append( { "question": question, "agent_response": response, "correct": correct, "hallucination": hallucination, "error": str(error) if error else None, } ) except Exception as e: logger.warning( f"Error in processing task: {question}: {e}" ) self._results.append( { "question": question, "agent_response": None, "correct": False, "hallucination": False, "error": str(e), } ) agent.reset() json_str = json.dumps( self._results[-1], indent=2, ensure_ascii=False ) f.write(json_str + "\n") f.flush() total = len(self._results) correct = sum(r["correct"] for r in self.results) hallucination = sum(r["hallucination"] for r in self.results) return { "total": total, "correct": correct, "hallucination": hallucination, "accuracy": correct / total if total else "N/A", "hallucination rate": hallucination / total if total else "N/A", }
# This code is modified from the # evaluators in the original repo # https://github.com/ShishirPatil/gorilla # Get all the subtrees given a root_node
[docs] def get_all_sub_trees(root_node): node_stack = [] sub_tree_sexp_list = [] depth = 1 # text = root_node.text node_stack.append([root_node, depth]) while len(node_stack) != 0: cur_node, cur_depth = node_stack.pop() if cur_node.child_count > 0: sub_tree_sexp_list.append( [ str(cur_node), cur_depth, cur_node, cur_node.children[0].text, ] ) else: sub_tree_sexp_list.append( [str(cur_node), cur_depth, cur_node, None] ) for child_node in cur_node.children: if len(child_node.children) != 0: depth = cur_depth + 1 node_stack.append([child_node, depth]) return sub_tree_sexp_list
# Parse the program into AST trees
[docs] def ast_parse(candidate): PY_LANGUAGE = Language(tspython.language()) parser = Parser(PY_LANGUAGE) candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node return candidate_tree
# Get all the arguments in the ast tree
[docs] def get_args(node, dataset_name): if node.child_count == 0: return [] args_list = [] if dataset_name == "huggingface": for child in node.children[0].children[0].children[1].children: if "=" in child.text.decode(): args_list.append(child.children[2].text) elif ( child.text.decode() != "(" and child.text.decode() != ")" and child.text.decode() != "," ): args_list.append(child.text) elif dataset_name == "tensorflowhub": for child in node.children[0].children[0].children[1].children: if ( 'model=' in child.text.decode() or 'model =' in child.text.decode() ): args_list.append(child.children[2].text) elif ( child.text.decode() != "(" and child.text.decode() != ")" and child.text.decode() != "," ): args_list.append(child.text) elif dataset_name == "torchhub": for child in node.children[0].children[0].children[1].children: if ( "repo_or_dir" in child.text.decode() or "model" in child.text.decode() ): args_list.append(child.children[2].text) return args_list
# Check if there is an api match
[docs] def ast_check(candidate_subtree_list, base_tree_list, dataset_name): for idx, base_tree in enumerate(base_tree_list): if base_tree.children[0].children[0].child_count == 0: continue api_name = base_tree.children[0].children[0].children[0].text for candidate_tree in candidate_subtree_list: if candidate_tree[3] == api_name: break # Now we have a sub-tree candidate_tree = candidate_tree[2] args_list = get_args(base_tree, dataset_name) if len(args_list) == 0: continue ast_match = True for arg in args_list: if ( arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode() ): ast_match = False break if ast_match: return idx return -1
[docs] def evaluate_response( response, question_id, dataset_name, api_database, qa_pairs, ast_database ): try: # Index the "api_call" domain output = response.split("api_call") if len(output) == 1: api_call = output[0] else: # Parse the output output = output[1].split("api_provider")[0] if ":" not in output: start = 0 else: start = output.index(":") if ")" not in output: end = -2 else: end = output.rindex(")") api_call = output[start + 2 : end + 1] try: ast_tree = ast_parse(api_call) except Exception as parse_error: print(f"Error parsing api_call: {api_call}, error: {parse_error}") return parse_error, False, False # Search for a subtree ast_subtree_list = get_all_sub_trees(ast_tree) # Check which ast tree is matching database_index = ast_check( ast_subtree_list, ast_database, dataset_name ) # We cannot index this ast in our database if database_index == -1: halluncination = True correct = False # We index our reference api_call ref_api_call = api_database[database_index] # Check for functionality if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']: correct = True halluncination = False else: return None, False, False except Exception as e: print(f'Error parsing response: {response}, error: {e}') return e, False, False return None, correct, halluncination