feat: add gameover test in heuristic and map model to current device

This commit is contained in:
Nemo D'ACREMONT 2025-05-18 12:23:00 +02:00
parent 21392f229d
commit c48403ba16

View File

@ -7,6 +7,7 @@ Right now, this class contains the copy of the randomPlayer. But you have to cha
from sys import stderr
import time
import math
import Goban
from random import choice
from moveSearch import IDDFS, alphabeta
@ -173,7 +174,7 @@ class myPlayer(PlayerInterface):
self.model = GoModel().to(self.device)
checkpoint = torch.load("scrum.pt", weights_only=True)
checkpoint = torch.load("scrum.pt", weights_only=True, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
def getPlayerName(self):
@ -187,7 +188,10 @@ class myPlayer(PlayerInterface):
score[0] - score[1] if color == Goban.Board._BLACK else score[1] - score[0]
)
def nnheuristic(self, board, color):
def nnheuristic(self, board: Goban.Board, color):
if board.is_game_over():
return math.inf if board.winner() == color else -math.inf
go_board = torch.from_numpy(np.array([goban2Go(board)])).float().to(self.device)
self.model.eval()