# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import random
from typing import Any, Dict, List, Optional, Sequence
from tqdm import tqdm
from camel.agents.multi_hop_generator_agent import MultiHopGeneratorAgent
from camel.datagen.source2synth.user_data_processor_config import (
ProcessorConfig,
)
from camel.logger import get_logger
logger = get_logger(__name__)
[docs]
class UserDataProcessor:
r"""A processor for generating multi-hop question-answer pairs from user
data.
This class handles the processing of text data to generate multi-hop
question-answer pairs using either an AI model or rule-based approaches.
It manages the entire pipeline from text preprocessing to dataset curation.
Attributes:
config (ProcessorConfig): Configuration for data processing parameters.
rng (random.Random): Random number generator for reproducibility.
multi_hop_agent (Optional[MultiHopGeneratorAgent]): Agent for
generating QA pairs.
"""
def __init__(self, config: Optional[ProcessorConfig] = None):
r"""Initialize the UserDataProcessor.
Args:
config (Optional[ProcessorConfig], optional): Configuration for
data processing. (default: :obj:`None`)
"""
self.config = config or ProcessorConfig()
self.rng = random.Random(self.config.seed)
self.multi_hop_agent = (
self.config.hop_generating_agent
if self.config.use_ai_model
else None
)
[docs]
def process_text(
self, text: str, source: str = "user_input"
) -> List[Dict[str, Any]]:
r"""Process a single text to generate multi-hop QA pairs.
Args:
text (str): The input text to process.
source (str, optional): Source identifier for the text.
(default: :obj:`"user_input"`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
"""
# Convert text to standard format
raw_data = [
{
'text': text,
'source': source,
}
]
# Construct examples
constructor = ExampleConstructor(self.config, self.multi_hop_agent)
examples = constructor.construct_examples(raw_data)
# Manage data
curator = DataCurator(self.config, self.rng)
final_dataset = curator.curate_dataset(examples)
return final_dataset
[docs]
def process_batch(
self, texts: List[str], sources: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
r"""Process multiple texts in batch to generate multi-hop QA pairs.
Args:
texts (List[str]): List of input texts to process.
sources (Optional[List[str]], optional): List of source
identifiers. (default: :obj:`None`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
Raises:
ValueError: If length of sources doesn't match length of texts.
"""
if sources is None:
sources = ["user_input"] * len(texts)
elif len(sources) != len(texts):
raise ValueError("Length of sources must match length of texts")
raw_data = [
{
'text': text,
'source': source,
}
for text, source in zip(texts, sources)
]
# Construct examples
constructor = ExampleConstructor(self.config, self.multi_hop_agent)
examples = constructor.construct_examples(raw_data)
# Manage data
curator = DataCurator(self.config, self.rng)
final_dataset = curator.curate_dataset(examples)
return final_dataset
[docs]
class ExampleConstructor:
r"""Constructs training examples from raw text data.
This class handles the construction of training examples by preprocessing
text, extracting information pairs, and generating question-answer pairs.
Attributes:
config (ProcessorConfig): Configuration for example construction.
multi_hop_agent (Optional[MultiHopGeneratorAgent]): Agent for QA
generation.
"""
def __init__(
self,
config: ProcessorConfig,
multi_hop_agent: Optional[MultiHopGeneratorAgent] = None,
):
r"""Initialize the ExampleConstructor.
Args:
config (ProcessorConfig): Configuration for example construction.
multi_hop_agent (Optional[MultiHopGeneratorAgent], optional):
Agent for generating multi-hop QA pairs. (default: :obj:`None`)
"""
self.config = config
self.multi_hop_agent = multi_hop_agent
[docs]
def construct_examples(
self, raw_data: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Construct training examples from raw data.
Args:
raw_data (List[Dict[str, Any]]): List of raw data dictionaries
containing text and metadata.
Returns:
List[Dict[str, Any]]: List of constructed examples with QA pairs
and metadata.
"""
logger.info("Starting to construct training examples...")
examples = []
for data in tqdm(raw_data, desc="Constructing examples"):
# 1. Text preprocessing
processed_text = self._preprocess_text(data.get('text', ''))
if not processed_text:
continue
# 2. Generate key information pairs
info_pairs = self._extract_info_pairs(processed_text)
# 3. Construct question-answer pairs
qa_pairs = self._generate_qa_pairs(info_pairs)
# 4. Add metadata
example = {
'text': processed_text,
'qa_pairs': qa_pairs,
'metadata': {
'source': data.get('source', 'unknown'),
'timestamp': data.get('timestamp', ''),
'complexity': self._calculate_complexity(qa_pairs),
},
}
examples.append(example)
logger.info(f"Successfully constructed {len(examples)} examples")
return examples
def _preprocess_text(self, text: str) -> str:
r"""Preprocess input text for example construction.
Args:
text (str): Input text to preprocess.
Returns:
str: Preprocessed text, or empty string if text fails quality
checks.
"""
if not isinstance(text, str):
return ''
# 1. Basic cleaning
text = text.strip()
# 2. Length check
if (
len(text) < self.config.min_length
or len(text) > self.config.max_length
):
return ''
# 3. Quality check
if not self._check_text_quality(text):
return ''
return text
def _check_text_quality(self, text: str) -> bool:
r"""Check the quality of input text.
Args:
text (str): Text to check quality for.
Returns:
bool: True if text passes quality checks, False otherwise.
"""
# 1. Basic quality check
if text.count('.') < 2: # Must have at least 2 sentences
return False
# 2. Special character ratio check
special_char_ratio = len(
[c for c in text if not c.isalnum() and not c.isspace()]
) / len(text)
if special_char_ratio > 0.3: # No more than 30% special characters
return False
return True
def _extract_info_pairs(self, text: str) -> List[Dict[str, Sequence[str]]]:
r"""Extract information pairs and relationships from text.
Args:
text (str): Input text to extract information from.
Returns:
List[Dict[str, Sequence[str]]]: List of dictionaries containing
premise, intermediate, conclusion, and related contexts.
"""
# Split into sentences
sentences = [s.strip() for s in text.split('.') if s.strip()]
info_pairs = []
# Extract combinations of multiple related sentences
for i in range(len(sentences) - 2):
if len(sentences[i]) > 10 and len(sentences[i + 1]) > 10:
info_pairs.append(
{
'premise': sentences[i],
'intermediate': sentences[i + 1],
'conclusion': sentences[i + 2]
if i + 2 < len(sentences)
else '',
'related_contexts': [
s
for j, s in enumerate(sentences)
if j != i and j != i + 1 and len(s) > 10
][:2],
# Limit to 2 additional related contexts
}
)
return info_pairs
def _generate_qa_pairs(
self, info_pairs: List[Dict[str, Sequence[str]]]
) -> List[Dict[str, str]]:
r"""Generate multi-hop question-answer pairs from information pairs.
Args:
info_pairs (List[Dict[str, Sequence[str]]]): List of information
pairs extracted from text.
Returns:
List[Dict[str, str]]: List of generated QA pairs.
"""
qa_pairs = []
for pair in info_pairs:
# 1. Generate multi-hop question-answer pair using AI
if self.multi_hop_agent:
# Construct full context
context = (
f"{pair['premise']}. {pair['intermediate']}."
f" {pair['conclusion']}"
)
response = self.multi_hop_agent.generate_multi_hop_qa(context)
if response:
qa_pairs.append(response.value.dict())
continue
return qa_pairs
def _calculate_complexity(self, qa_pairs: List[Dict[str, Any]]) -> float:
r"""Calculate the complexity score for a set of QA pairs.
Args:
qa_pairs (List[Dict[str, Any]]): List of QA pairs to calculate
complexity for.
Returns:
float: Complexity score between 0.0 and 1.0.
"""
if not qa_pairs:
return 0.0
# Calculate complexity based on multiple factors
complexities = []
for qa in qa_pairs:
# 1. Number of reasoning steps
reasoning_steps_count = len(qa.get('reasoning_steps', []))
# 2. Number of supporting facts
supporting_facts_count = len(qa.get('supporting_facts', []))
# 3. Question length
question_length = len(qa.get('question', '').split())
# 4. Answer length
answer_length = len(qa.get('answer', '').split())
# Calculate complexity of a single QA pair
qa_complexity = (
min(reasoning_steps_count / 3, 1.0)
* 0.4 # Weight for reasoning steps
+ min(supporting_facts_count / 3, 1.0)
* 0.3 # Weight for supporting facts
+ min(question_length / 20, 1.0)
* 0.15 # Weight for question length
+ min(answer_length / 50, 1.0) * 0.15
# Weight for answer length
)
complexities.append(qa_complexity)
return sum(complexities) / len(complexities)
[docs]
class DataCurator:
r"""Manages and curates datasets of multi-hop question-answer pairs.
This class handles dataset management tasks including quality filtering,
complexity filtering, deduplication, and dataset sampling.
Attributes:
config (ProcessorConfig): Configuration for data curation parameters.
rng (random.Random): Random number generator for reproducible sampling.
"""
def __init__(self, config: ProcessorConfig, rng: random.Random):
r"""Initialize the DataCurator.
Args:
config (ProcessorConfig): Configuration for data curation.
rng (random.Random): Random number generator for reproducibility.
"""
self.config = config
self.rng = rng
[docs]
def curate_dataset(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Manage and curate a dataset through multiple filtering stages.
Args:
examples (List[Dict[str, Any]]): List of examples to curate.
Returns:
List[Dict[str, Any]]: Curated dataset meeting quality criteria.
"""
logger.info("Starting dataset management...")
# 1. Quality filtering
quality_filtered = self._quality_filter(examples)
logger.info(
f"Remaining examples after quality filtering:"
f" {len(quality_filtered)}"
)
# 2. Complexity filtering
complexity_filtered = self._complexity_filter(quality_filtered)
logger.info(
f"Remaining examples after complexity filtering:"
f" {len(complexity_filtered)}"
)
# 3. Deduplication
deduplicated = self._remove_duplicates(complexity_filtered)
logger.info(
f"Remaining examples after deduplication: {len(deduplicated)}"
)
# 4. Sample to target size
final_dataset = self._sample_dataset(deduplicated)
logger.info(f"Final dataset size: {len(final_dataset)}")
return final_dataset
def _quality_filter(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Filter examples based on quality criteria.
Args:
examples (List[Dict[str, Any]]): List of examples to filter.
Returns:
List[Dict[str, Any]]: Examples that pass quality checks.
"""
filtered = []
for example in examples:
# 1. Check QA pair quality
qa_quality = self._check_qa_quality(example.get('qa_pairs', []))
# 2. Check text quality
text_quality = (
len(example.get('text', '').split()) >= 20
) # At least 20 words
if qa_quality and text_quality:
filtered.append(example)
return filtered
def _check_qa_quality(self, qa_pairs: List[Dict[str, str]]) -> bool:
r"""Check the quality of question-answer pairs.
Args:
qa_pairs (List[Dict[str, str]]): List of QA pairs to check.
Returns:
bool: True if QA pairs meet quality criteria, False otherwise.
"""
if not qa_pairs:
return False
for qa in qa_pairs:
# 1. Length check
if (
len(qa.get('question', '')) < 10
or len(qa.get('answer', '')) < 5
):
return False
# 2. QA pair duplication check
if qa.get('question', '') == qa.get('answer', ''):
return False
return True
def _complexity_filter(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Filter examples based on complexity threshold.
Removes examples with complexity scores below the configured threshold.
Args:
examples (List[Dict[str, Any]]): List of examples to filter.
Returns:
List[Dict[str, Any]]: Examples meeting complexity threshold.
"""
return [
example
for example in examples
if example.get('metadata', {}).get('complexity', 0)
>= self.config.complexity_threshold
]
def _remove_duplicates(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Remove duplicate examples from the dataset.
Args:
examples (List[Dict[str, Any]]): List of examples to deduplicate.
Returns:
List[Dict[str, Any]]: Deduplicated examples.
"""
seen = set()
unique_examples = []
for example in examples:
# Use text and QA pair combination as unique identifier
text = example.get('text', '')
qa_str = str(example.get('qa_pairs', []))
identifier = hash(text + qa_str)
if identifier not in seen:
seen.add(identifier)
unique_examples.append(example)
return unique_examples
def _sample_dataset(
self, examples: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Sample examples to match target dataset size.
Args:
examples (List[Dict[str, Any]]): List of examples to sample from.
Returns:
List[Dict[str, Any]]: Sampled dataset of target size or smaller.
"""
if len(examples) <= self.config.dataset_size:
return examples
return self.rng.sample(examples, self.config.dataset_size)