|
|
|
@ -59,7 +59,7 @@ class WaveRNN(nn.Module) :
|
|
|
|
|
# Compute all gates for coarse and fine
|
|
|
|
|
u = F.sigmoid(R_u + I_u + self.bias_u)
|
|
|
|
|
r = F.sigmoid(R_r + I_r + self.bias_r)
|
|
|
|
|
e = F.tanh(r * R_e + I_e + self.bias_e)
|
|
|
|
|
e = torch.tanh(r * R_e + I_e + self.bias_e)
|
|
|
|
|
hidden = u * prev_hidden + (1. - u) * e
|
|
|
|
|
|
|
|
|
|
# Split the hidden state
|
|
|
|
@ -118,7 +118,7 @@ class WaveRNN(nn.Module) :
|
|
|
|
|
# Compute the coarse gates
|
|
|
|
|
u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
|
|
|
|
|
r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
|
|
|
|
|
e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
|
|
|
|
|
e = torch.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
|
|
|
|
|
hidden_coarse = u * hidden_coarse + (1. - u) * e
|
|
|
|
|
|
|
|
|
|
# Compute the coarse output
|
|
|
|
@ -138,7 +138,7 @@ class WaveRNN(nn.Module) :
|
|
|
|
|
# Compute the fine gates
|
|
|
|
|
u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
|
|
|
|
|
r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
|
|
|
|
|
e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
|
|
|
|
|
e = torch.tanh(r * R_fine_e + I_fine_e + b_fine_e)
|
|
|
|
|
hidden_fine = u * hidden_fine + (1. - u) * e
|
|
|
|
|
|
|
|
|
|
# Compute the fine output
|
|
|
|
|