From 12735a7f765d3757b91f36a04cab49c30efce7d0 Mon Sep 17 00:00:00 2001 From: Martin Eyben Date: Fri, 16 May 2025 11:26:52 +0200 Subject: [PATCH] feat: rotate board on evaluation --- go_ia/main.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/go_ia/main.py b/go_ia/main.py index 42e559a..d78b84f 100755 --- a/go_ia/main.py +++ b/go_ia/main.py @@ -56,7 +56,23 @@ class GoModel(nn.Module): ) def forward(self, x): - return self.net(x) + if self.training: + return self.net(x) + else: + y = self.net(x) + batch_size = x.size(0) + + x_rotated = torch.stack([torch.rot90(x, k=k, dims=[2, 3]) for k in range(4)], dim=1) # x_rotated: [batch_size, 4, 3, 8, 8] + x_rotated = x_rotated.view(-1, 3, 8, 8) # [batch_size*4, 3, 8, 8] + + with torch.no_grad(): + y_rotated = self.net(x_rotated) # [batch_size*4, 1] + + # Reshape to get them by rotation + y_rotated = y_rotated.view(batch_size, 4, -1) # [batch_size, 4, 1] + y_mean = y_rotated.mean(dim=1) # [batch_size, 1] + + return y_mean class GoDataset(Dataset):