diff --git a/go_ia/main.py b/go_ia/main.py index 42e559a..d78b84f 100755 --- a/go_ia/main.py +++ b/go_ia/main.py @@ -56,7 +56,23 @@ class GoModel(nn.Module): ) def forward(self, x): - return self.net(x) + if self.training: + return self.net(x) + else: + y = self.net(x) + batch_size = x.size(0) + + x_rotated = torch.stack([torch.rot90(x, k=k, dims=[2, 3]) for k in range(4)], dim=1) # x_rotated: [batch_size, 4, 3, 8, 8] + x_rotated = x_rotated.view(-1, 3, 8, 8) # [batch_size*4, 3, 8, 8] + + with torch.no_grad(): + y_rotated = self.net(x_rotated) # [batch_size*4, 1] + + # Reshape to get them by rotation + y_rotated = y_rotated.view(batch_size, 4, -1) # [batch_size, 4, 1] + y_mean = y_rotated.mean(dim=1) # [batch_size, 1] + + return y_mean class GoDataset(Dataset):