Source code for camel.environments.single_step

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import random
from typing import Any, Dict, List, Optional, Tuple, Union

from camel.datasets import BaseGenerator, DataPoint, StaticDataset
from camel.logger import get_logger
from camel.verifiers.base import (
    BaseVerifier,
    VerificationResult,
)

from .models import Action, Observation, StepResult

logger = get_logger(__name__)


[docs] class SingleStepEnv: r"""A lightweight environment for single-step RL with LLMs as policy. This environment models a single interaction between an LLM-based agent and a problem drawn from a dataset—such as a question-answering or math problem—where the agent produces one response and receives feedback. Core Flow: - A question is sampled from a (possibly infinitely long) dataset. - The LLM generates a single-step response (the action). - The response is verified against the ground truth. - A reward is computed based on correctness and optional custom logic. Key Features: - Batched evaluation with per-sample state tracking. - Async setup and teardown for verifiers and related resources. - Supports deterministic sampling via local RNG (optional seed). - Extensible reward computation via subclassing. """ PLACEHOLDER_OBS = Observation( question="Episode ended. This is just a placeholder." ) ACCURACY_REWARD = 10 def __init__( self, dataset: Union[StaticDataset, BaseGenerator], verifier: BaseVerifier, **kwargs, ) -> None: r"""Initialize the SingleStepEnv. Args: dataset (Union[StaticDataset, BaseGenerator]): Dataset to sample problems from. verifier (BaseVerifier): Verifier used to evaluate LLM responses against ground-truth answers. **kwargs: Optional metadata or configuration values. Notes: This class assumes all interactions are single-step: one question, one LLM response, one reward. """ self.dataset = dataset self.verifier = verifier self._metadata = kwargs # State tracking self._is_setup: bool = False self._states: List[DataPoint] = [] self._states_done: List[bool] = [] self.current_batch_size: int = 0
[docs] async def setup(self) -> None: r"""Set up the environment by initializing the verifier. This method ensures that the environment is ready for interaction. It sets up necessary components, including the verifier. Raises: Exception: If setup fails due to an internal error. """ if self._is_setup: logger.warning("Environment has already been set up") return try: await self.verifier.setup() self._is_setup = True logger.info('Environment setup completed successfully') except Exception as e: logger.error(f'Failed to setup environment: {e}') raise
[docs] async def close(self) -> None: r"""Clean up and close all resources used by the environment. This method shuts down the verifier, resets the internal state, and ensures that the environment is properly closed. Raises: Exception: If an error occurs while closing the environment. """ if not self._is_setup: logger.warning( "Not closing environment - has not been set up yet." ) return try: self._is_setup = False await self.verifier.cleanup() self._states = [] self._states_done = [] self.current_batch_size = 0 logger.info('Environment closed successfully') except Exception as e: logger.error(f'Failed to close environment: {e}') raise
[docs] async def reset( self, batch_size: int = 1, seed: Optional[int] = None ) -> Union[Observation, List[Observation]]: r"""Resets the environment and starts a new episode. This method samples a new batch of data points from the dataset and returns the corresponding initial observations. If a seed is provided, a local random number generator is initialized for deterministic sampling. The global random state is not affected. Args: batch_size (int): Number of data points to sample. (default: :obj:`1`) seed (Optional[int]): Seed for deterministic sampling. If None, sampling is non-deterministic. (default: :obj:`None`) Returns: Observation or List[Observation]: Initial observation(s) for the episode. Raises: RuntimeError: If called before all previous states are processed. ValueError: If batch size exceeds dataset size. TypeError: If the dataset is of an unsupported type. """ if batch_size <= 0: raise ValueError("Batch size must be positive") if not self._is_setup: logger.warning( "reset() called on un-setup environment. Setting up..." ) await self.setup() if self._batch_started() and not self._batch_done(): logger.error( "Reset called before all states were processed. " "Call step on remaining states first." ) raise RuntimeError( "reset() called before all states in batch were processed." ) if seed is not None: rng = random.Random(seed) else: rng = random.Random() if isinstance(self.dataset, StaticDataset): dataset_len = len(self.dataset) if batch_size > dataset_len: raise ValueError( f"Batch size {batch_size} is too large for dataset " f"of size {dataset_len}" ) start_idx = rng.randint(0, dataset_len - batch_size) idx_slice = slice(start_idx, start_idx + batch_size) val = self.dataset[idx_slice] self._states = [val] if isinstance(val, DataPoint) else val self.current_batch_size = len(self._states) self._states_done = [False] * self.current_batch_size observations = [ Observation(question=sample.question, context={}, metadata={}) for sample in self._states ] return observations[0] if batch_size == 1 else observations elif isinstance(self.dataset, BaseGenerator): raise NotImplementedError( "Reset not yet implemented for BaseGenerator datasets." ) else: raise TypeError(f"Unsupported dataset type: {type(self.dataset)}")
[docs] async def step( self, action: Union[Action, List[Action]] ) -> Union[ Tuple[Observation, float, bool, Dict[str, Any]], List[Tuple[Observation, float, bool, Dict[str, Any]]], ]: r"""Process actions for a subset of states and update their finished status. Args: action: Single action (for batch_size=1 or micro-batch of size 1) or list of actions (for batch_size>=2 with multiple actions). Each action must have an index for batch_size>=2, indicating which state it corresponds to. Returns: Union[StepResult, List[StepResult]]: StepResult or list of StepResults for the processed states. Raises: RuntimeError: If environment isn't set up or episode has ended. ValueError: If indices are invalid, duplicate, or correspond to finished states. """ if not self._is_setup: raise RuntimeError("Environment not set up. Call setup() first.") if self._batch_done(): raise RuntimeError( "Episodes have ended for batch. Call reset() first." ) if not self._states: raise RuntimeError("No current observation. Call reset() first.") # Normalize actions into a list for uniform processing if self.current_batch_size == 1: if isinstance(action, list): if len(action) != 1 or not isinstance(action[0], Action): raise ValueError( "For batch_size=1, expect a single Action or a " "list containing exactly one Action" ) elif not isinstance(action, Action): raise ValueError( "For batch_size=1, expect a single Action or a " "list containing exactly one Action" ) if isinstance(action, Action): actions = [action] else: actions = action if actions[0].index is None: actions[0].index = 0 if actions[0].index != 0: raise ValueError("For batch_size=1, index must be None or 0") else: # batch_size >= 2 if isinstance(action, Action): if action.index is None: raise ValueError( "For batch_size>=2, each Action must have an index" ) if not isinstance(action.index, int): raise ValueError("Index must be an integer") actions = [action] elif isinstance(action, list): if not action: # Empty list raise ValueError("Action list cannot be empty") actions = action for act in actions: if not isinstance(act, Action): raise ValueError( "All elements in list must be Action objects" ) if act.index is None: raise ValueError( "For batch_size>=2, each Action must have an index" ) if not isinstance(act.index, int): raise ValueError("Index must be an integer") else: raise ValueError( "For batch_size>=2, expect an Action or list of Actions" ) # Validate indices indices: List[int] = [] for act in actions: assert act.index is not None indices.append(act.index) if len(set(indices)) != len(indices): raise ValueError("Duplicate state indices in actions.") for idx in indices: if idx < 0 or idx >= len(self._states): raise ValueError(f"Invalid state index {idx}.") if self._states_done[idx]: raise ValueError(f"State at index {idx} is already finished.") num_actions = len(actions) if self.current_batch_size % num_actions != 0: logger.warning( f"Number of actions ({num_actions}) is not a divisor of " f"total batch size ({self.current_batch_size})" ) proposed_solutions = [act.llm_response for act in actions] ground_truths: List[str] = [] for idx in indices: ground_truths.append(self._states[idx].final_answer) verification_results = await self.verifier.verify_batch( solutions=proposed_solutions, ground_truths=ground_truths, # type: ignore [arg-type] raise_on_error=True, ) total_rewards, rewards_dicts = await self._compute_reward_batch( proposed_solutions, verification_results ) # TODO Batch this step_results = [] for i, action in enumerate(actions): assert action.index is not None idx = action.index step_result = StepResult( observation=self.PLACEHOLDER_OBS, reward=total_rewards[i], rewards_dict=rewards_dicts[i], done=True, info={ "proposed_solution": proposed_solutions[i], "verification_result": verification_results[i], "state": self._states[idx], }, ) step_results.append(step_result.as_tuple()) self._states_done[idx] = True return step_results[0] if len(step_results) == 1 else step_results
async def _compute_reward_batch( self, proposed_solutions: List[str], verification_results: List[VerificationResult], ) -> Tuple[List[float], List[Dict[str, float]]]: r"""Compute rewards for a batch of proposed solutions based on verification results. Args: proposed_solutions (List[str]): List of LLM-generated responses to evaluate. verification_results (List[VerificationResult]): List of verification outcomes for each solution. Returns: Tuple containing: - List of total rewards for each solution. - List of reward component dictionaries for each solution. """ if len(proposed_solutions) != len(verification_results): raise ValueError( f"Length mismatch: {len(proposed_solutions)} solutions vs " f"{len(verification_results)} verification results" ) total_rewards = [] rewards_dicts = [] for solution, verification_result in zip( proposed_solutions, verification_results ): rewards: Dict[str, float] = {} rewards["correctness"] = ( self.ACCURACY_REWARD if verification_result.status else 0.0 ) further_rewards = await self._compute_custom_reward( solution, verification_result ) rewards = {**rewards, **further_rewards} total_reward = sum(rewards.values()) total_rewards.append(total_reward) rewards_dicts.append(rewards) return total_rewards, rewards_dicts async def _compute_custom_reward( self, proposed_solution: str, verification_result: VerificationResult ) -> Dict[str, float]: r"""Compute additional custom reward components for a single solution. To be overridden by subclasses for domain-specific rewards. Args: proposed_solution (str): The LLM-generated response. verification_result (VerificationResult): The verification outcome. Returns: Dict[str, float]: Dictionary of custom reward components. """ return {} def _batch_done(self) -> bool: r"""Check if all states in the current batch are done. Returns: bool: True if all states are marked as done, False otherwise. """ return all(self._states_done) def _batch_started(self) -> bool: r"""Check if any state in the current batch is done. Returns: bool: True if at least one state is marked as done, False otherwise. """ return any(self._states_done) @property def metadata(self) -> Dict[str, Any]: r"""Retrieve the metadata of the environment. This provides additional parameters and configuration details. Returns: Dict[str, Any]: A copy of the environment's metadata. """ return self._metadata.copy()