Spaces:
Running
Running
| import uuid | |
| import weaviate | |
| from weaviate import Client | |
| from weaviate.embedded import EmbeddedOptions | |
| from weaviate.util import generate_uuid5 | |
| from autogpt.config import Config | |
| from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding | |
| def default_schema(weaviate_index): | |
| return { | |
| "class": weaviate_index, | |
| "properties": [ | |
| { | |
| "name": "raw_text", | |
| "dataType": ["text"], | |
| "description": "original text for the embedding", | |
| } | |
| ], | |
| } | |
| class WeaviateMemory(MemoryProviderSingleton): | |
| def __init__(self, cfg): | |
| auth_credentials = self._build_auth_credentials(cfg) | |
| url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}" | |
| if cfg.use_weaviate_embedded: | |
| self.client = Client( | |
| embedded_options=EmbeddedOptions( | |
| hostname=cfg.weaviate_host, | |
| port=int(cfg.weaviate_port), | |
| persistence_data_path=cfg.weaviate_embedded_path, | |
| ) | |
| ) | |
| print( | |
| f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}" | |
| ) | |
| else: | |
| self.client = Client(url, auth_client_secret=auth_credentials) | |
| self.index = WeaviateMemory.format_classname(cfg.memory_index) | |
| self._create_schema() | |
| def format_classname(index): | |
| # weaviate uses capitalised index names | |
| # The python client uses the following code to format | |
| # index names before the corresponding class is created | |
| if len(index) == 1: | |
| return index.capitalize() | |
| return index[0].capitalize() + index[1:] | |
| def _create_schema(self): | |
| schema = default_schema(self.index) | |
| if not self.client.schema.contains(schema): | |
| self.client.schema.create_class(schema) | |
| def _build_auth_credentials(self, cfg): | |
| if cfg.weaviate_username and cfg.weaviate_password: | |
| return weaviate.AuthClientPassword( | |
| cfg.weaviate_username, cfg.weaviate_password | |
| ) | |
| if cfg.weaviate_api_key: | |
| return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key) | |
| else: | |
| return None | |
| def add(self, data): | |
| vector = get_ada_embedding(data) | |
| doc_uuid = generate_uuid5(data, self.index) | |
| data_object = {"raw_text": data} | |
| with self.client.batch as batch: | |
| batch.add_data_object( | |
| uuid=doc_uuid, | |
| data_object=data_object, | |
| class_name=self.index, | |
| vector=vector, | |
| ) | |
| return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" | |
| def get(self, data): | |
| return self.get_relevant(data, 1) | |
| def clear(self): | |
| self.client.schema.delete_all() | |
| # weaviate does not yet have a neat way to just remove the items in an index | |
| # without removing the entire schema, therefore we need to re-create it | |
| # after a call to delete_all | |
| self._create_schema() | |
| return "Obliterated" | |
| def get_relevant(self, data, num_relevant=5): | |
| query_embedding = get_ada_embedding(data) | |
| try: | |
| results = ( | |
| self.client.query.get(self.index, ["raw_text"]) | |
| .with_near_vector({"vector": query_embedding, "certainty": 0.7}) | |
| .with_limit(num_relevant) | |
| .do() | |
| ) | |
| if len(results["data"]["Get"][self.index]) > 0: | |
| return [ | |
| str(item["raw_text"]) for item in results["data"]["Get"][self.index] | |
| ] | |
| else: | |
| return [] | |
| except Exception as err: | |
| print(f"Unexpected error {err=}, {type(err)=}") | |
| return [] | |
| def get_stats(self): | |
| result = self.client.query.aggregate(self.index).with_meta_count().do() | |
| class_data = result["data"]["Aggregate"][self.index] | |
| return class_data[0]["meta"] if class_data else {} | |