diff --git a/go_ia/main.py b/go_ia/main.py index c5ba089..146bee2 100755 --- a/go_ia/main.py +++ b/go_ia/main.py @@ -1,6 +1,7 @@ #!python import time from os import path +import os import matplotlib.pyplot as plt import argparse import torch @@ -279,6 +280,8 @@ if __name__ == "__main__": mymodel = GoModel().to(device) output_dir = args.O + if (not os.path.exists(output_dir)): + os.mkdir(output_dir) if args.m is not None: checkpoint = torch.load(args.m, weights_only=True)