diff --git a/torchvggish/vggish.py b/torchvggish/vggish.py index de9de97..5645b46 100644 --- a/torchvggish/vggish.py +++ b/torchvggish/vggish.py @@ -44,41 +44,20 @@ class Postprocessor(nn.Module): the same PCA (with whitening) and quantization transformations." """ - def __init__(self, pretrained, params_url, progress=True): + def __init__(self): """Constructs a postprocessor.""" super(Postprocessor, self).__init__() - if pretrained: - self.init_params_pth_url(params_url, progress=progress) - else: - # Create empty matrix, for user's state_dict to load - self.pca_matrix = torch.empty( - (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), - dtype=torch.float, - ) - self.pca_means = torch.empty( - (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float - ) - - self.pca_matrix = nn.Parameter(self.pca_matrix, requires_grad=False) - self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) - - def init_params_pth_url(self, pca_params_dict_url, progress=True): - params = hub.load_state_dict_from_url(pca_params_dict_url, progress=progress) - self.pca_matrix = torch.as_tensor( - params[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float + # Create empty matrix, for user's state_dict to load + self.pca_eigen_vectors = torch.empty( + (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), + dtype=torch.float, ) - self.pca_means = torch.as_tensor( - params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float + self.pca_means = torch.empty( + (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float ) - assert self.pca_matrix.shape == ( - vggish_params.EMBEDDING_SIZE, - vggish_params.EMBEDDING_SIZE, - ), "Bad PCA matrix shape: %r" % (self.pca_matrix.shape,) - assert self.pca_means.shape == ( - vggish_params.EMBEDDING_SIZE, - 1, - ), "Bad PCA means shape: %r" % (self.pca_means.shape,) + self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) + self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) def postprocess(self, embeddings_batch): """Applies tensor postprocessing to a batch of embeddings. @@ -105,7 +84,7 @@ def postprocess(self, embeddings_batch): # - Premultiply by PCA matrix of shape [output_dims, input_dims] # where both are are equal to embedding_size in our case. # - Transpose result back to [batch_size, embedding_size]. - pca_applied = torch.mm(self.pca_matrix, (embeddings_batch.T - self.pca_means)).T + pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.T - self.pca_means)).T # Quantize by: # - clipping to [min, max] range @@ -162,15 +141,27 @@ def _vgg(): class VGGish(VGG): - def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True): + def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, progress=True): super().__init__(make_layers()) if pretrained: - state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=True) + state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress) super().load_state_dict(state_dict) self.preprocess = preprocess self.postprocess = postprocess - self.pproc = Postprocessor(urls['pca']) + if self.preprocess: + self.pproc = Postprocessor() + if pretrained: + state_dict = hub.load_state_dict_from_url(urls['pca'], progress=progress) + # TODO: Convert the state_dict to torch + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float + ) + state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( + state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float + ) + + self.pproc.load_state_dict(state_dict) def forward(self, x, fs=None): if self.preprocess: @@ -190,4 +181,4 @@ def _preprocess(self, x, fs): return x def _postprocess(self, x): - return self.pproc.postprocess(x) + return self.pproc(x)