go-ai/go_ia/main.py
2025-05-16 10:49:03 +02:00

410 lines
12 KiB
Python
Executable File

#!python
import time
from os import path
import os
import matplotlib.pyplot as plt
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
import torch.optim as optim
class GoModel(nn.Module):
def __init__(self):
super(GoModel, self).__init__()
self.net = torch.nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(16),
torch.nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
torch.nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.Dropout(0.4),
torch.nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
torch.nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
torch.nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
torch.nn.ReLU(),
nn.Flatten(),
nn.Linear(128 * 8 * 8, 128),
nn.BatchNorm1d(128),
torch.nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
class GoDataset(Dataset):
def __init__(self, data, device, test=False):
def label(d, j):
if j == 0:
return d["black_wins"] / d["rollouts"]
else:
return 1 - label(d, 0)
def board(d, j, k):
if j == 0:
out = stones_to_board(d["black_stones"], d["white_stones"], d["depth"] % 2 == 0)
else:
out = stones_to_board(d["white_stones"], d["black_stones"], d["depth"] % 2 == 1)
if k == 0:
return out
else:
return out.flipud()
if test:
dims = [1, 2]
self.boards = torch.from_numpy(np.array([
board(d, 0, 0) for d in data
])).float().to(device)
self.labels = torch.from_numpy(np.array(
[label(d, 0) for d in data],
)).float().to(device)
else:
dims = [1, 2]
self.boards = torch.from_numpy(np.array([
torch.rot90(board(d, j, k), i, dims)
for d in data
for k in range(2)
for i in range(4)
for j in range(2)
])).float().to(device)
self.labels = torch.from_numpy(np.array(
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
)).float().to(device)
def __len__(self):
return len(self.boards)
def __getitem__(self, i):
return self.boards[i], self.labels[i]
def train(model, device, train_loader, optimizer, floss, epoch):
model.train()
total_loss = 0
for _, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data).squeeze()
loss = floss(output, target)
loss.backward()
optimizer.step()
total_loss += loss * len(data) # We average the loss
total_loss /= len(train_loader.dataset)
return {"loss": total_loss.item()}
def test(model, device, test_loader, floss, eps=0.1):
def isok(output, target, eps):
return abs(output - target) < eps
model.eval()
total_loss = 0
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
output = model(data).squeeze()
loss = floss(output, target)
total_loss += loss * len(data)
correct += isok(output, target, eps).sum()
total_loss /= len(test_loader.dataset)
fcorrect = correct.item()
return {
"loss": total_loss.item(),
"correct": fcorrect,
"accuracy": fcorrect / len(testloader.dataset),
}
# Import du fichier d'exemples
def get_raw_data_go():
"""Returns the set of samples from the local file or download it if it does not exists"""
import gzip, os.path
import json
raw_samples_file = "samples-8x8.json.gz"
if not os.path.isfile(raw_samples_file):
print("File", raw_samples_file, "not found, I am downloading it...", end="")
import urllib.request
urllib.request.urlretrieve(
"https://www.labri.fr/perso/lsimon/static/inge2-ia/samples-8x8.json.gz",
"samples-8x8.json.gz",
)
print(" Done")
with gzip.open("samples-8x8.json.gz") as fz:
data = json.loads(fz.read().decode("utf-8"))
return data
def summary_of_example(data, sample_nb):
"""Gives you some insights about a sample number"""
sample = data[sample_nb]
print("Sample", sample_nb)
print()
print("Données brutes en format JSON:", sample)
print()
print("The sample was obtained after", sample["depth"], "moves")
print("The successive moves were", sample["list_of_moves"])
print(
"After these moves and all the captures, there was black stones at the following position",
sample["black_stones"],
)
print(
"After these moves and all the captures, there was white stones at the following position",
sample["white_stones"],
)
print(
"Number of rollouts (gnugo games played against itself from this position):",
sample["rollouts"],
)
print(
"Over these",
sample["rollouts"],
"games, black won",
sample["black_wins"],
"times with",
sample["black_points"],
"total points over all this winning games",
)
print(
"Over these",
sample["rollouts"],
"games, white won",
sample["white_wins"],
"times with",
sample["white_points"],
"total points over all this winning games",
)
def stones_to_board(black_stones, white_stones, black_plays):
board = torch.zeros((3, 8, 8), dtype=torch.float32)
for s in black_stones:
if s == "PASS":
continue
i, j = (ord(s[0]) - ord("A"), int(s[1]) - 1)
board[0, i, j] = 1
for s in white_stones:
if s == "PASS":
continue
i, j = (ord(s[0]) - ord("A"), int(s[1]) - 1)
board[1, i, j] = 1
board[2,:,:] = 1 if black_plays else 0
return board
def position_predict(black_stones, white_stones, depth):
board = stones_to_board(black_stones, white_stones, depth % 2 == 0)
with torch.no_grad():
prediction = mymodel(board.unsqueeze(0))
return prediction
# Ainsi, pour le rendu, en admettant que newdata soit la structure de données issue du json contenant les nouvelles données que
# l'on vous donnera 24h avant la fin, vous pourrez construire le fichier resultat ainsi
def create_result_file(newdata):
"""Exemple de méthode permettant de générer le fichier de resultats demandés."""
resultat = [position_predict(d["black_stones"], d["white_stones"], d["depth"]) for d in newdata]
with open("my_predictions.txt", "w") as f:
for p in resultat:
f.write(str(p.item()) + "\n")
def setup_device():
torch.set_float32_matmul_precision("medium")
# Allows to use the GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
torch.backends.cuda.matmul.allow_tf32 = True
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
return device
if __name__ == "__main__":
torch.serialization.add_safe_globals([GoDataset])
parser = argparse.ArgumentParser(
prog="alphaChadGo", description="a random ai evaluating go", epilog=""
)
parser.add_argument("-m", help="path to model to load, default train from scratch")
parser.add_argument("-o", help="path to file to save the current model")
parser.add_argument(
"-O",
help="path dir to directory where to save each epoch",
type=str,
default="./current",
)
parser.add_argument("-t", help="path to train dataset")
parser.add_argument("-T", help="path to test dataset")
parser.add_argument("-d", help="path to dataset input")
parser.add_argument("-n", help="Number of epoch", type=int, default=25)
parser.add_argument("-l", help="learning rate", type=float, default=0.01)
parser.add_argument("-R", help="Create result file from a dataset")
args = parser.parse_args()
device = setup_device()
print(device)
mymodel = GoModel().to(device)
output_dir = args.O
if (not os.path.exists(output_dir)):
os.mkdir(output_dir)
if args.m is not None:
checkpoint = torch.load(args.m, weights_only=True)
mymodel.load_state_dict(checkpoint["model_state_dict"])
if args.t is not None and args.T is not None:
trainset = torch.load(args.t)
testset = torch.load(args.T)
else:
rawdata = get_raw_data_go()
val_part = 0.33
train_size = int(len(rawdata) * val_part)
traindata = rawdata[train_size:]
testdata = rawdata[:train_size]
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device, test=True)
batch_size = 8192
epochs = args.n
lr = args.l
floss = nn.MSELoss()
print("=============================")
print("batch_size:\t", batch_size)
print("epochs:\t\t", epochs)
print("learning rate:\t", lr)
print("train size:\t", len(trainset))
print("test size:\t", len(testset))
print("=============================")
if args.t is None:
torch.save(trainset, path.join(output_dir, "trainset.pt"))
if args.T is None:
torch.save(testset, path.join(output_dir, "testset.pt"))
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=0
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=0
)
optimizer = optim.AdamW(mymodel.parameters(), lr=lr)
print("begin train")
stats = []
test_stats = test(mymodel, device, testloader, floss, 0.025)
print(test_stats)
for epoch in range(epochs):
tst = time.time()
train_stats = train(mymodel, device, trainloader, optimizer, floss, epoch)
tnd = time.time()
vst = time.time()
test_stats = test(mymodel, device, testloader, floss, 0.025)
vnd = time.time()
print(
f"{epoch:04d} | {tnd - tst:.2f} | {vnd - vst:.2f} | tloss {train_stats['loss']:.4f} | vloss {test_stats['loss']:.4f} | vaccuracy {100 * test_stats['accuracy']:.2f}"
)
stats.append({"train": train_stats, "test": test_stats})
torch.save(
{
"epoch": epoch,
"model_state_dict": mymodel.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"trainloss": train_stats["loss"],
"testloss": test_stats["loss"],
"accuracy": test_stats["accuracy"],
},
path.join(output_dir, f"{epoch}.pt"),
)
if args.o is not None:
torch.save(mymodel.state_dict(), args.o)
trainloss = [s["train"]["loss"] for s in stats]
testloss = [s["test"]["loss"] for s in stats]
testacc = [s["test"]["accuracy"] for s in stats]
plt.figure()
#
plt.subplot(1, 2, 1)
plt.title("loss over epochs")
plt.plot(trainloss, label="train loss")
plt.plot(testloss, label="test loss")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.grid()
plt.legend()
#
plt.subplot(1, 2, 2)
plt.title("test accuracy over epochs")
plt.plot(testacc, label="Test accuracy")
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.grid()
plt.legend()
plt.show()
# create_result_file(data)