|
|
|
@ -470,7 +470,9 @@ class Tacotron(nn.Module):
|
|
|
|
|
# put after encoder
|
|
|
|
|
if hparams.use_gst and self.gst is not None:
|
|
|
|
|
if style_idx >= 0 and style_idx < 10:
|
|
|
|
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda()
|
|
|
|
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
|
|
|
|
if device.type == 'cuda':
|
|
|
|
|
query = query.cuda()
|
|
|
|
|
gst_embed = torch.tanh(self.gst.stl.embed)
|
|
|
|
|
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
|
|
|
|
style_embed = self.gst.stl.attention(query, key)
|
|
|
|
@ -539,9 +541,9 @@ class Tacotron(nn.Module):
|
|
|
|
|
with open(path, "a") as f:
|
|
|
|
|
print(msg, file=f)
|
|
|
|
|
|
|
|
|
|
def load(self, path, optimizer=None):
|
|
|
|
|
def load(self, path, device, optimizer=None):
|
|
|
|
|
# Use device of model params as location for loaded state
|
|
|
|
|
checkpoint = torch.load(str(path))
|
|
|
|
|
checkpoint = torch.load(str(path), map_location=device)
|
|
|
|
|
self.load_state_dict(checkpoint["model_state"], strict=False)
|
|
|
|
|
|
|
|
|
|
if "optimizer_state" in checkpoint and optimizer is not None:
|
|
|
|
|