From 7670ecdabf7630d821065404be1c61cf1729afcb Mon Sep 17 00:00:00 2001 From: Martin Eyben Date: Fri, 16 May 2025 10:48:14 +0200 Subject: [PATCH] feat: double convolution --- go_ia/main.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/go_ia/main.py b/go_ia/main.py index 0cb88af..3b86227 100755 --- a/go_ia/main.py +++ b/go_ia/main.py @@ -30,17 +30,21 @@ class GoModel(nn.Module): nn.Dropout(0.4), torch.nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(64), + nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), torch.nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(64), + nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), + torch.nn.ReLU(), + + nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), torch.nn.ReLU(), nn.Flatten(), - nn.Linear(64 * 8 * 8, 128), + nn.Linear(128 * 8 * 8, 128), nn.BatchNorm1d(128), torch.nn.ReLU(), @@ -298,6 +302,15 @@ if __name__ == "__main__": checkpoint = torch.load(args.m, weights_only=True) mymodel.load_state_dict(checkpoint["model_state_dict"]) + if args.R is not None: + if args.m is None: + print("You need to specify weights for a model") + return + + with gzip.open(args.R) as fz: + data = json.loads(fz.read().decode("utf-8")) + create_result_file(data) + if args.t is not None and args.T is not None: trainset = torch.load(args.t) testset = torch.load(args.T)