diff --git a/go_ia/main.py b/go_ia/main.py index 15f4f92..fdab1ed 100755 --- a/go_ia/main.py +++ b/go_ia/main.py @@ -1,5 +1,7 @@ #!python import time +import gzip +import json from os import path import os import matplotlib.pyplot as plt @@ -152,8 +154,6 @@ def test(model, device, test_loader, floss, eps=0.1): # Import du fichier d'exemples def get_raw_data_go(): """Returns the set of samples from the local file or download it if it does not exists""" - import gzip, os.path - import json raw_samples_file = "samples-8x8.json.gz" @@ -302,6 +302,14 @@ if __name__ == "__main__": checkpoint = torch.load(args.m, weights_only=True) mymodel.load_state_dict(checkpoint["model_state_dict"]) + if args.R is not None: + if args.m is None: + print("You need to specify weights for a model") + exit(1) + + with gzip.open(args.R) as fz: + data = json.loads(fz.read().decode("utf-8")) + create_result_file(data) if args.t is not None and args.T is not None: trainset = torch.load(args.t)