diff --git a/synthesizer/models/global_style_token.py b/synthesizer/models/global_style_token.py index 79282c2..cef3009 100644 --- a/synthesizer/models/global_style_token.py +++ b/synthesizer/models/global_style_token.py @@ -6,15 +6,22 @@ from synthesizer.gst_hyperparameters import GSTHyperparameters as hp class GlobalStyleToken(nn.Module): - - def __init__(self): + """ + inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel] + speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + def __init__(self, speaker_embedding_dim=None): super().__init__() self.encoder = ReferenceEncoder() - self.stl = STL() + self.stl = STL(speaker_embedding_dim) - def forward(self, inputs): + def forward(self, inputs, speaker_embedding=None): enc_out = self.encoder(inputs) + # concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py + if speaker_embedding is not None: + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) style_embed = self.stl(enc_out) return style_embed @@ -73,13 +80,15 @@ class STL(nn.Module): inputs --- [N, E//2] ''' - def __init__(self): + def __init__(self, speaker_embedding_dim=None): super().__init__() self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads)) d_q = hp.E // 2 d_k = hp.E // hp.num_heads # self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v) + if speaker_embedding_dim: + d_q += speaker_embedding_dim self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads) init.normal_(self.embed, mean=0, std=0.5) diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index e7c26f2..87a481a 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -338,7 +338,7 @@ class Tacotron(nn.Module): self.encoder = Encoder(embed_dims, num_chars, encoder_dims, encoder_K, num_highways, dropout) self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size + gst_hp.E, decoder_dims, bias=False) - self.gst = GlobalStyleToken() + self.gst = GlobalStyleToken(speaker_embedding_size) self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims, dropout, speaker_embedding_size) self.postnet = CBHG(postnet_K, n_mels, postnet_dims, @@ -359,6 +359,34 @@ class Tacotron(nn.Module): def r(self, value): self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) + def compute_gst(self, inputs, style_input, speaker_embedding=None): + """ Compute global style token """ + device = inputs.device + if isinstance(style_input, dict): + query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + if speaker_embedding is not None: + query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) + + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + else: + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) + return inputs + + @staticmethod + def _concat_speaker_embedding(outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand( + outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) + return outputs + def forward(self, texts, mels, speaker_embedding): device = next(self.parameters()).device # use same device as parameters @@ -387,9 +415,10 @@ class Tacotron(nn.Module): encoder_seq = self.encoder(texts, speaker_embedding) # put after encoder if self.gst is not None: - style_embed = self.gst(speaker_embedding) - style_embed = style_embed.expand_as(encoder_seq) - encoder_seq = torch.cat((encoder_seq, style_embed), 2) + style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced + # style_embed = style_embed.expand_as(encoder_seq) + # encoder_seq = torch.cat((encoder_seq, style_embed), 2) + encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) encoder_seq_proj = self.encoder_proj(encoder_seq) # Need a couple of lists for outputs @@ -454,11 +483,14 @@ class Tacotron(nn.Module): gst_embed = np.tile(gst_embed, (1, 8)) scale = np.zeros(512) scale[:] = 0.3 - speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32) - speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device) - style_embed = self.gst(speaker_embedding) - style_embed = style_embed.expand_as(encoder_seq) - encoder_seq = torch.cat((encoder_seq, style_embed), 2) + speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32) + speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device) + else: + speaker_embedding_style = torch.zeros(2, 1, self.speaker_embedding_size).to(device) + style_embed = self.gst(speaker_embedding_style, speaker_embedding) + encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) + # style_embed = style_embed.expand_as(encoder_seq) + # encoder_seq = torch.cat((encoder_seq, style_embed), 2) encoder_seq_proj = self.encoder_proj(encoder_seq) # Need a couple of lists for outputs