|
|
@ -168,8 +168,8 @@ class ConvGRU(nn.Module):
|
|
|
|
def forward_single_frame(self, x, h):
|
|
|
|
def forward_single_frame(self, x, h):
|
|
|
|
r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
|
|
|
|
r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
|
|
|
|
c = self.hh(torch.cat([x, r * h], dim=1))
|
|
|
|
c = self.hh(torch.cat([x, r * h], dim=1))
|
|
|
|
h = (1 - z) * h + z * c
|
|
|
|
h = (1 - z) * c + z * h
|
|
|
|
return h, h
|
|
|
|
return c, h
|
|
|
|
|
|
|
|
|
|
|
|
def forward_time_series(self, x, h):
|
|
|
|
def forward_time_series(self, x, h):
|
|
|
|
o = []
|
|
|
|
o = []
|
|
|
|