@ -168,8 +168,8 @@ class ConvGRU(nn.Module):
def forward_single_frame(self, x, h):
r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
c = self.hh(torch.cat([x, r * h], dim=1))
h = (1 - z) * h + z * c
return h, h
h = (1 - z) * c + z * h
return c, h
def forward_time_series(self, x, h):
o = []