feat: add test arg for dataset

This commit is contained in:
Martin Eyben 2025-05-16 10:04:16 +02:00
parent 624da7dd62
commit 15e64ffad3

View File

@ -54,7 +54,7 @@ class GoModel(nn.Module):
class GoDataset(Dataset): class GoDataset(Dataset):
def __init__(self, data, device): def __init__(self, data, device, test=False):
def label(d, j): def label(d, j):
if j == 0: if j == 0:
return d["black_wins"] / d["rollouts"] return d["black_wins"] / d["rollouts"]
@ -72,17 +72,26 @@ class GoDataset(Dataset):
else: else:
return out.flipud() return out.flipud()
dims = [1, 2] if test:
self.boards = torch.from_numpy(np.array([ dims = [1, 2]
torch.rot90(board(d, j, k), i, dims) self.boards = torch.from_numpy(np.array([
for d in data board(d, 0, 0) for d in data
for k in range(2) ])).float().to(device)
for i in range(4) self.labels = torch.from_numpy(np.array(
for j in range(2) [label(d, 0) for d in data],
])).float().to(device) )).float().to(device)
self.labels = torch.from_numpy(np.array( else:
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)], dims = [1, 2]
)).float().to(device) self.boards = torch.from_numpy(np.array([
torch.rot90(board(d, j, k), i, dims)
for d in data
for k in range(2)
for i in range(4)
for j in range(2)
])).float().to(device)
self.labels = torch.from_numpy(np.array(
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
)).float().to(device)
def __len__(self): def __len__(self):
return len(self.boards) return len(self.boards)
@ -274,6 +283,7 @@ if __name__ == "__main__":
parser.add_argument("-d", help="path to dataset input") parser.add_argument("-d", help="path to dataset input")
parser.add_argument("-n", help="Number of epoch", type=int, default=25) parser.add_argument("-n", help="Number of epoch", type=int, default=25)
parser.add_argument("-l", help="learning rate", type=float, default=0.01) parser.add_argument("-l", help="learning rate", type=float, default=0.01)
parser.add_argument("-R", help="Create result file from a dataset")
args = parser.parse_args() args = parser.parse_args()
device = setup_device() device = setup_device()
@ -301,7 +311,7 @@ if __name__ == "__main__":
testdata = rawdata[:train_size] testdata = rawdata[:train_size]
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device) trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device) testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device, test=True)
batch_size = 8192 batch_size = 8192
epochs = args.n epochs = args.n