feat: double convolution
This commit is contained in:
parent
15e64ffad3
commit
7670ecdabf
@ -30,17 +30,21 @@ class GoModel(nn.Module):
|
|||||||
nn.Dropout(0.4),
|
nn.Dropout(0.4),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
|
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(64),
|
nn.BatchNorm2d(128),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
|
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(64),
|
nn.BatchNorm2d(128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
|
|
||||||
nn.Linear(64 * 8 * 8, 128),
|
nn.Linear(128 * 8 * 8, 128),
|
||||||
nn.BatchNorm1d(128),
|
nn.BatchNorm1d(128),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
@ -298,6 +302,15 @@ if __name__ == "__main__":
|
|||||||
checkpoint = torch.load(args.m, weights_only=True)
|
checkpoint = torch.load(args.m, weights_only=True)
|
||||||
mymodel.load_state_dict(checkpoint["model_state_dict"])
|
mymodel.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
if args.R is not None:
|
||||||
|
if args.m is None:
|
||||||
|
print("You need to specify weights for a model")
|
||||||
|
return
|
||||||
|
|
||||||
|
with gzip.open(args.R) as fz:
|
||||||
|
data = json.loads(fz.read().decode("utf-8"))
|
||||||
|
create_result_file(data)
|
||||||
|
|
||||||
if args.t is not None and args.T is not None:
|
if args.t is not None and args.T is not None:
|
||||||
trainset = torch.load(args.t)
|
trainset = torch.load(args.t)
|
||||||
testset = torch.load(args.T)
|
testset = torch.load(args.T)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user