Source code for camel.retrievers.hybrid_retrival

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Collection, Dict, List, Optional, Sequence, Union

import numpy as np

from camel.embeddings import BaseEmbedding
from camel.retrievers import BaseRetriever, BM25Retriever, VectorRetriever
from camel.storages import BaseVectorStorage


[docs] class HybridRetriever(BaseRetriever): def __init__( self, embedding_model: Optional[BaseEmbedding] = None, vector_storage: Optional[BaseVectorStorage] = None, ) -> None: r"""Initializes the HybridRetriever with optional embedding model and vector storage. Args: embedding_model (Optional[BaseEmbedding]): An optional embedding model used by the VectorRetriever. Defaults to None. vector_storage (Optional[BaseVectorStorage]): An optional vector storage used by the VectorRetriever. Defaults to None. """ self.vr = VectorRetriever(embedding_model, vector_storage) self.bm25 = BM25Retriever()
[docs] def process(self, content_input_path: str) -> None: r"""Processes the content input path for both vector and BM25 retrievers. Args: content_input_path (str): File path or URL of the content to be processed. Raises: ValueError: If the content_input_path is empty. """ if not content_input_path: raise ValueError("content_input_path cannot be empty.") self.content_input_path = content_input_path self.vr.process(content=self.content_input_path) self.bm25.process(content_input_path=self.content_input_path)
def _sort_rrf_scores( self, vector_retriever_results: List[Dict[str, Any]], bm25_retriever_results: List[Dict[str, Any]], top_k: int, vector_weight: float, bm25_weight: float, rank_smoothing_factor: float, ) -> List[Dict[str, Union[str, float]]]: r"""Sorts and combines results from vector and BM25 retrievers using Reciprocal Rank Fusion (RRF). Args: vector_retriever_results: A list of dictionaries containing the results from the vector retriever, where each dictionary contains a 'text' entry. bm25_retriever_results: A list of dictionaries containing the results from the BM25 retriever, where each dictionary contains a 'text' entry. top_k: The number of top results to return after sorting by RRF score. vector_weight: The weight to assign to the vector retriever results in the RRF calculation. bm25_weight: The weight to assign to the BM25 retriever results in the RRF calculation. rank_smoothing_factor: A hyperparameter for the RRF calculation that helps smooth the rank positions. Returns: List[Dict[str, Union[str, float]]]: A list of dictionaries representing the sorted results. Each dictionary contains the 'text'from the retrieved items and their corresponding 'rrf_score'. Raises: ValueError: If any of the input weights are negative. References: https://medium.com/@devalshah1619/mathematical-intuition-behind-reciprocal-rank-fusion-rrf-explained-in-2-mins-002df0cc5e2a https://colab.research.google.com/drive/1iwVJrN96fiyycxN1pBqWlEr_4EPiGdGy#scrollTo=0qh83qGV2dY8 """ text_to_id = {} id_to_info = {} current_id = 1 # Iterate over vector_retriever_results for rank, result in enumerate(vector_retriever_results, start=1): text = result.get('text', None) # type: ignore[attr-defined] if text is None: raise KeyError("Each result must contain a 'text' key") if text not in text_to_id: text_to_id[text] = current_id id_to_info[current_id] = {'text': text, 'vector_rank': rank} current_id += 1 else: id_to_info[text_to_id[text]]['vector_rank'] = rank # Iterate over bm25_retriever_results for rank, result in enumerate(bm25_retriever_results, start=1): text = result['text'] if text not in text_to_id: text_to_id[text] = current_id id_to_info[current_id] = {'text': text, 'bm25_rank': rank} current_id += 1 else: id_to_info[text_to_id[text]].setdefault('bm25_rank', rank) vector_ranks = np.array( [ info.get('vector_rank', float('inf')) for info in id_to_info.values() ] ) bm25_ranks = np.array( [ info.get('bm25_rank', float('inf')) for info in id_to_info.values() ] ) # Calculate RRF scores vector_rrf_scores = vector_weight / ( rank_smoothing_factor + vector_ranks ) bm25_rrf_scores = bm25_weight / (rank_smoothing_factor + bm25_ranks) rrf_scores = vector_rrf_scores + bm25_rrf_scores for idx, (_, info) in enumerate(id_to_info.items()): info['rrf_score'] = rrf_scores[idx] sorted_results = sorted( id_to_info.values(), key=lambda x: x['rrf_score'], reverse=True ) return sorted_results[:top_k]
[docs] def query( self, query: str, top_k: int = 20, vector_weight: float = 0.8, bm25_weight: float = 0.2, rank_smoothing_factor: int = 60, vector_retriever_top_k: int = 50, vector_retriever_similarity_threshold: float = 0.5, bm25_retriever_top_k: int = 50, return_detailed_info: bool = False, ) -> Union[ dict[str, Sequence[Collection[str]]], dict[str, Sequence[Union[str, float]]], ]: r"""Executes a hybrid retrieval query using both vector and BM25 retrievers. Args: query (str): The search query. top_k (int): Number of top results to return (default 20). vector_weight (float): Weight for vector retriever results in RRF. bm25_weight (float): Weight for BM25 retriever results in RRF. rank_smoothing_factor (int): RRF hyperparameter for rank smoothing. vector_retriever_top_k (int): Top results from vector retriever. vector_retriever_similarity_threshold (float): Similarity threshold for vector retriever. bm25_retriever_top_k (int): Top results from BM25 retriever. return_detailed_info (bool): Return detailed info if True. Returns: Union[ dict[str, Sequence[Collection[str]]], dict[str, Sequence[Union[str, float]]] ]: By default, returns only the text information. If `return_detailed_info` is `True`, return detailed information including rrf scores. """ if top_k > max(vector_retriever_top_k, bm25_retriever_top_k): raise ValueError( "top_k needs to be less than or equal to the " "maximum value among vector_retriever_top_k and " "bm25_retriever_top_k." ) if vector_weight < 0 or bm25_weight < 0: raise ValueError( "Neither `vector_weight` nor `bm25_weight` can be negative." ) vr_raw_results: List[Dict[str, Any]] = self.vr.query( query=query, top_k=vector_retriever_top_k, similarity_threshold=vector_retriever_similarity_threshold, ) # if the number of results is less than top_k, return all results with_score = [ info for info in vr_raw_results if 'similarity score' in info ] vector_retriever_results = sorted( with_score, key=lambda x: x['similarity score'], reverse=True ) bm25_retriever_results = self.bm25.query( query=query, top_k=bm25_retriever_top_k, ) all_retrieved_info = self._sort_rrf_scores( vector_retriever_results, bm25_retriever_results, top_k, vector_weight, bm25_weight, rank_smoothing_factor, ) retrieved_info = { "Original Query": query, "Retrieved Context": ( all_retrieved_info if return_detailed_info else [item['text'] for item in all_retrieved_info] # type: ignore[misc] ), } return retrieved_info