Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import h5py | |
| import numpy as np | |
| from functools import partial | |
| from utils.gen_utils import map_nlist, vround | |
| import regex as re | |
| from spacyface.simple_spacy_token import SimpleSpacyToken | |
| from data_processing.sentence_data_wrapper import SentenceH5Data, TokenH5Data | |
| from utils.f import ifnone | |
| ZERO_BUFFER = 12 # Number of decimal places each index takes | |
| main_key = r"{:0" + str(ZERO_BUFFER) + r"}" | |
| def to_idx(idx:int): | |
| return main_key.format(idx) | |
| def zip_len_check(*iters): | |
| """Zip iterables with a check that they are all the same length""" | |
| if len(iters) < 2: | |
| raise ValueError(f"Expected at least 2 iterables to combine. Got {len(iters)} iterables") | |
| n = len(iters[0]) | |
| for i in iters: | |
| n_ = len(i) | |
| if n_ != n: | |
| raise ValueError(f"Expected all iterations to have len {n} but found {n_}") | |
| return zip(*iters) | |
| class CorpusDataWrapper: | |
| """A wrapper for both the token embeddings and the head context. | |
| This class allows access into an HDF5 file designed according to the data/processing module's contents as if it were | |
| and in memory dictionary. | |
| """ | |
| def __init__(self, fname, name=None): | |
| """Open an hdf5 file of the format designed and provide easy access to its contents""" | |
| # For iterating through the dataset | |
| self.__curr = 0 | |
| self.__name = ifnone(name, "CorpusData") | |
| self.fname = fname | |
| self.data = h5py.File(fname, 'r') | |
| main_keys = self.data.keys() | |
| self.__len = len(main_keys) | |
| assert self.__len > 0, "Cannot process an empty file" | |
| embeds = self[0].embeddings | |
| self.embedding_dim = embeds.shape[-1] | |
| self.n_layers = embeds.shape[0] - 1 # 1 was added for the input layer | |
| self.refmap, self.total_vectors = self._init_vector_map() | |
| def __del__(self): | |
| try: self.data.close() | |
| # If run as a script, won't be able to close because of an import error | |
| except ImportError: pass | |
| except AttributeError: | |
| print(f"Never successfully loaded {self.fname}") | |
| def __iter__(self): | |
| return self | |
| def __len__(self): | |
| return self.__len | |
| def __next__(self): | |
| if self.__curr >= self.__len: | |
| self.__curr = 0 | |
| raise StopIteration | |
| out = self[self.__curr] | |
| self.__curr += 1 | |
| return out | |
| def __getitem__(self, idx): | |
| """Index into the embeddings""" | |
| if isinstance(idx, slice): | |
| start = idx.start or 0 | |
| step = idx.step or 1 | |
| stop = idx.stop or (self.__len - 1) | |
| stop = min(stop, self.__len) | |
| i = start | |
| out = [] | |
| while i < stop: | |
| out.append(self[i]) | |
| i += step | |
| return out | |
| elif isinstance(idx, int): | |
| if idx < 0: i = self.__len + idx | |
| else: i = idx | |
| key = to_idx(i) | |
| return SentenceH5Data(self.data[key]) | |
| else: | |
| raise NotImplementedError | |
| def __repr__(self): | |
| return f"{self.__name}: containing {self.__len} items" | |
| def _init_vector_map(self): | |
| """Create main hashmap for all vectors to get their metadata. | |
| TODO Initialization is a little slow... Should this be stored in a separate hdf5 file? | |
| This doesn't change. Check for special hdf5 file and see if it exists already. If it does, open it. | |
| If not, create it | |
| """ | |
| refmap = {} | |
| print("Initializing reference map for embedding vector...") | |
| n_vec = 0 | |
| for z, sentence in enumerate(self): | |
| for i in range(len(sentence)): | |
| refs = TokenH5Data(sentence, i) | |
| refmap[n_vec] = refs | |
| n_vec += 1 | |
| return refmap, n_vec | |
| def extract(self, layer): | |
| """Extract embeddings from a particular layer from the dataset | |
| For all examples | |
| """ | |
| embeddings = [] | |
| for i, embeds in enumerate(self): | |
| embeddings.append(embeds[layer]) | |
| out = np.vstack(embeddings) | |
| return out | |
| def find(self, vec_num): | |
| """Find a vector's metadata (by id) in the hdf5 file. Needed to find sentence info and other attr""" | |
| return self.refmap[vec_num] | |
| def find2d(self, idxs): | |
| """Find a vector's metadata in the hdf5 file. Needed to find sentence info and other attr""" | |
| out = [[self.refmap[i] for i in idx] for idx in idxs] | |
| return out |