Spaces:
Running
Running
| """ Milvus memory storage provider.""" | |
| from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections | |
| from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding | |
| class MilvusMemory(MemoryProviderSingleton): | |
| """Milvus memory storage provider.""" | |
| def __init__(self, cfg) -> None: | |
| """Construct a milvus memory storage connection. | |
| Args: | |
| cfg (Config): Auto-GPT global config. | |
| """ | |
| # connect to milvus server. | |
| connections.connect(address=cfg.milvus_addr) | |
| fields = [ | |
| FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), | |
| FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=1536), | |
| FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535), | |
| ] | |
| # create collection if not exist and load it. | |
| self.milvus_collection = cfg.milvus_collection | |
| self.schema = CollectionSchema(fields, "auto-gpt memory storage") | |
| self.collection = Collection(self.milvus_collection, self.schema) | |
| # create index if not exist. | |
| if not self.collection.has_index(): | |
| self.collection.release() | |
| self.collection.create_index( | |
| "embeddings", | |
| { | |
| "metric_type": "IP", | |
| "index_type": "HNSW", | |
| "params": {"M": 8, "efConstruction": 64}, | |
| }, | |
| index_name="embeddings", | |
| ) | |
| self.collection.load() | |
| def add(self, data) -> str: | |
| """Add an embedding of data into memory. | |
| Args: | |
| data (str): The raw text to construct embedding index. | |
| Returns: | |
| str: log. | |
| """ | |
| embedding = get_ada_embedding(data) | |
| result = self.collection.insert([[embedding], [data]]) | |
| _text = ( | |
| "Inserting data into memory at primary key: " | |
| f"{result.primary_keys[0]}:\n data: {data}" | |
| ) | |
| return _text | |
| def get(self, data): | |
| """Return the most relevant data in memory. | |
| Args: | |
| data: The data to compare to. | |
| """ | |
| return self.get_relevant(data, 1) | |
| def clear(self) -> str: | |
| """Drop the index in memory. | |
| Returns: | |
| str: log. | |
| """ | |
| self.collection.drop() | |
| self.collection = Collection(self.milvus_collection, self.schema) | |
| self.collection.create_index( | |
| "embeddings", | |
| { | |
| "metric_type": "IP", | |
| "index_type": "HNSW", | |
| "params": {"M": 8, "efConstruction": 64}, | |
| }, | |
| index_name="embeddings", | |
| ) | |
| self.collection.load() | |
| return "Obliviated" | |
| def get_relevant(self, data: str, num_relevant: int = 5): | |
| """Return the top-k relevant data in memory. | |
| Args: | |
| data: The data to compare to. | |
| num_relevant (int, optional): The max number of relevant data. | |
| Defaults to 5. | |
| Returns: | |
| list: The top-k relevant data. | |
| """ | |
| # search the embedding and return the most relevant text. | |
| embedding = get_ada_embedding(data) | |
| search_params = { | |
| "metrics_type": "IP", | |
| "params": {"nprobe": 8}, | |
| } | |
| result = self.collection.search( | |
| [embedding], | |
| "embeddings", | |
| search_params, | |
| num_relevant, | |
| output_fields=["raw_text"], | |
| ) | |
| return [item.entity.value_of_field("raw_text") for item in result[0]] | |
| def get_stats(self) -> str: | |
| """ | |
| Returns: The stats of the milvus cache. | |
| """ | |
| return f"Entities num: {self.collection.num_entities}" | |