Source code for camel.terminators.response_terminator

# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

from camel.messages import BaseMessage
from camel.types import TerminationMode

from .base import ResponseTerminator


[docs] class ResponseWordsTerminator(ResponseTerminator): r"""Terminate agent when some words reached to occurrence limit by any message of the response. Args: words_dict (dict): Dictionary of words and its occurrence threshold. case_sensitive (bool): Whether count the words as case-sensitive. (default: :obj:`False`) mode (TerminationMode): Whether terminate agent if any or all pre-set words reached the threshold. (default: :obj:`TerminationMode.ANY`) """ def __init__( self, words_dict: Dict[str, int], case_sensitive: bool = False, mode: TerminationMode = TerminationMode.ANY, ): super().__init__() self.words_dict = words_dict self.case_sensitive = case_sensitive self.mode = mode self._word_count_dict: List[Dict[str, int]] = [] self._validate() def _validate(self): if len(self.words_dict) == 0: raise ValueError("`words_dict` cannot be empty") for word in self.words_dict: threshold = self.words_dict[word] if threshold <= 0: raise ValueError( f"Threshold for word `{word}` should " f"be larger than 0, got `{threshold}`" )
[docs] def is_terminated( self, messages: List[BaseMessage] ) -> Tuple[bool, Optional[str]]: r"""Whether terminate the agent by checking the occurrence of specified words reached to preset thresholds. Args: messages (list): List of :obj:`BaseMessage` from a response. Returns: tuple: A tuple containing whether the agent should be terminated and a string of termination reason. """ if self._terminated: return True, self._termination_reason for i in range(len(messages)): if i >= len(self._word_count_dict): self._word_count_dict.append(defaultdict(int)) for word in self.words_dict: special_word = word if self.case_sensitive else word.lower() for i, message in enumerate(messages): if self.case_sensitive: content = message.content else: content = message.content.lower() if special_word in content: self._word_count_dict[i][word] += 1 num_reached: List[int] = [] all_reasons: List[List[str]] = [] for i in range(len(self._word_count_dict)): reached = 0 reasons: List[str] = [] for word, value in self._word_count_dict[i].items(): if value >= self.words_dict[word]: reached += 1 reason = ( f"Word `{word}` appears {value} times in the " f"{i + 1} message of the response which has " f"reached termination threshold " f"{self.words_dict[word]}." ) reasons.append(reason) all_reasons.append(reasons) num_reached.append(reached) for i, reached in enumerate(num_reached): if self.mode == TerminationMode.ANY: if reached > 0: self._terminated = True self._termination_reason = "\n".join(all_reasons[i]) elif self.mode == TerminationMode.ALL: if reached >= len(self.words_dict): self._terminated = True self._termination_reason = "\n".join(all_reasons[i]) else: raise ValueError( f"Unsupported termination mode " f"`{self.mode}`" ) return self._terminated, self._termination_reason
[docs] def reset(self): self._terminated = False self._termination_reason = None self._word_count_dict = defaultdict(int)