Source code for camel.storages.vectordb_storages.oceanbase

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import json
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional

from sqlalchemy import JSON, Column, Integer

if TYPE_CHECKING:
    from pyobvector.client import ObVecClient

from camel.storages.vectordb_storages import (
    BaseVectorStorage,
    VectorDBQuery,
    VectorDBQueryResult,
    VectorDBStatus,
    VectorRecord,
)
from camel.utils import dependencies_required

logger = logging.getLogger(__name__)


[docs] class OceanBaseStorage(BaseVectorStorage): r"""An implementation of the `BaseVectorStorage` for interacting with OceanBase Vector Database. Args: vector_dim (int): The dimension of storing vectors. table_name (str): Name for the table in OceanBase. uri (str): Connection URI for OceanBase (host:port). (default: :obj:`"127.0.0.1:2881"`) user (str): Username for connecting to OceanBase. (default: :obj:`"root@test"`) password (str): Password for the user. (default: :obj:`""`) db_name (str): Database name in OceanBase. (default: :obj:`"test"`) distance (Literal["l2", "cosine"], optional): The distance metric for vector comparison. Options: "l2", "cosine". (default: :obj:`"l2"`) delete_table_on_del (bool, optional): Flag to determine if the table should be deleted upon object destruction. (default: :obj:`False`) **kwargs (Any): Additional keyword arguments for initializing `ObVecClient`. Raises: ImportError: If `pyobvector` package is not installed. """ @dependencies_required('pyobvector') def __init__( self, vector_dim: int, table_name: str, uri: str = "127.0.0.1:2881", user: str = "root@test", password: str = "", db_name: str = "test", distance: Literal["l2", "cosine"] = "l2", delete_table_on_del: bool = False, **kwargs: Any, ) -> None: from pyobvector.client import ( ObVecClient, ) from pyobvector.client.index_param import ( IndexParam, IndexParams, ) from pyobvector.schema import VECTOR self.vector_dim: int = vector_dim self.table_name: str = table_name self.distance: Literal["l2", "cosine"] = distance self.delete_table_on_del: bool = delete_table_on_del # Create client self._client: ObVecClient = ObVecClient( uri=uri, user=user, password=password, db_name=db_name, **kwargs ) # Map distance to distance function in OceanBase self._distance_func_map: Dict[str, str] = { "cosine": "cosine_distance", "l2": "l2_distance", } # Check or create table with vector index if not self._client.check_table_exists(self.table_name): # Define table schema columns: List[Column] = [ Column("id", Integer, primary_key=True, autoincrement=True), Column("embedding", VECTOR(vector_dim)), Column("metadata", JSON), ] # Create table self._client.create_table( table_name=self.table_name, columns=columns ) # Create vector index index_params: IndexParams = IndexParams() index_params.add_index_param( IndexParam( index_name="embedding_idx", field_name="embedding", distance=self.distance, type="hnsw", m=16, ef_construction=256, ) ) self._client.create_vidx_with_vec_index_param( table_name=self.table_name, vidx_param=index_params.params[0] ) logger.info(f"Created table {self.table_name} with vector index") else: logger.info(f"Using existing table {self.table_name}") def __del__(self): r"""Deletes the table if :obj:`delete_table_on_del` is set to :obj:`True`. """ if hasattr(self, "delete_table_on_del") and self.delete_table_on_del: try: self._client.drop_table_if_exist(self.table_name) logger.info(f"Deleted table {self.table_name}") except Exception as e: logger.error(f"Failed to delete table {self.table_name}: {e}")
[docs] def add( self, records: List[VectorRecord], batch_size: int = 100, **kwargs: Any, ) -> None: r"""Saves a list of vector records to the storage. Args: records (List[VectorRecord]): List of vector records to be saved. batch_size (int): Number of records to insert each batch. Larger batches are more efficient but use more memory. (default: :obj:`100`) **kwargs (Any): Additional keyword arguments. Raises: RuntimeError: If there is an error during the saving process. ValueError: If any vector dimension doesn't match vector_dim. """ if not records: return try: # Convert records to OceanBase format data: List[Dict[str, Any]] = [] for i, record in enumerate(records): # Validate vector dimensions if len(record.vector) != self.vector_dim: raise ValueError( f"Vector at index {i} has dimension " f"{len(record.vector)}, expected {self.vector_dim}" ) item: Dict[str, Any] = { "embedding": record.vector, "metadata": record.payload or {}, } # If id is specified, use it if record.id: try: # If id is numeric, use it directly item["id"] = int(record.id) except ValueError: # If id is not numeric, store it in payload item["metadata"]["_id"] = record.id data.append(item) # Batch insert when reaching batch_size if len(data) >= batch_size: self._client.insert(self.table_name, data=data) data = [] # Insert any remaining records if data: self._client.insert(self.table_name, data=data) except ValueError as e: # Re-raise ValueError for dimension mismatch raise e except Exception as e: error_msg = f"Failed to add records to OceanBase: {e}" logger.error(error_msg) raise RuntimeError(error_msg)
[docs] def delete( self, ids: List[str], **kwargs: Any, ) -> None: r"""Deletes a list of vectors identified by their IDs from the storage. Args: ids (List[str]): List of unique identifiers for the vectors to be deleted. **kwargs (Any): Additional keyword arguments. Raises: RuntimeError: If there is an error during the deletion process. """ if not ids: return try: numeric_ids: List[int] = [] non_numeric_ids: List[str] = [] # Separate numeric and non-numeric IDs for id_val in ids: try: numeric_ids.append(int(id_val)) except ValueError: non_numeric_ids.append(id_val) # Delete records with numeric IDs if numeric_ids: self._client.delete(self.table_name, ids=numeric_ids) # Delete records with non-numeric IDs stored in metadata if non_numeric_ids: from sqlalchemy import text for id_val in non_numeric_ids: self._client.delete( self.table_name, where_clause=[ text(f"metadata->>'$.._id' = '{id_val}'") ], ) except Exception as e: error_msg = f"Failed to delete records from OceanBase: {e}" logger.error(error_msg) raise RuntimeError(error_msg)
[docs] def status(self) -> VectorDBStatus: r"""Returns status of the vector database. Returns: VectorDBStatus: The vector database status. """ try: # Get count of records result = self._client.perform_raw_text_sql( f"SELECT COUNT(*) FROM {self.table_name}" ) count: int = result.fetchone()[0] return VectorDBStatus( vector_dim=self.vector_dim, vector_count=count ) except Exception as e: error_msg = f"Failed to get status from OceanBase: {e}" logger.error(error_msg) raise RuntimeError(error_msg)
[docs] def query( self, query: VectorDBQuery, **kwargs: Any, ) -> List[VectorDBQueryResult]: r"""Searches for similar vectors in the storage based on the provided query. Args: query (VectorDBQuery): The query object containing the search vector and the number of top similar vectors to retrieve. **kwargs (Any): Additional keyword arguments. Returns: List[VectorDBQueryResult]: A list of vectors retrieved from the storage based on similarity to the query vector. Raises: RuntimeError: If there is an error during the query process. ValueError: If the query vector dimension does not match the storage dimension. """ from sqlalchemy import func try: # Get distance function name distance_func_name: str = self._distance_func_map.get( self.distance, "l2_distance" ) distance_func = getattr(func, distance_func_name) # Validate query vector dimensions if len(query.query_vector) != self.vector_dim: raise ValueError( f"Query vector dimension {len(query.query_vector)} " f"does not match storage dimension {self.vector_dim}" ) results = self._client.ann_search( table_name=self.table_name, vec_data=query.query_vector, vec_column_name="embedding", distance_func=distance_func, with_dist=True, topk=query.top_k, output_column_names=["id", "embedding", "metadata"], ) # Convert results to VectorDBQueryResult format query_results: List[VectorDBQueryResult] = [] for row in results: try: result_dict: Dict[str, Any] = dict(row._mapping) # Extract data id_val: str = str(result_dict["id"]) # Handle vector - ensure it's a proper list of floats vector: Any = result_dict.get("embedding") if isinstance(vector, str): # If vector is a string, try to parse it try: if vector.startswith('[') and vector.endswith(']'): # Remove brackets and split by commas vector = [ float(x.strip()) for x in vector[1:-1].split(',') ] except (ValueError, TypeError) as e: logger.warning( f"Failed to parse vector string: {e}" ) # Ensure we have a proper vector if ( not isinstance(vector, list) or len(vector) != self.vector_dim ): logger.warning( f"Invalid vector format, using zeros: {vector}" ) vector = [0.0] * self.vector_dim # Ensure metadata is a dictionary metadata: Dict[str, Any] = result_dict.get("metadata", {}) if not isinstance(metadata, dict): # Convert to dict if it's not already try: if isinstance(metadata, str): metadata = json.loads(metadata) else: metadata = {"value": metadata} except Exception: metadata = {"value": str(metadata)} distance_value: Optional[float] = None for key in result_dict: if ( key.endswith(distance_func_name) or distance_func_name in key ): distance_value = float(result_dict[key]) break if distance_value is None: # If we can't find the distance, use a default value logger.warning( "Could not find distance value in query results, " "using default" ) distance_value = 0.0 similarity: float = self._convert_distance_to_similarity( distance_value ) # Check if the id is stored in metadata if isinstance(metadata, dict) and "_id" in metadata: id_val = metadata.pop("_id") # Create query result query_results.append( VectorDBQueryResult.create( similarity=similarity, vector=vector, id=id_val, payload=metadata, ) ) except Exception as e: logger.warning(f"Failed to process result row: {e}") continue return query_results except Exception as e: error_msg = f"Failed to query OceanBase: {e}" logger.error(error_msg) raise RuntimeError(error_msg)
def _convert_distance_to_similarity(self, distance: float) -> float: r"""Converts distance to similarity score based on distance metric.""" # Ensure distance is non-negative distance = max(0.0, distance) if self.distance == "cosine": # Cosine distance = 1 - cosine similarity # Ensure similarity is between 0 and 1 return max(0.0, min(1.0, 1.0 - distance)) elif self.distance == "l2": import math # Exponential decay function for L2 distance return math.exp(-distance) else: # Default normalization, ensure result is between 0 and 1 return max(0.0, min(1.0, 1.0 - min(1.0, distance)))
[docs] def clear(self) -> None: r"""Remove all vectors from the storage.""" try: self._client.delete(self.table_name) logger.info(f"Cleared all records from table {self.table_name}") except Exception as e: error_msg = f"Failed to clear records from OceanBase: {e}" logger.error(error_msg) raise RuntimeError(error_msg)
[docs] def load(self) -> None: r"""Load the collection hosted on cloud service.""" # OceanBase doesn't require explicit loading pass
@property def client(self) -> "ObVecClient": r"""Provides access to underlying OceanBase vector database client.""" return self._client