feat: rotate board on evaluation
This commit is contained in:
parent
29ddec60f7
commit
12735a7f76
@ -56,7 +56,23 @@ class GoModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
if self.training:
|
||||||
|
return self.net(x)
|
||||||
|
else:
|
||||||
|
y = self.net(x)
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
x_rotated = torch.stack([torch.rot90(x, k=k, dims=[2, 3]) for k in range(4)], dim=1) # x_rotated: [batch_size, 4, 3, 8, 8]
|
||||||
|
x_rotated = x_rotated.view(-1, 3, 8, 8) # [batch_size*4, 3, 8, 8]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_rotated = self.net(x_rotated) # [batch_size*4, 1]
|
||||||
|
|
||||||
|
# Reshape to get them by rotation
|
||||||
|
y_rotated = y_rotated.view(batch_size, 4, -1) # [batch_size, 4, 1]
|
||||||
|
y_mean = y_rotated.mean(dim=1) # [batch_size, 1]
|
||||||
|
|
||||||
|
return y_mean
|
||||||
|
|
||||||
|
|
||||||
class GoDataset(Dataset):
|
class GoDataset(Dataset):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user