feat: rotate board on evaluation

This commit is contained in:
Martin Eyben 2025-05-16 11:26:52 +02:00
parent 29ddec60f7
commit 12735a7f76

View File

@ -56,7 +56,23 @@ class GoModel(nn.Module):
)
def forward(self, x):
if self.training:
return self.net(x)
else:
y = self.net(x)
batch_size = x.size(0)
x_rotated = torch.stack([torch.rot90(x, k=k, dims=[2, 3]) for k in range(4)], dim=1) # x_rotated: [batch_size, 4, 3, 8, 8]
x_rotated = x_rotated.view(-1, 3, 8, 8) # [batch_size*4, 3, 8, 8]
with torch.no_grad():
y_rotated = self.net(x_rotated) # [batch_size*4, 1]
# Reshape to get them by rotation
y_rotated = y_rotated.view(batch_size, 4, -1) # [batch_size, 4, 1]
y_mean = y_rotated.mean(dim=1) # [batch_size, 1]
return y_mean
class GoDataset(Dataset):