feat: add test arg for dataset

This commit is contained in:
Martin Eyben 2025-05-16 10:04:16 +02:00
parent 624da7dd62
commit 15e64ffad3

View File

@ -54,7 +54,7 @@ class GoModel(nn.Module):
class GoDataset(Dataset): class GoDataset(Dataset):
def __init__(self, data, device): def __init__(self, data, device, test=False):
def label(d, j): def label(d, j):
if j == 0: if j == 0:
return d["black_wins"] / d["rollouts"] return d["black_wins"] / d["rollouts"]
@ -72,6 +72,15 @@ class GoDataset(Dataset):
else: else:
return out.flipud() return out.flipud()
if test:
dims = [1, 2]
self.boards = torch.from_numpy(np.array([
board(d, 0, 0) for d in data
])).float().to(device)
self.labels = torch.from_numpy(np.array(
[label(d, 0) for d in data],
)).float().to(device)
else:
dims = [1, 2] dims = [1, 2]
self.boards = torch.from_numpy(np.array([ self.boards = torch.from_numpy(np.array([
torch.rot90(board(d, j, k), i, dims) torch.rot90(board(d, j, k), i, dims)
@ -274,6 +283,7 @@ if __name__ == "__main__":
parser.add_argument("-d", help="path to dataset input") parser.add_argument("-d", help="path to dataset input")
parser.add_argument("-n", help="Number of epoch", type=int, default=25) parser.add_argument("-n", help="Number of epoch", type=int, default=25)
parser.add_argument("-l", help="learning rate", type=float, default=0.01) parser.add_argument("-l", help="learning rate", type=float, default=0.01)
parser.add_argument("-R", help="Create result file from a dataset")
args = parser.parse_args() args = parser.parse_args()
device = setup_device() device = setup_device()
@ -301,7 +311,7 @@ if __name__ == "__main__":
testdata = rawdata[:train_size] testdata = rawdata[:train_size]
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device) trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device) testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device, test=True)
batch_size = 8192 batch_size = 8192
epochs = args.n epochs = args.n