feat: create result file

This commit is contained in:
Martin Eyben 2025-05-16 10:51:35 +02:00
parent 6d0b682053
commit d98a817d51

View File

@ -1,5 +1,7 @@
#!python #!python
import time import time
import gzip
import json
from os import path from os import path
import os import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -152,8 +154,6 @@ def test(model, device, test_loader, floss, eps=0.1):
# Import du fichier d'exemples # Import du fichier d'exemples
def get_raw_data_go(): def get_raw_data_go():
"""Returns the set of samples from the local file or download it if it does not exists""" """Returns the set of samples from the local file or download it if it does not exists"""
import gzip, os.path
import json
raw_samples_file = "samples-8x8.json.gz" raw_samples_file = "samples-8x8.json.gz"
@ -302,6 +302,14 @@ if __name__ == "__main__":
checkpoint = torch.load(args.m, weights_only=True) checkpoint = torch.load(args.m, weights_only=True)
mymodel.load_state_dict(checkpoint["model_state_dict"]) mymodel.load_state_dict(checkpoint["model_state_dict"])
if args.R is not None:
if args.m is None:
print("You need to specify weights for a model")
exit(1)
with gzip.open(args.R) as fz:
data = json.loads(fz.read().decode("utf-8"))
create_result_file(data)
if args.t is not None and args.T is not None: if args.t is not None and args.T is not None:
trainset = torch.load(args.t) trainset = torch.load(args.t)