# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from camel.storages.vectordb_storages import (
BaseVectorStorage,
VectorDBQuery,
VectorDBQueryResult,
VectorDBStatus,
VectorRecord,
)
from camel.utils import dependencies_required
logger = logging.getLogger(__name__)
[docs]
class MilvusStorage(BaseVectorStorage):
r"""An implementation of the `BaseVectorStorage` for interacting with
Milvus, a cloud-native vector search engine.
The detailed information about Milvus is available at:
`Milvus <https://milvus.io/docs/overview.md/>`_
Args:
vector_dim (int): The dimenstion of storing vectors.
url_and_api_key (Tuple[str, str]): Tuple containing
the URL and API key for connecting to a remote Milvus instance.
URL maps to Milvus uri concept, typically "endpoint:port".
API key maps to Milvus token concept, for self-hosted it's
"username:pwd", for Zilliz Cloud (fully-managed Milvus) it's API
Key.
collection_name (Optional[str], optional): Name for the collection in
the Milvus. If not provided, set it to the current time with iso
format. (default: :obj:`None`)
**kwargs (Any): Additional keyword arguments for initializing
`MilvusClient`.
Raises:
ImportError: If `pymilvus` package is not installed.
"""
@dependencies_required('pymilvus')
def __init__(
self,
vector_dim: int,
url_and_api_key: Tuple[str, str],
collection_name: Optional[str] = None,
**kwargs: Any,
) -> None:
from pymilvus import MilvusClient
self._client: MilvusClient
self._create_client(url_and_api_key, **kwargs)
self.vector_dim = vector_dim
self.collection_name = (
collection_name or self._generate_collection_name()
)
self._check_and_create_collection()
def _create_client(
self,
url_and_api_key: Tuple[str, str],
**kwargs: Any,
) -> None:
r"""Initializes the Milvus client with the provided connection details.
Args:
url_and_api_key (Tuple[str, str]): The URL and API key for the
Milvus server.
**kwargs: Additional keyword arguments passed to the Milvus client.
"""
from pymilvus import MilvusClient
self._client = MilvusClient(
uri=url_and_api_key[0],
token=url_and_api_key[1],
**kwargs,
)
def _check_and_create_collection(self) -> None:
r"""Checks if the specified collection exists in Milvus and creates it
if it doesn't, ensuring it matches the specified vector dimensionality.
"""
if self._collection_exists(self.collection_name):
in_dim = self._get_collection_info(self.collection_name)[
"vector_dim"
]
if in_dim != self.vector_dim:
# The name of collection has to be confirmed by the user
raise ValueError(
"Vector dimension of the existing collection "
f'"{self.collection_name}" ({in_dim}) is different from '
f"the given embedding dim ({self.vector_dim})."
)
else:
self._create_collection(
collection_name=self.collection_name,
)
def _create_collection(
self,
collection_name: str,
**kwargs: Any,
) -> None:
r"""Creates a new collection in the database.
Args:
collection_name (str): Name of the collection to be created.
**kwargs (Any): Additional keyword arguments pass to create
collection.
"""
from pymilvus import DataType
# Set the schema
schema = self._client.create_schema(
auto_id=False,
enable_dynamic_field=True,
description='collection schema',
)
schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
descrition='A unique identifier for the vector',
is_primary=True,
max_length=65535,
)
# max_length reference: https://milvus.io/docs/limitations.md
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
description='The numerical representation of the vector',
dim=self.vector_dim,
)
schema.add_field(
field_name="payload",
datatype=DataType.JSON,
description=(
'Any additional metadata or information related'
'to the vector'
),
)
# Create the collection
self._client.create_collection(
collection_name=collection_name,
schema=schema,
**kwargs,
)
# Set the index of the parameters
index_params = self._client.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="AUTOINDEX",
index_name="vector_index",
)
self._client.create_index(
collection_name=collection_name, index_params=index_params
)
def _delete_collection(
self,
collection_name: str,
) -> None:
r"""Deletes an existing collection from the database.
Args:
collection (str): Name of the collection to be deleted.
"""
self._client.drop_collection(collection_name=collection_name)
def _collection_exists(self, collection_name: str) -> bool:
r"""Checks whether a collection with the specified name exists in the
database.
Args:
collection_name (str): The name of the collection to check.
Returns:
bool: True if the collection exists, False otherwise.
"""
return self._client.has_collection(collection_name)
def _generate_collection_name(self) -> str:
r"""Generates a unique name for a new collection based on the current
timestamp. Milvus collection names can only contain alphanumeric
characters and underscores.
Returns:
str: A unique, valid collection name.
"""
timestamp = datetime.now().isoformat()
transformed_name = re.sub(r'[^a-zA-Z0-9_]', '_', timestamp)
valid_name = "Time" + transformed_name
return valid_name
def _get_collection_info(self, collection_name: str) -> Dict[str, Any]:
r"""Retrieves details of an existing collection.
Args:
collection_name (str): Name of the collection to be checked.
Returns:
Dict[str, Any]: A dictionary containing details about the
collection.
"""
vector_count = self._client.get_collection_stats(collection_name)[
'row_count'
]
collection_info = self._client.describe_collection(collection_name)
collection_id = collection_info['collection_id']
dim_value = next(
(
field['params']['dim']
for field in collection_info['fields']
if field['description']
== 'The numerical representation of the vector'
),
None,
)
return {
"id": collection_id, # the id of the collection
"vector_count": vector_count, # the number of the vector
"vector_dim": dim_value, # the dimension of the vector
}
def _validate_and_convert_vectors(
self, records: List[VectorRecord]
) -> List[dict]:
r"""Validates and converts VectorRecord instances to the format
expected by Milvus.
Args:
records (List[VectorRecord]): List of vector records to validate
and convert.
Returns:
List[dict]: A list of dictionaries formatted for Milvus insertion.
"""
validated_data = []
for record in records:
record_dict = {
"id": record.id,
"payload": record.payload
if record.payload is not None
else '',
"vector": record.vector,
}
validated_data.append(record_dict)
return validated_data
[docs]
def add(
self,
records: List[VectorRecord],
**kwargs,
) -> None:
r"""Adds a list of vectors to the specified collection.
Args:
records (List[VectorRecord]): List of vectors to be added.
**kwargs (Any): Additional keyword arguments pass to insert.
Raises:
RuntimeError: If there was an error in the addition process.
"""
validated_records = self._validate_and_convert_vectors(records)
op_info = self._client.insert(
collection_name=self.collection_name,
data=validated_records,
**kwargs,
)
logger.debug(f"Successfully added vectors in Milvus: {op_info}")
[docs]
def delete(
self,
ids: List[str],
**kwargs: Any,
) -> None:
r"""Deletes a list of vectors identified by their IDs from the
storage. If unsure of ids you can first query the collection to grab
the corresponding data.
Args:
ids (List[str]): List of unique identifiers for the vectors to be
deleted.
**kwargs (Any): Additional keyword arguments passed to delete.
Raises:
RuntimeError: If there is an error during the deletion process.
"""
op_info = self._client.delete(
collection_name=self.collection_name, pks=ids, **kwargs
)
logger.debug(f"Successfully deleted vectors in Milvus: {op_info}")
[docs]
def status(self) -> VectorDBStatus:
r"""Retrieves the current status of the Milvus collection. This method
provides information about the collection, including its vector
dimensionality and the total number of vectors stored.
Returns:
VectorDBStatus: An object containing information about the
collection's status.
"""
status = self._get_collection_info(self.collection_name)
return VectorDBStatus(
vector_dim=status["vector_dim"],
vector_count=status["vector_count"],
)
[docs]
def query(
self,
query: VectorDBQuery,
**kwargs: Any,
) -> List[VectorDBQueryResult]:
r"""Searches for similar vectors in the storage based on the provided
query.
Args:
query (VectorDBQuery): The query object containing the search
vector and the number of top similar vectors to retrieve.
**kwargs (Any): Additional keyword arguments passed to search.
Returns:
List[VectorDBQueryResult]: A list of vectors retrieved from the
storage based on similarity to the query vector.
"""
search_result = self._client.search(
collection_name=self.collection_name,
data=[query.query_vector],
limit=query.top_k,
output_fields=['vector', 'payload'],
**kwargs,
)
query_results = []
for point in search_result:
query_results.append(
VectorDBQueryResult.create(
similarity=(point[0]['distance']),
id=str(point[0]['id']),
payload=(point[0]['entity'].get('payload')),
vector=point[0]['entity'].get('vector'),
)
)
return query_results
[docs]
def clear(self) -> None:
r"""Removes all vectors from the Milvus collection. This method
deletes the existing collection and then recreates it with the same
schema to effectively remove all stored vectors.
"""
self._delete_collection(self.collection_name)
self._create_collection(collection_name=self.collection_name)
[docs]
def load(self) -> None:
r"""Load the collection hosted on cloud service."""
self._client.load_collection(self.collection_name)
@property
def client(self) -> Any:
r"""Provides direct access to the Milvus client. This property allows
for direct interactions with the Milvus client for operations that are
not covered by the `MilvusStorage` class.
Returns:
Any: The Milvus client instance.
"""
return self._client