feat: add test arg for dataset
This commit is contained in:
parent
624da7dd62
commit
15e64ffad3
@ -54,7 +54,7 @@ class GoModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GoDataset(Dataset):
|
class GoDataset(Dataset):
|
||||||
def __init__(self, data, device):
|
def __init__(self, data, device, test=False):
|
||||||
def label(d, j):
|
def label(d, j):
|
||||||
if j == 0:
|
if j == 0:
|
||||||
return d["black_wins"] / d["rollouts"]
|
return d["black_wins"] / d["rollouts"]
|
||||||
@ -71,18 +71,27 @@ class GoDataset(Dataset):
|
|||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
return out.flipud()
|
return out.flipud()
|
||||||
|
|
||||||
dims = [1, 2]
|
if test:
|
||||||
self.boards = torch.from_numpy(np.array([
|
dims = [1, 2]
|
||||||
torch.rot90(board(d, j, k), i, dims)
|
self.boards = torch.from_numpy(np.array([
|
||||||
for d in data
|
board(d, 0, 0) for d in data
|
||||||
for k in range(2)
|
])).float().to(device)
|
||||||
for i in range(4)
|
self.labels = torch.from_numpy(np.array(
|
||||||
for j in range(2)
|
[label(d, 0) for d in data],
|
||||||
])).float().to(device)
|
)).float().to(device)
|
||||||
self.labels = torch.from_numpy(np.array(
|
else:
|
||||||
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
|
dims = [1, 2]
|
||||||
)).float().to(device)
|
self.boards = torch.from_numpy(np.array([
|
||||||
|
torch.rot90(board(d, j, k), i, dims)
|
||||||
|
for d in data
|
||||||
|
for k in range(2)
|
||||||
|
for i in range(4)
|
||||||
|
for j in range(2)
|
||||||
|
])).float().to(device)
|
||||||
|
self.labels = torch.from_numpy(np.array(
|
||||||
|
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
|
||||||
|
)).float().to(device)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.boards)
|
return len(self.boards)
|
||||||
@ -274,6 +283,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-d", help="path to dataset input")
|
parser.add_argument("-d", help="path to dataset input")
|
||||||
parser.add_argument("-n", help="Number of epoch", type=int, default=25)
|
parser.add_argument("-n", help="Number of epoch", type=int, default=25)
|
||||||
parser.add_argument("-l", help="learning rate", type=float, default=0.01)
|
parser.add_argument("-l", help="learning rate", type=float, default=0.01)
|
||||||
|
parser.add_argument("-R", help="Create result file from a dataset")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
device = setup_device()
|
device = setup_device()
|
||||||
@ -301,7 +311,7 @@ if __name__ == "__main__":
|
|||||||
testdata = rawdata[:train_size]
|
testdata = rawdata[:train_size]
|
||||||
|
|
||||||
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
|
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
|
||||||
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device)
|
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device, test=True)
|
||||||
|
|
||||||
batch_size = 8192
|
batch_size = 8192
|
||||||
epochs = args.n
|
epochs = args.n
|
||||||
|
Loading…
x
Reference in New Issue
Block a user