feat: add go ai
This commit is contained in:
parent
5f28800dbb
commit
9d28f31900
3
go_ia/.gitignore
vendored
3
go_ia/.gitignore
vendored
@ -1,3 +1,6 @@
|
||||
__pycache__
|
||||
/venv
|
||||
/.venv
|
||||
/samples-8x8.json.gz
|
||||
/samples-8x8.json
|
||||
*.pt
|
||||
|
8
go_ia/README.md
Normal file
8
go_ia/README.md
Normal file
@ -0,0 +1,8 @@
|
||||
# TP IA GO
|
||||
|
||||
## VENV ACTIVATE
|
||||
|
||||
```sh
|
||||
python -m venv venv
|
||||
. ./venv/bin/activate
|
||||
```
|
390
go_ia/main.py
Executable file
390
go_ia/main.py
Executable file
@ -0,0 +1,390 @@
|
||||
#!python
|
||||
import time
|
||||
from os import path
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
class GoModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(GoModel, self).__init__()
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(16),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.Dropout(0.4),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
nn.Flatten(),
|
||||
|
||||
nn.Linear(64 * 8 * 8, 128),
|
||||
nn.BatchNorm1d(128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(128, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GoDataset(Dataset):
|
||||
def __init__(self, data, device):
|
||||
def label(d, j):
|
||||
if j == 0:
|
||||
return d["black_wins"] / d["rollouts"]
|
||||
else:
|
||||
return 1 - label(d, 0)
|
||||
|
||||
def board(d, j, k):
|
||||
if j == 0:
|
||||
out = stones_to_board(d["black_stones"], d["white_stones"], d["depth"] % 2 == 0)
|
||||
else:
|
||||
out = stones_to_board(d["white_stones"], d["black_stones"], d["depth"] % 2 == 1)
|
||||
|
||||
if k == 0:
|
||||
return out
|
||||
else:
|
||||
return out.flipud()
|
||||
|
||||
dims = [1, 2]
|
||||
self.boards = torch.from_numpy(np.array([
|
||||
torch.rot90(board(d, j, k), i, dims)
|
||||
for d in data
|
||||
for k in range(2)
|
||||
for i in range(4)
|
||||
for j in range(2)
|
||||
])).float().to(device)
|
||||
self.labels = torch.from_numpy(np.array(
|
||||
[label(d, j) for d in data for _ in range(4) for _k in range(2) for j in range(2)],
|
||||
)).float().to(device)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.boards)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.boards[i], self.labels[i]
|
||||
|
||||
|
||||
def train(model, device, train_loader, optimizer, floss, epoch):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
for _, (data, target) in enumerate(train_loader):
|
||||
optimizer.zero_grad()
|
||||
output = model(data).squeeze()
|
||||
|
||||
loss = floss(output, target)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss * len(data) # We average the loss
|
||||
|
||||
total_loss /= len(train_loader.dataset)
|
||||
return {"loss": total_loss.item()}
|
||||
|
||||
|
||||
def test(model, device, test_loader, floss, eps=0.1):
|
||||
def isok(output, target, eps):
|
||||
return abs(output - target) < eps
|
||||
|
||||
model.eval()
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, target) in enumerate(test_loader):
|
||||
output = model(data).squeeze()
|
||||
|
||||
loss = floss(output, target)
|
||||
total_loss += loss * len(data)
|
||||
|
||||
correct += isok(output, target, eps).sum()
|
||||
|
||||
total_loss /= len(test_loader.dataset)
|
||||
fcorrect = correct.item()
|
||||
return {
|
||||
"loss": total_loss.item(),
|
||||
"correct": fcorrect,
|
||||
"accuracy": fcorrect / len(testloader.dataset),
|
||||
}
|
||||
|
||||
|
||||
# Import du fichier d'exemples
|
||||
def get_raw_data_go():
|
||||
"""Returns the set of samples from the local file or download it if it does not exists"""
|
||||
import gzip, os.path
|
||||
import json
|
||||
|
||||
raw_samples_file = "samples-8x8.json.gz"
|
||||
|
||||
if not os.path.isfile(raw_samples_file):
|
||||
print("File", raw_samples_file, "not found, I am downloading it...", end="")
|
||||
import urllib.request
|
||||
|
||||
urllib.request.urlretrieve(
|
||||
"https://www.labri.fr/perso/lsimon/static/inge2-ia/samples-8x8.json.gz",
|
||||
"samples-8x8.json.gz",
|
||||
)
|
||||
print(" Done")
|
||||
|
||||
with gzip.open("samples-8x8.json.gz") as fz:
|
||||
data = json.loads(fz.read().decode("utf-8"))
|
||||
return data
|
||||
|
||||
|
||||
def summary_of_example(data, sample_nb):
|
||||
"""Gives you some insights about a sample number"""
|
||||
sample = data[sample_nb]
|
||||
print("Sample", sample_nb)
|
||||
print()
|
||||
print("Données brutes en format JSON:", sample)
|
||||
print()
|
||||
print("The sample was obtained after", sample["depth"], "moves")
|
||||
print("The successive moves were", sample["list_of_moves"])
|
||||
print(
|
||||
"After these moves and all the captures, there was black stones at the following position",
|
||||
sample["black_stones"],
|
||||
)
|
||||
print(
|
||||
"After these moves and all the captures, there was white stones at the following position",
|
||||
sample["white_stones"],
|
||||
)
|
||||
print(
|
||||
"Number of rollouts (gnugo games played against itself from this position):",
|
||||
sample["rollouts"],
|
||||
)
|
||||
print(
|
||||
"Over these",
|
||||
sample["rollouts"],
|
||||
"games, black won",
|
||||
sample["black_wins"],
|
||||
"times with",
|
||||
sample["black_points"],
|
||||
"total points over all this winning games",
|
||||
)
|
||||
print(
|
||||
"Over these",
|
||||
sample["rollouts"],
|
||||
"games, white won",
|
||||
sample["white_wins"],
|
||||
"times with",
|
||||
sample["white_points"],
|
||||
"total points over all this winning games",
|
||||
)
|
||||
|
||||
|
||||
def stones_to_board(black_stones, white_stones, black_plays):
|
||||
board = torch.zeros((3, 8, 8), dtype=torch.float32)
|
||||
|
||||
for s in black_stones:
|
||||
if s == "PASS":
|
||||
continue
|
||||
|
||||
i, j = (ord(s[0]) - ord("A"), int(s[1]) - 1)
|
||||
board[0, i, j] = 1
|
||||
|
||||
for s in white_stones:
|
||||
if s == "PASS":
|
||||
continue
|
||||
|
||||
i, j = (ord(s[0]) - ord("A"), int(s[1]) - 1)
|
||||
board[1, i, j] = 1
|
||||
|
||||
board[2,:,:] = 1 if black_plays else 0
|
||||
|
||||
return board
|
||||
|
||||
|
||||
def position_predict(black_stones, white_stones, depth):
|
||||
board = stones_to_board(black_stones, white_stones, depth % 2 == 0)
|
||||
|
||||
with torch.no_grad():
|
||||
prediction = mymodel(board.unsqueeze(0))
|
||||
|
||||
return prediction
|
||||
|
||||
|
||||
# Ainsi, pour le rendu, en admettant que newdata soit la structure de données issue du json contenant les nouvelles données que
|
||||
# l'on vous donnera 24h avant la fin, vous pourrez construire le fichier resultat ainsi
|
||||
def create_result_file(newdata):
|
||||
"""Exemple de méthode permettant de générer le fichier de resultats demandés."""
|
||||
resultat = [position_predict(d["black_stones"], d["white_stones"], d["depth"]) for d in newdata]
|
||||
with open("my_predictions.txt", "w") as f:
|
||||
for p in resultat:
|
||||
f.write(str(p.item()) + "\n")
|
||||
|
||||
|
||||
def setup_device():
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
# Allows to use the GPU if available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="alphaChadGo", description="a random ai evaluating go", epilog=""
|
||||
)
|
||||
|
||||
parser.add_argument("-m", help="path to model to load, default train from scratch")
|
||||
parser.add_argument("-o", help="path to file to save the current model")
|
||||
parser.add_argument(
|
||||
"-O",
|
||||
help="path dir to directory where to save each epoch",
|
||||
type=str,
|
||||
default="./current",
|
||||
)
|
||||
parser.add_argument("-t", help="path to train dataset")
|
||||
parser.add_argument("-T", help="path to test dataset")
|
||||
parser.add_argument("-d", help="path to dataset input")
|
||||
parser.add_argument("-n", help="Number of epoch", type=int, default=25)
|
||||
parser.add_argument("-l", help="learning rate", type=float, default=0.01)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = setup_device()
|
||||
print(device)
|
||||
|
||||
mymodel = GoModel().to(device)
|
||||
output_dir = args.O
|
||||
|
||||
if args.m is not None:
|
||||
checkpoint = torch.load(args.m, weights_only=True)
|
||||
mymodel.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if args.t is not None and args.T is not None:
|
||||
trainset = torch.load(args.t)
|
||||
testset = torch.load(args.T)
|
||||
|
||||
else:
|
||||
rawdata = get_raw_data_go()
|
||||
|
||||
val_part = 0.33
|
||||
train_size = int(len(rawdata) * val_part)
|
||||
traindata = rawdata[train_size:]
|
||||
testdata = rawdata[:train_size]
|
||||
|
||||
trainset = torch.load(args.t) if args.t is not None else GoDataset(traindata, device)
|
||||
testset = torch.load(args.T) if args.T is not None else GoDataset(testdata, device)
|
||||
|
||||
batch_size = 8192
|
||||
epochs = args.n
|
||||
lr = args.l
|
||||
floss = nn.MSELoss()
|
||||
|
||||
print("=============================")
|
||||
print("batch_size:\t", batch_size)
|
||||
print("epochs:\t\t", epochs)
|
||||
print("learning rate:\t", lr)
|
||||
print("train size:\t", len(trainset))
|
||||
print("test size:\t", len(testset))
|
||||
print("=============================")
|
||||
|
||||
if args.t is None:
|
||||
torch.save(trainset, path.join(output_dir, "trainset.pt"))
|
||||
if args.T is None:
|
||||
torch.save(testset, path.join(output_dir, "testset.pt"))
|
||||
|
||||
trainloader = torch.utils.data.DataLoader(
|
||||
trainset, batch_size=batch_size, shuffle=True, num_workers=0
|
||||
)
|
||||
|
||||
testloader = torch.utils.data.DataLoader(
|
||||
testset, batch_size=batch_size, shuffle=False, num_workers=0
|
||||
)
|
||||
|
||||
optimizer = optim.AdamW(mymodel.parameters(), lr=lr)
|
||||
|
||||
print("begin train")
|
||||
|
||||
stats = []
|
||||
test_stats = test(mymodel, device, testloader, floss, 0.025)
|
||||
print(test_stats)
|
||||
|
||||
for epoch in range(epochs):
|
||||
tst = time.time()
|
||||
train_stats = train(mymodel, device, trainloader, optimizer, floss, epoch)
|
||||
tnd = time.time()
|
||||
|
||||
vst = time.time()
|
||||
test_stats = test(mymodel, device, testloader, floss, 0.025)
|
||||
vnd = time.time()
|
||||
|
||||
print(
|
||||
f"{epoch:04d} | {tnd - tst:.2f} | {vnd - vst:.2f} | tloss {train_stats['loss']:.4f} | vloss {test_stats['loss']:.4f} | vaccuracy {100 * test_stats['accuracy']:.2f}"
|
||||
)
|
||||
stats.append({"train": train_stats, "test": test_stats})
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": mymodel.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"trainloss": train_stats["loss"],
|
||||
"testloss": test_stats["loss"],
|
||||
"accuracy": test_stats["accuracy"],
|
||||
},
|
||||
path.join(output_dir, f"{epoch}.pt"),
|
||||
)
|
||||
|
||||
if args.o is not None:
|
||||
torch.save(mymodel.state_dict(), args.o)
|
||||
|
||||
trainloss = [s["train"]["loss"] for s in stats]
|
||||
testloss = [s["test"]["loss"] for s in stats]
|
||||
testacc = [s["test"]["accuracy"] for s in stats]
|
||||
|
||||
plt.figure()
|
||||
#
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.title("loss over epochs")
|
||||
plt.plot(trainloss, label="train loss")
|
||||
plt.plot(testloss, label="test loss")
|
||||
plt.xlabel("epochs")
|
||||
plt.ylabel("loss")
|
||||
plt.grid()
|
||||
plt.legend()
|
||||
#
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.title("test accuracy over epochs")
|
||||
plt.plot(testacc, label="Test accuracy")
|
||||
plt.xlabel("epochs")
|
||||
plt.ylabel("accuracy")
|
||||
plt.grid()
|
||||
plt.legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
# create_result_file(data)
|
3
go_ia/requirements.txt
Normal file
3
go_ia/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
matplotlib
|
||||
numpy
|
||||
torch
|
Loading…
x
Reference in New Issue
Block a user