Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from .em import EM, EmptyClusterResolveError | |
| class PQ(EM): | |
| """ | |
| Quantizes the layer weights W with the standard Product Quantization | |
| technique. This learns a codebook of codewords or centroids of size | |
| block_size from W. For further reference on using PQ to quantize | |
| neural networks, see "And the Bit Goes Down: Revisiting the Quantization | |
| of Neural Networks", Stock et al., ICLR 2020. | |
| PQ is performed in two steps: | |
| (1) The matrix W (weights or fully-connected or convolutional layer) | |
| is reshaped to (block_size, -1). | |
| - If W is fully-connected (2D), its columns are split into | |
| blocks of size block_size. | |
| - If W is convolutional (4D), its filters are split along the | |
| spatial dimension. | |
| (2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix. | |
| Args: | |
| - W: weight matrix to quantize of size (in_features x out_features) | |
| - block_size: size of the blocks (subvectors) | |
| - n_centroids: number of centroids | |
| - n_iter: number of k-means iterations | |
| - eps: for cluster reassignment when an empty cluster is found | |
| - max_tentatives for cluster reassignment when an empty cluster is found | |
| - verbose: print information after each iteration | |
| Remarks: | |
| - block_size be compatible with the shape of W | |
| """ | |
| def __init__( | |
| self, | |
| W, | |
| block_size, | |
| n_centroids=256, | |
| n_iter=20, | |
| eps=1e-6, | |
| max_tentatives=30, | |
| verbose=True, | |
| ): | |
| self.block_size = block_size | |
| W_reshaped = self._reshape(W) | |
| super(PQ, self).__init__( | |
| W_reshaped, | |
| n_centroids=n_centroids, | |
| n_iter=n_iter, | |
| eps=eps, | |
| max_tentatives=max_tentatives, | |
| verbose=verbose, | |
| ) | |
| def _reshape(self, W): | |
| """ | |
| Reshapes the matrix W as expained in step (1). | |
| """ | |
| # fully connected: by convention the weight has size out_features x in_features | |
| if len(W.size()) == 2: | |
| self.out_features, self.in_features = W.size() | |
| assert ( | |
| self.in_features % self.block_size == 0 | |
| ), "Linear: n_blocks must be a multiple of in_features" | |
| return ( | |
| W.reshape(self.out_features, -1, self.block_size) | |
| .permute(2, 1, 0) | |
| .flatten(1, 2) | |
| ) | |
| # convolutional: we reshape along the spatial dimension | |
| elif len(W.size()) == 4: | |
| self.out_channels, self.in_channels, self.k_h, self.k_w = W.size() | |
| assert ( | |
| self.in_channels * self.k_h * self.k_w | |
| ) % self.block_size == 0, ( | |
| "Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w" | |
| ) | |
| return ( | |
| W.reshape(self.out_channels, -1, self.block_size) | |
| .permute(2, 1, 0) | |
| .flatten(1, 2) | |
| ) | |
| # not implemented | |
| else: | |
| raise NotImplementedError(W.size()) | |
| def encode(self): | |
| """ | |
| Performs self.n_iter EM steps. | |
| """ | |
| self.initialize_centroids() | |
| for i in range(self.n_iter): | |
| try: | |
| self.step(i) | |
| except EmptyClusterResolveError: | |
| break | |
| def decode(self): | |
| """ | |
| Returns the encoded full weight matrix. Must be called after | |
| the encode function. | |
| """ | |
| # fully connected case | |
| if "k_h" not in self.__dict__: | |
| return ( | |
| self.centroids[self.assignments] | |
| .reshape(-1, self.out_features, self.block_size) | |
| .permute(1, 0, 2) | |
| .flatten(1, 2) | |
| ) | |
| # convolutional case | |
| else: | |
| return ( | |
| self.centroids[self.assignments] | |
| .reshape(-1, self.out_channels, self.block_size) | |
| .permute(1, 0, 2) | |
| .reshape(self.out_channels, self.in_channels, self.k_h, self.k_w) | |
| ) | |