pull/250/head
Jose 2 years ago committed by GitHub
parent 53d74c6826
commit 974f6fe560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -168,8 +168,8 @@ class ConvGRU(nn.Module):
def forward_single_frame(self, x, h): def forward_single_frame(self, x, h):
r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1) r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
c = self.hh(torch.cat([x, r * h], dim=1)) c = self.hh(torch.cat([x, r * h], dim=1))
h = (1 - z) * h + z * c h = (1 - z) * c + z * h
return h, h return c, h
def forward_time_series(self, x, h): def forward_time_series(self, x, h):
o = [] o = []

Loading…
Cancel
Save