feat: add test arg for dataset
This commit is contained in:
parent
624da7dd62
commit
15e64ffad3
@ -54,7 +54,7 @@ class GoModel(nn.Module):
|
||||
|
||||
|
||||
class GoDataset(Dataset):
|
||||
def __init__(self, data, device):
|
||||
def __init__(self, data, device, test=False):
|
||||
def label(d, j):
|
||||
if j == 0:
|
||||
return d["black_wins"] / d["rollouts"]
|
||||
@ -71,18 +71,27 @@ class GoDataset(Dataset):
|
||||
return out
|
||||
else:
|
||||
return out.flipud()
|
||||
|
||||
dims = [1, 2]
|
||||
self.boards = torch.from_numpy(np.array([
|
||||
torch.rot90(board(d, j, k), i, dims)
|
||||
for d in data
|
||||
for k in range(2)
|
||||
for i in range(4)
|
||||
for j in range(2)
|
||||
])).float().to(device)
|
||||
self.labels = torch.from_numpy(np.array(
|
||||
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
|
||||
)).float().to(device)
|
||||
|
||||
if test:
|
||||
dims = [1, 2]
|
||||
self.boards = torch.from_numpy(np.array([
|
||||
board(d, 0, 0) for d in data
|
||||
])).float().to(device)
|
||||
self.labels = torch.from_numpy(np.array(
|
||||
[label(d, 0) for d in data],
|
||||
)).float().to(device)
|
||||
else:
|
||||
dims = [1, 2]
|
||||
self.boards = torch.from_numpy(np.array([
|
||||
torch.rot90(board(d, j, k), i, dims)
|
||||
for d in data
|
||||
for k in range(2)
|
||||
for i in range(4)
|
||||
for j in range(2)
|
||||
])).float().to(device)
|
||||
self.labels = torch.from_numpy(np.array(
|
||||
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
|
||||
)).float().to(device)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.boards)
|
||||
@ -274,6 +283,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-d", help="path to dataset input")
|
||||
parser.add_argument("-n", help="Number of epoch", type=int, default=25)
|
||||
parser.add_argument("-l", help="learning rate", type=float, default=0.01)
|
||||
parser.add_argument("-R", help="Create result file from a dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = setup_device()
|
||||
@ -301,7 +311,7 @@ if __name__ == "__main__":
|
||||
testdata = rawdata[:train_size]
|
||||
|
||||
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
|
||||
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device)
|
||||
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device, test=True)
|
||||
|
||||
batch_size = 8192
|
||||
epochs = args.n
|
||||
|
Loading…
x
Reference in New Issue
Block a user