diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py
index 5c3fce6..534b0fa 100644
--- a/synthesizer/models/tacotron.py
+++ b/synthesizer/models/tacotron.py
@@ -127,7 +127,7 @@ class CBHG(nn.Module):
         # Although we `_flatten_parameters()` on init, when using DataParallel
         # the model gets replicated, making it no longer guaranteed that the
         # weights are contiguous in GPU memory. Hence, we must call it again
-        self._flatten_parameters()
+        self.rnn.flatten_parameters()
 
         # Save these for later
         residual = x
@@ -214,7 +214,7 @@ class LSA(nn.Module):
         self.attention = None
 
     def init_attention(self, encoder_seq_proj):
-        device = next(self.parameters()).device  # use same device as parameters
+        device = encoder_seq_proj.device  # use same device as parameters
         b, t, c = encoder_seq_proj.size()
         self.cumulative = torch.zeros(b, t, device=device)
         self.attention = torch.zeros(b, t, device=device)
@@ -265,9 +265,8 @@ class Decoder(nn.Module):
         self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
         self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
 
-    def zoneout(self, prev, current, p=0.1):
-        device = next(self.parameters()).device  # Use same device as parameters
-        mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
+    def zoneout(self, prev, current, device, p=0.1):
+        mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
         return prev * mask + current * (1 - mask)
 
     def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
@@ -275,7 +274,7 @@ class Decoder(nn.Module):
 
         # Need this for reshaping mels
         batch_size = encoder_seq.size(0)
-
+        device = encoder_seq.device
         # Unpack the hidden and cell states
         attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
         rnn1_cell, rnn2_cell = cell_states
@@ -301,7 +300,7 @@ class Decoder(nn.Module):
         # Compute first Residual RNN
         rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
         if self.training:
-            rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
+            rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
         else:
             rnn1_hidden = rnn1_hidden_next
         x = x + rnn1_hidden
@@ -309,7 +308,7 @@ class Decoder(nn.Module):
         # Compute second Residual RNN
         rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
         if self.training:
-            rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
+            rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
         else:
             rnn2_hidden = rnn2_hidden_next
         x = x + rnn2_hidden
@@ -374,7 +373,7 @@ class Tacotron(nn.Module):
         return outputs
 
     def forward(self, texts, mels, speaker_embedding):
-        device = next(self.parameters()).device  # use same device as parameters
+        device = texts.device  # use same device as parameters
 
         self.step += 1
         batch_size, _, steps  = mels.size()
@@ -440,7 +439,7 @@ class Tacotron(nn.Module):
 
     def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
         self.eval()
-        device = next(self.parameters()).device  # use same device as parameters
+        device = x.device  # use same device as parameters
 
         batch_size, _  = x.size()
 
@@ -542,8 +541,7 @@ class Tacotron(nn.Module):
 
     def load(self, path, optimizer=None):
         # Use device of model params as location for loaded state
-        device = next(self.parameters()).device
-        checkpoint = torch.load(str(path), map_location=device)
+        checkpoint = torch.load(str(path))
         self.load_state_dict(checkpoint["model_state"], strict=False)
 
         if "optimizer_state" in checkpoint and optimizer is not None: