Source code for camel.retrievers.cohere_rerank_retriever

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
from typing import Any, Dict, List, Optional

from camel.retrievers import BaseRetriever
from camel.utils import dependencies_required

DEFAULT_TOP_K_RESULTS = 1


[docs] class CohereRerankRetriever(BaseRetriever): r"""An implementation of the `BaseRetriever` using the `Cohere Re-ranking` model. Attributes: model_name (str): The model name to use for re-ranking. api_key (Optional[str]): The API key for authenticating with the Cohere service. References: https://txt.cohere.com/rerank/ """ @dependencies_required('cohere') def __init__( self, model_name: str = "rerank-multilingual-v2.0", api_key: Optional[str] = None, ) -> None: r"""Initializes an instance of the CohereRerankRetriever. This constructor sets up a client for interacting with the Cohere API using the specified model name and API key. If the API key is not provided, it attempts to retrieve it from the COHERE_API_KEY environment variable. Args: model_name (str): The name of the model to be used for re-ranking. Defaults to 'rerank-multilingual-v2.0'. api_key (Optional[str]): The API key for authenticating requests to the Cohere API. If not provided, the method will attempt to retrieve the key from the environment variable 'COHERE_API_KEY'. Raises: ImportError: If the 'cohere' package is not installed. ValueError: If the API key is neither passed as an argument nor set in the environment variable. """ import cohere try: self.api_key = api_key or os.environ["COHERE_API_KEY"] except ValueError as e: raise ValueError( "Must pass in cohere api key or specify via COHERE_API_KEY" " environment variable." ) from e self.co = cohere.Client(self.api_key) self.model_name = model_name
[docs] def query( self, query: str, retrieved_result: List[Dict[str, Any]], top_k: int = DEFAULT_TOP_K_RESULTS, ) -> List[Dict[str, Any]]: r"""Queries and compiles results using the Cohere re-ranking model. Args: query (str): Query string for information retriever. retrieved_result (List[Dict[str, Any]]): The content to be re-ranked, should be the output from `BaseRetriever` like `VectorRetriever`. top_k (int, optional): The number of top results to return during retriever. Must be a positive integer. Defaults to `DEFAULT_TOP_K_RESULTS`. Returns: List[Dict[str, Any]]: Concatenated list of the query results. """ rerank_results = self.co.rerank( query=query, documents=retrieved_result, top_n=top_k, model=self.model_name, ) formatted_results = [] for result in rerank_results.results: selected_chunk = retrieved_result[result.index] selected_chunk['similarity score'] = result.relevance_score formatted_results.append(selected_chunk) return formatted_results