feat: create result file

This commit is contained in:
Martin Eyben 2025-05-16 10:51:35 +02:00
parent 6d0b682053
commit d98a817d51

View File

@ -1,5 +1,7 @@
#!python
import time
import gzip
import json
from os import path
import os
import matplotlib.pyplot as plt
@ -152,8 +154,6 @@ def test(model, device, test_loader, floss, eps=0.1):
# Import du fichier d'exemples
def get_raw_data_go():
"""Returns the set of samples from the local file or download it if it does not exists"""
import gzip, os.path
import json
raw_samples_file = "samples-8x8.json.gz"
@ -302,6 +302,14 @@ if __name__ == "__main__":
checkpoint = torch.load(args.m, weights_only=True)
mymodel.load_state_dict(checkpoint["model_state_dict"])
if args.R is not None:
if args.m is None:
print("You need to specify weights for a model")
exit(1)
with gzip.open(args.R) as fz:
data = json.loads(fz.read().decode("utf-8"))
create_result_file(data)
if args.t is not None and args.T is not None:
trainset = torch.load(args.t)