Source code for camel.toolkits.jina_reranker_toolkit

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import List, Optional, Tuple

from camel.toolkits import FunctionTool
from camel.toolkits.base import BaseToolkit
from camel.utils import MCPServer


[docs] @MCPServer() class JinaRerankerToolkit(BaseToolkit): r"""A class representing a toolkit for reranking documents using Jina Reranker. This class provides methods for reranking documents (text or images) based on their relevance to a given query using the Jina Reranker model. """ def __init__( self, timeout: Optional[float] = None, device: Optional[str] = None, ) -> None: r"""Initializes a new instance of the JinaRerankerToolkit class. Args: timeout (Optional[float]): The timeout value for API requests in seconds. If None, no timeout is applied. (default: :obj:`None`) device (Optional[str]): Device to load the model on. If None, will use CUDA if available, otherwise CPU. (default: :obj:`None`) """ import torch from transformers import AutoModel super().__init__(timeout=timeout) self.model = AutoModel.from_pretrained( 'jinaai/jina-reranker-m0', torch_dtype="auto", trust_remote_code=True, ) DEVICE = ( device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model.to(DEVICE) self.model.eval() def _sort_documents( self, documents: List[str], scores: List[float] ) -> List[Tuple[str, float]]: r"""Sort documents by their scores in descending order. Args: documents (List[str]): List of documents to sort. scores (List[float]): Corresponding scores for each document. Returns: List[Tuple[str, float]]: Sorted list of (document, score) pairs. Raises: ValueError: If documents and scores have different lengths. """ if len(documents) != len(scores): raise ValueError("Number of documents must match number of scores") doc_score_pairs = list(zip(documents, scores)) doc_score_pairs.sort(key=lambda x: x[1], reverse=True) return doc_score_pairs
[docs] def rerank_text_documents( self, query: str, documents: List[str], max_length: int = 1024, ) -> List[Tuple[str, float]]: r"""Reranks text documents based on their relevance to a text query. Args: query (str): The text query for reranking. documents (List[str]): List of text documents to be reranked. max_length (int): Maximum token length for processing. (default: :obj:`1024`) Returns: List[Tuple[str, float]]: A list of tuples containing the reranked documents and their relevance scores. """ import torch if self.model is None: raise ValueError( "Model has not been initialized or failed to initialize." ) with torch.inference_mode(): text_pairs = [[query, doc] for doc in documents] scores = self.model.compute_score( text_pairs, max_length=max_length, doc_type="text" ) return self._sort_documents(documents, scores)
[docs] def rerank_image_documents( self, query: str, documents: List[str], max_length: int = 2048, ) -> List[Tuple[str, float]]: r"""Reranks image documents based on their relevance to a text query. Args: query (str): The text query for reranking. documents (List[str]): List of image URLs or paths to be reranked. max_length (int): Maximum token length for processing. (default: :obj:`2048`) Returns: List[Tuple[str, float]]: A list of tuples containing the reranked image URLs/paths and their relevance scores. """ import torch if self.model is None: raise ValueError( "Model has not been initialized or failed to initialize." ) with torch.inference_mode(): image_pairs = [[query, doc] for doc in documents] scores = self.model.compute_score( image_pairs, max_length=max_length, doc_type="image" ) return self._sort_documents(documents, scores)
[docs] def image_query_text_documents( self, image_query: str, documents: List[str], max_length: int = 2048, ) -> List[Tuple[str, float]]: r"""Reranks text documents based on their relevance to an image query. Args: image_query (str): The image URL or path used as query. documents (List[str]): List of text documents to be reranked. max_length (int): Maximum token length for processing. (default: :obj:`2048`) Returns: List[Tuple[str, float]]: A list of tuples containing the reranked documents and their relevance scores. """ import torch if self.model is None: raise ValueError("Model has not been initialized.") with torch.inference_mode(): image_pairs = [[image_query, doc] for doc in documents] scores = self.model.compute_score( image_pairs, max_length=max_length, query_type="image", doc_type="text", ) return self._sort_documents(documents, scores)
[docs] def image_query_image_documents( self, image_query: str, documents: List[str], max_length: int = 2048, ) -> List[Tuple[str, float]]: r"""Reranks image documents based on their relevance to an image query. Args: image_query (str): The image URL or path used as query. documents (List[str]): List of image URLs or paths to be reranked. max_length (int): Maximum token length for processing. (default: :obj:`2048`) Returns: List[Tuple[str, float]]: A list of tuples containing the reranked image URLs/paths and their relevance scores. """ import torch if self.model is None: raise ValueError("Model has not been initialized.") with torch.inference_mode(): image_pairs = [[image_query, doc] for doc in documents] scores = self.model.compute_score( image_pairs, max_length=max_length, query_type="image", doc_type="image", ) return self._sort_documents(documents, scores)
[docs] def get_tools(self) -> List[FunctionTool]: r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Returns: List[FunctionTool]: A list of FunctionTool objects representing the functions in the toolkit. """ return [ FunctionTool(self.rerank_text_documents), FunctionTool(self.rerank_image_documents), FunctionTool(self.image_query_text_documents), FunctionTool(self.image_query_image_documents), ]