|
|
|
@ -471,15 +471,13 @@ class Tacotron(nn.Module):
|
|
|
|
|
# put after encoder
|
|
|
|
|
if hparams.use_gst and self.gst is not None:
|
|
|
|
|
if style_idx >= 0 and style_idx < 10:
|
|
|
|
|
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
|
|
|
|
gst_embed = np.tile(gst_embed, (1, 8))
|
|
|
|
|
scale = np.zeros(512)
|
|
|
|
|
scale[:] = 0.3
|
|
|
|
|
speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32)
|
|
|
|
|
speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device)
|
|
|
|
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda()
|
|
|
|
|
gst_embed = torch.tanh(self.gst.stl.embed)
|
|
|
|
|
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
|
|
|
|
style_embed = self.gst.stl.attention(query, key)
|
|
|
|
|
else:
|
|
|
|
|
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
|
|
|
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
|
|
|
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
|
|
|
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
|
|
|
|
# style_embed = style_embed.expand_as(encoder_seq)
|
|
|
|
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
|
|
|
|