Source code for camel.memories.context_creators.score_based
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import List, Tuple
from pydantic import BaseModel
from camel.memories.base import BaseContextCreator
from camel.memories.records import ContextRecord
from camel.messages import OpenAIMessage
from camel.utils import BaseTokenCounter
class _ContextUnit(BaseModel):
idx: int
record: ContextRecord
num_tokens: int
[docs]
class ScoreBasedContextCreator(BaseContextCreator):
r"""A default implementation of context creation strategy, which inherits
from :obj:`BaseContextCreator`.
This class provides a strategy to generate a conversational context from
a list of chat history records while ensuring the total token count of
the context does not exceed a specified limit. It prunes messages based
on their score if the total token count exceeds the limit.
Args:
token_counter (BaseTokenCounter): An instance responsible for counting
tokens in a message.
token_limit (int): The maximum number of tokens allowed in the
generated context.
"""
def __init__(
self, token_counter: BaseTokenCounter, token_limit: int
) -> None:
self._token_counter = token_counter
self._token_limit = token_limit
@property
def token_counter(self) -> BaseTokenCounter:
return self._token_counter
@property
def token_limit(self) -> int:
return self._token_limit
[docs]
def create_context(
self,
records: List[ContextRecord],
) -> Tuple[List[OpenAIMessage], int]:
r"""Creates conversational context from chat history while respecting
token limits.
Constructs the context from provided records and ensures that the total
token count does not exceed the specified limit by pruning the least
score messages if necessary.
Args:
records (List[ContextRecord]): A list of message records from which
to generate the context.
Returns:
Tuple[List[OpenAIMessage], int]: A tuple containing the constructed
context in OpenAIMessage format and the total token count.
Raises:
RuntimeError: If it's impossible to create a valid context without
exceeding the token limit.
"""
# Create unique context units list
uuid_set = set()
context_units = []
for idx, record in enumerate(records):
if record.memory_record.uuid not in uuid_set:
uuid_set.add(record.memory_record.uuid)
context_units.append(
_ContextUnit(
idx=idx,
record=record,
num_tokens=self.token_counter.count_tokens_from_messages(
[record.memory_record.to_openai_message()]
),
)
)
# TODO: optimize the process, may give information back to memory
# If not exceed token limit, simply return
total_tokens = sum([unit.num_tokens for unit in context_units])
if total_tokens <= self.token_limit:
return self._create_output(context_units)
# Sort by score
context_units = sorted(
context_units, key=lambda unit: unit.record.score
)
# Remove the least score messages until total token number is smaller
# than token limit
truncate_idx = None
for i, unit in enumerate(context_units):
if unit.record.score == 1:
raise RuntimeError(
"Cannot create context: exceed token limit.", total_tokens
)
total_tokens -= unit.num_tokens
if total_tokens <= self.token_limit:
truncate_idx = i
break
if truncate_idx is None:
raise RuntimeError(
"Cannot create context: exceed token limit.", total_tokens
)
return self._create_output(context_units[truncate_idx + 1 :])
def _create_output(
self, context_units: List[_ContextUnit]
) -> Tuple[List[OpenAIMessage], int]:
r"""Helper method to generate output from context units.
This method converts the provided context units into a format suitable
for output, specifically a list of OpenAIMessages and an integer
representing the total token count.
"""
context_units = sorted(context_units, key=lambda unit: unit.idx)
return [
unit.record.memory_record.to_openai_message()
for unit in context_units
], sum([unit.num_tokens for unit in context_units])