# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from sqlalchemy import JSON, Column, Integer
if TYPE_CHECKING:
from pyobvector.client import ObVecClient
from camel.storages.vectordb_storages import (
BaseVectorStorage,
VectorDBQuery,
VectorDBQueryResult,
VectorDBStatus,
VectorRecord,
)
from camel.utils import dependencies_required
logger = logging.getLogger(__name__)
[docs]
class OceanBaseStorage(BaseVectorStorage):
r"""An implementation of the `BaseVectorStorage` for interacting with
OceanBase Vector Database.
Args:
vector_dim (int): The dimension of storing vectors.
table_name (str): Name for the table in OceanBase.
uri (str): Connection URI for OceanBase (host:port).
(default: :obj:`"127.0.0.1:2881"`)
user (str): Username for connecting to OceanBase.
(default: :obj:`"root@test"`)
password (str): Password for the user. (default: :obj:`""`)
db_name (str): Database name in OceanBase.
(default: :obj:`"test"`)
distance (Literal["l2", "cosine"], optional): The distance metric for
vector comparison. Options: "l2", "cosine". (default: :obj:`"l2"`)
delete_table_on_del (bool, optional): Flag to determine if the
table should be deleted upon object destruction.
(default: :obj:`False`)
**kwargs (Any): Additional keyword arguments for initializing
`ObVecClient`.
Raises:
ImportError: If `pyobvector` package is not installed.
"""
@dependencies_required('pyobvector')
def __init__(
self,
vector_dim: int,
table_name: str,
uri: str = "127.0.0.1:2881",
user: str = "root@test",
password: str = "",
db_name: str = "test",
distance: Literal["l2", "cosine"] = "l2",
delete_table_on_del: bool = False,
**kwargs: Any,
) -> None:
from pyobvector.client import (
ObVecClient,
)
from pyobvector.client.index_param import (
IndexParam,
IndexParams,
)
from pyobvector.schema import VECTOR
self.vector_dim: int = vector_dim
self.table_name: str = table_name
self.distance: Literal["l2", "cosine"] = distance
self.delete_table_on_del: bool = delete_table_on_del
# Create client
self._client: ObVecClient = ObVecClient(
uri=uri, user=user, password=password, db_name=db_name, **kwargs
)
# Map distance to distance function in OceanBase
self._distance_func_map: Dict[str, str] = {
"cosine": "cosine_distance",
"l2": "l2_distance",
}
# Check or create table with vector index
if not self._client.check_table_exists(self.table_name):
# Define table schema
columns: List[Column] = [
Column("id", Integer, primary_key=True, autoincrement=True),
Column("embedding", VECTOR(vector_dim)),
Column("metadata", JSON),
]
# Create table
self._client.create_table(
table_name=self.table_name, columns=columns
)
# Create vector index
index_params: IndexParams = IndexParams()
index_params.add_index_param(
IndexParam(
index_name="embedding_idx",
field_name="embedding",
distance=self.distance,
type="hnsw",
m=16,
ef_construction=256,
)
)
self._client.create_vidx_with_vec_index_param(
table_name=self.table_name, vidx_param=index_params.params[0]
)
logger.info(f"Created table {self.table_name} with vector index")
else:
logger.info(f"Using existing table {self.table_name}")
def __del__(self):
r"""Deletes the table if :obj:`delete_table_on_del` is set to
:obj:`True`.
"""
if hasattr(self, "delete_table_on_del") and self.delete_table_on_del:
try:
self._client.drop_table_if_exist(self.table_name)
logger.info(f"Deleted table {self.table_name}")
except Exception as e:
logger.error(f"Failed to delete table {self.table_name}: {e}")
[docs]
def add(
self,
records: List[VectorRecord],
batch_size: int = 100,
**kwargs: Any,
) -> None:
r"""Saves a list of vector records to the storage.
Args:
records (List[VectorRecord]): List of vector records to be saved.
batch_size (int): Number of records to insert each batch.
Larger batches are more efficient but use more memory.
(default: :obj:`100`)
**kwargs (Any): Additional keyword arguments.
Raises:
RuntimeError: If there is an error during the saving process.
ValueError: If any vector dimension doesn't match vector_dim.
"""
if not records:
return
try:
# Convert records to OceanBase format
data: List[Dict[str, Any]] = []
for i, record in enumerate(records):
# Validate vector dimensions
if len(record.vector) != self.vector_dim:
raise ValueError(
f"Vector at index {i} has dimension "
f"{len(record.vector)}, expected {self.vector_dim}"
)
item: Dict[str, Any] = {
"embedding": record.vector,
"metadata": record.payload or {},
}
# If id is specified, use it
if record.id:
try:
# If id is numeric, use it directly
item["id"] = int(record.id)
except ValueError:
# If id is not numeric, store it in payload
item["metadata"]["_id"] = record.id
data.append(item)
# Batch insert when reaching batch_size
if len(data) >= batch_size:
self._client.insert(self.table_name, data=data)
data = []
# Insert any remaining records
if data:
self._client.insert(self.table_name, data=data)
except ValueError as e:
# Re-raise ValueError for dimension mismatch
raise e
except Exception as e:
error_msg = f"Failed to add records to OceanBase: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg)
[docs]
def delete(
self,
ids: List[str],
**kwargs: Any,
) -> None:
r"""Deletes a list of vectors identified by their IDs from the storage.
Args:
ids (List[str]): List of unique identifiers for the vectors to
be deleted.
**kwargs (Any): Additional keyword arguments.
Raises:
RuntimeError: If there is an error during the deletion process.
"""
if not ids:
return
try:
numeric_ids: List[int] = []
non_numeric_ids: List[str] = []
# Separate numeric and non-numeric IDs
for id_val in ids:
try:
numeric_ids.append(int(id_val))
except ValueError:
non_numeric_ids.append(id_val)
# Delete records with numeric IDs
if numeric_ids:
self._client.delete(self.table_name, ids=numeric_ids)
# Delete records with non-numeric IDs stored in metadata
if non_numeric_ids:
from sqlalchemy import text
for id_val in non_numeric_ids:
self._client.delete(
self.table_name,
where_clause=[
text(f"metadata->>'$.._id' = '{id_val}'")
],
)
except Exception as e:
error_msg = f"Failed to delete records from OceanBase: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg)
[docs]
def status(self) -> VectorDBStatus:
r"""Returns status of the vector database.
Returns:
VectorDBStatus: The vector database status.
"""
try:
# Get count of records
result = self._client.perform_raw_text_sql(
f"SELECT COUNT(*) FROM {self.table_name}"
)
count: int = result.fetchone()[0]
return VectorDBStatus(
vector_dim=self.vector_dim, vector_count=count
)
except Exception as e:
error_msg = f"Failed to get status from OceanBase: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg)
[docs]
def query(
self,
query: VectorDBQuery,
**kwargs: Any,
) -> List[VectorDBQueryResult]:
r"""Searches for similar vectors in the storage based on the
provided query.
Args:
query (VectorDBQuery): The query object containing the search
vector and the number of top similar vectors to retrieve.
**kwargs (Any): Additional keyword arguments.
Returns:
List[VectorDBQueryResult]: A list of vectors retrieved from the
storage based on similarity to the query vector.
Raises:
RuntimeError: If there is an error during the query process.
ValueError: If the query vector dimension does not match the
storage dimension.
"""
from sqlalchemy import func
try:
# Get distance function name
distance_func_name: str = self._distance_func_map.get(
self.distance, "l2_distance"
)
distance_func = getattr(func, distance_func_name)
# Validate query vector dimensions
if len(query.query_vector) != self.vector_dim:
raise ValueError(
f"Query vector dimension {len(query.query_vector)} "
f"does not match storage dimension {self.vector_dim}"
)
results = self._client.ann_search(
table_name=self.table_name,
vec_data=query.query_vector,
vec_column_name="embedding",
distance_func=distance_func,
with_dist=True,
topk=query.top_k,
output_column_names=["id", "embedding", "metadata"],
)
# Convert results to VectorDBQueryResult format
query_results: List[VectorDBQueryResult] = []
for row in results:
try:
result_dict: Dict[str, Any] = dict(row._mapping)
# Extract data
id_val: str = str(result_dict["id"])
# Handle vector - ensure it's a proper list of floats
vector: Any = result_dict.get("embedding")
if isinstance(vector, str):
# If vector is a string, try to parse it
try:
if vector.startswith('[') and vector.endswith(']'):
# Remove brackets and split by commas
vector = [
float(x.strip())
for x in vector[1:-1].split(',')
]
except (ValueError, TypeError) as e:
logger.warning(
f"Failed to parse vector string: {e}"
)
# Ensure we have a proper vector
if (
not isinstance(vector, list)
or len(vector) != self.vector_dim
):
logger.warning(
f"Invalid vector format, using zeros: {vector}"
)
vector = [0.0] * self.vector_dim
# Ensure metadata is a dictionary
metadata: Dict[str, Any] = result_dict.get("metadata", {})
if not isinstance(metadata, dict):
# Convert to dict if it's not already
try:
if isinstance(metadata, str):
metadata = json.loads(metadata)
else:
metadata = {"value": metadata}
except Exception:
metadata = {"value": str(metadata)}
distance_value: Optional[float] = None
for key in result_dict:
if (
key.endswith(distance_func_name)
or distance_func_name in key
):
distance_value = float(result_dict[key])
break
if distance_value is None:
# If we can't find the distance, use a default value
logger.warning(
"Could not find distance value in query results, "
"using default"
)
distance_value = 0.0
similarity: float = self._convert_distance_to_similarity(
distance_value
)
# Check if the id is stored in metadata
if isinstance(metadata, dict) and "_id" in metadata:
id_val = metadata.pop("_id")
# Create query result
query_results.append(
VectorDBQueryResult.create(
similarity=similarity,
vector=vector,
id=id_val,
payload=metadata,
)
)
except Exception as e:
logger.warning(f"Failed to process result row: {e}")
continue
return query_results
except Exception as e:
error_msg = f"Failed to query OceanBase: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _convert_distance_to_similarity(self, distance: float) -> float:
r"""Converts distance to similarity score based on distance metric."""
# Ensure distance is non-negative
distance = max(0.0, distance)
if self.distance == "cosine":
# Cosine distance = 1 - cosine similarity
# Ensure similarity is between 0 and 1
return max(0.0, min(1.0, 1.0 - distance))
elif self.distance == "l2":
import math
# Exponential decay function for L2 distance
return math.exp(-distance)
else:
# Default normalization, ensure result is between 0 and 1
return max(0.0, min(1.0, 1.0 - min(1.0, distance)))
[docs]
def clear(self) -> None:
r"""Remove all vectors from the storage."""
try:
self._client.delete(self.table_name)
logger.info(f"Cleared all records from table {self.table_name}")
except Exception as e:
error_msg = f"Failed to clear records from OceanBase: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg)
[docs]
def load(self) -> None:
r"""Load the collection hosted on cloud service."""
# OceanBase doesn't require explicit loading
pass
@property
def client(self) -> "ObVecClient":
r"""Provides access to underlying OceanBase vector database client."""
return self._client