Source code for camel.data_collector.base

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from uuid import UUID

from typing_extensions import Self

from camel.agents import ChatAgent


[docs] class CollectorData: def __init__( self, id: UUID, name: str, role: Literal["user", "assistant", "system", "tool"], message: Optional[str] = None, function_call: Optional[Dict[str, Any]] = None, ) -> None: r"""Create a data item store information about a message. Used by the data collector. Args: id (UUID): The id of the message. name (str): The name of the agent. role (Literal["user", "assistant", "system", "function"]): The role of the message. message (Optional[str], optional): The message. (default: :obj:`None`) function_call (Optional[Dict[str, Any]], optional): The function call. (default: :obj:`None`) Raises: ValueError: If the role is not supported. ValueError: If the role is system and function call is provided. ValueError: If neither message nor function call is provided. """ if role not in ["user", "assistant", "system", "tool"]: raise ValueError(f"Role {role} not supported") if role == "system" and function_call: raise ValueError("System role cannot have function call") if not message and not function_call: raise ValueError( "Either message or function call must be provided" ) self.id = id self.name = name self.role = role self.message = message self.function_call = function_call
[docs] @staticmethod def from_context(name, context: Dict[str, Any]) -> "CollectorData": r"""Create a data collector from a context. Args: name (str): The name of the agent. context (Dict[str, Any]): The context. Returns: CollectorData: The data collector. """ return CollectorData( id=uuid.uuid4(), name=name, role=context["role"], message=context["content"], function_call=context.get("tool_calls", None), )
[docs] class BaseDataCollector(ABC): r"""Base class for data collectors.""" def __init__(self) -> None: r"""Create a data collector.""" self.history: List[CollectorData] = [] self._recording = False self.agents: List[Tuple[str, ChatAgent]] = [] self.data: List[Dict[str, Any]] = []
[docs] def step( self, role: Literal["user", "assistant", "system", "tool"], name: Optional[str] = None, message: Optional[str] = None, function_call: Optional[Dict[str, Any]] = None, ) -> Self: r"""Record a message. Args: role (Literal["user", "assistant", "system", "tool"]): The role of the message. name (Optional[str], optional): The name of the agent. (default: :obj:`None`) message (Optional[str], optional): The message to record. (default: :obj:`None`) function_call (Optional[Dict[str, Any]], optional): The function call to record. (default: :obj:`None`) Returns: Self: The data collector. """ name = name or role self.history.append( CollectorData( id=uuid.uuid4(), name=name, role=role, message=message, function_call=function_call, ) ) return self
[docs] def record( self, agent: Union[List[ChatAgent], ChatAgent], ) -> Self: r"""Record agents. Args: agent (Union[List[ChatAgent], ChatAgent]): The agent(s) to inject. """ if not isinstance(agent, list): agent = [agent] for a in agent: name = a.role_name if not name: name = f"{a.__class__.__name__}_{len(self.agents)}" if name in [n for n, _ in self.agents]: raise ValueError(f"Name {name} already exists") self.agents.append((name, a)) return self
[docs] def start(self) -> Self: r"""Start recording.""" self._recording = True return self
[docs] def stop(self) -> Self: r"""Stop recording.""" self._recording = False return self
@property def recording(self) -> bool: r"""Whether the collector is recording.""" return self._recording
[docs] def reset(self, reset_agents: bool = True): r"""Reset the collector. Args: reset_agents (bool, optional): Whether to reset the agents. Defaults to True. """ self.history = [] if reset_agents: for _, agent in self.agents: agent.reset()
[docs] @abstractmethod def convert(self) -> Any: r"""Convert the collected data.""" pass
[docs] @abstractmethod def llm_convert(self, converter: Any, prompt: Optional[str] = None) -> Any: r"""Convert the collected data.""" pass
[docs] def get_agent_history(self, name: str) -> List[CollectorData]: r"""Get the message history of an agent. Args: name (str): The name of the agent. Returns: List[CollectorData]: The message history of the agent """ if not self.history: for _name, agent in self.agents: if _name == name: return [ CollectorData.from_context(name, dict(i)) for i in agent.memory.get_context()[0] ] return [msg for msg in self.history if msg.name == name]