From 15e64ffad3eb1bc8a4334b3bf02f27e17f60e70e Mon Sep 17 00:00:00 2001 From: Martin Eyben Date: Fri, 16 May 2025 10:04:16 +0200 Subject: [PATCH] feat: add test arg for dataset --- go_ia/main.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/go_ia/main.py b/go_ia/main.py index 83130f2..0cb88af 100755 --- a/go_ia/main.py +++ b/go_ia/main.py @@ -54,7 +54,7 @@ class GoModel(nn.Module): class GoDataset(Dataset): - def __init__(self, data, device): + def __init__(self, data, device, test=False): def label(d, j): if j == 0: return d["black_wins"] / d["rollouts"] @@ -71,18 +71,27 @@ class GoDataset(Dataset): return out else: return out.flipud() - - dims = [1, 2] - self.boards = torch.from_numpy(np.array([ - torch.rot90(board(d, j, k), i, dims) - for d in data - for k in range(2) - for i in range(4) - for j in range(2) - ])).float().to(device) - self.labels = torch.from_numpy(np.array( - [label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)], - )).float().to(device) + + if test: + dims = [1, 2] + self.boards = torch.from_numpy(np.array([ + board(d, 0, 0) for d in data + ])).float().to(device) + self.labels = torch.from_numpy(np.array( + [label(d, 0) for d in data], + )).float().to(device) + else: + dims = [1, 2] + self.boards = torch.from_numpy(np.array([ + torch.rot90(board(d, j, k), i, dims) + for d in data + for k in range(2) + for i in range(4) + for j in range(2) + ])).float().to(device) + self.labels = torch.from_numpy(np.array( + [label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)], + )).float().to(device) def __len__(self): return len(self.boards) @@ -274,6 +283,7 @@ if __name__ == "__main__": parser.add_argument("-d", help="path to dataset input") parser.add_argument("-n", help="Number of epoch", type=int, default=25) parser.add_argument("-l", help="learning rate", type=float, default=0.01) + parser.add_argument("-R", help="Create result file from a dataset") args = parser.parse_args() device = setup_device() @@ -301,7 +311,7 @@ if __name__ == "__main__": testdata = rawdata[:train_size] trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device) - testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device) + testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device, test=True) batch_size = 8192 epochs = args.n