feat: add test arg for dataset

This commit is contained in:
Martin Eyben 2025-05-16 10:04:16 +02:00
parent 624da7dd62
commit 15e64ffad3

View File

@ -54,7 +54,7 @@ class GoModel(nn.Module):
class GoDataset(Dataset):
def __init__(self, data, device):
def __init__(self, data, device, test=False):
def label(d, j):
if j == 0:
return d["black_wins"] / d["rollouts"]
@ -72,17 +72,26 @@ class GoDataset(Dataset):
else:
return out.flipud()
dims = [1, 2]
self.boards = torch.from_numpy(np.array([
torch.rot90(board(d, j, k), i, dims)
for d in data
for k in range(2)
for i in range(4)
for j in range(2)
])).float().to(device)
self.labels = torch.from_numpy(np.array(
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
)).float().to(device)
if test:
dims = [1, 2]
self.boards = torch.from_numpy(np.array([
board(d, 0, 0) for d in data
])).float().to(device)
self.labels = torch.from_numpy(np.array(
[label(d, 0) for d in data],
)).float().to(device)
else:
dims = [1, 2]
self.boards = torch.from_numpy(np.array([
torch.rot90(board(d, j, k), i, dims)
for d in data
for k in range(2)
for i in range(4)
for j in range(2)
])).float().to(device)
self.labels = torch.from_numpy(np.array(
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
)).float().to(device)
def __len__(self):
return len(self.boards)
@ -274,6 +283,7 @@ if __name__ == "__main__":
parser.add_argument("-d", help="path to dataset input")
parser.add_argument("-n", help="Number of epoch", type=int, default=25)
parser.add_argument("-l", help="learning rate", type=float, default=0.01)
parser.add_argument("-R", help="Create result file from a dataset")
args = parser.parse_args()
device = setup_device()
@ -301,7 +311,7 @@ if __name__ == "__main__":
testdata = rawdata[:train_size]
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device)
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device, test=True)
batch_size = 8192
epochs = args.n