Source code for camel.benchmarks.base

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional

from camel.agents import ChatAgent

logger = logging.getLogger(__name__)


[docs] class BaseBenchmark(ABC): r"""Base class for benchmarks. Attributes: name (str): Name of the benchmark. data_dir (str): Path to the data directory. save_to (str): Path to save the results. processes (int): Number of processes to use for parallel processing. :(default: :obj:`1`) """ def __init__( self, name: str, data_dir: str, save_to: str, processes: int = 1 ): r"""Initialize the benchmark. Args: name (str): Name of the benchmark. data_dir (str): Path to the data directory. save_to (str): Path to save the results. processes (int): Number of processes to use for parallel processing. :(default: :obj:`1`) """ self.name = name self.data_dir = Path(data_dir) self.processes = processes self.save_to = save_to if not self.data_dir.exists(): logger.info( f"Data directory {data_dir} does not exist. Creating it." ) self.data_dir.mkdir(parents=True, exist_ok=True) if not self.data_dir.is_dir(): raise NotADirectoryError( f"Data directory {data_dir} is not a directory" ) self._data: Dict[str, List[Dict[str, Any]]] = dict() self._results: List[Dict[str, Any]] = []
[docs] @abstractmethod def download(self) -> "BaseBenchmark": r"""Download the benchmark data. Returns: BaseBenchmark: The benchmark instance. """ pass
[docs] @abstractmethod def load(self, force_download: bool = False) -> "BaseBenchmark": r"""Load the benchmark data. Args: force_download (bool): Whether to force download the data. Returns: BaseBenchmark: The benchmark instance. """ pass
@property def train(self) -> List[Dict[str, Any]]: r"""Get the training data. Returns: List[Dict[str, Any]]: The training data. """ if not self._data: logger.info("Data not loaded. Loading data.") self.load() return self._data["train"] @property def valid(self) -> List[Dict[str, Any]]: r"""Get the validation data. Returns: List[Dict[str, Any]]: The validation data. """ if not self._data: logger.info("Data not loaded. Loading data.") self.load() return self._data["valid"] @property def test(self) -> List[Dict[str, Any]]: r"""Get the test data. Returns: List[Dict[str, Any]]: The test data. """ if not self._data: logger.info("Data not loaded. Loading data.") self.load() return self._data["test"]
[docs] @abstractmethod def run( self, agent: ChatAgent, on: Literal["train", "valid", "test"], randomize: bool = False, subset: Optional[int] = None, *args, **kwargs, ) -> "BaseBenchmark": r"""Run the benchmark. Args: agent (ChatAgent): The chat agent. on (str): The data split to run the benchmark on. randomize (bool): Whether to randomize the data. subset (int): The subset of the data to run the benchmark on. Returns: BaseBenchmark: The benchmark instance. """ pass
@property def results(self) -> List[Dict[str, Any]]: r"""Get the results. Returns: List[Dict[str, Any]]: The results. """ return self._results