feat: rotate board on evaluation
This commit is contained in:
parent
29ddec60f7
commit
12735a7f76
@ -56,7 +56,23 @@ class GoModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
if self.training:
|
||||
return self.net(x)
|
||||
else:
|
||||
y = self.net(x)
|
||||
batch_size = x.size(0)
|
||||
|
||||
x_rotated = torch.stack([torch.rot90(x, k=k, dims=[2, 3]) for k in range(4)], dim=1) # x_rotated: [batch_size, 4, 3, 8, 8]
|
||||
x_rotated = x_rotated.view(-1, 3, 8, 8) # [batch_size*4, 3, 8, 8]
|
||||
|
||||
with torch.no_grad():
|
||||
y_rotated = self.net(x_rotated) # [batch_size*4, 1]
|
||||
|
||||
# Reshape to get them by rotation
|
||||
y_rotated = y_rotated.view(batch_size, 4, -1) # [batch_size, 4, 1]
|
||||
y_mean = y_rotated.mean(dim=1) # [batch_size, 1]
|
||||
|
||||
return y_mean
|
||||
|
||||
|
||||
class GoDataset(Dataset):
|
||||
|
Loading…
x
Reference in New Issue
Block a user