diff --git a/go_player/myPlayer.py b/go_player/myPlayer.py index 64ccf5d..63139d2 100644 --- a/go_player/myPlayer.py +++ b/go_player/myPlayer.py @@ -7,6 +7,7 @@ Right now, this class contains the copy of the randomPlayer. But you have to cha from sys import stderr import time +import math import Goban from random import choice from moveSearch import IDDFS, alphabeta @@ -173,7 +174,7 @@ class myPlayer(PlayerInterface): self.model = GoModel().to(self.device) - checkpoint = torch.load("scrum.pt", weights_only=True) + checkpoint = torch.load("scrum.pt", weights_only=True, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) def getPlayerName(self): @@ -187,7 +188,10 @@ class myPlayer(PlayerInterface): score[0] - score[1] if color == Goban.Board._BLACK else score[1] - score[0] ) - def nnheuristic(self, board, color): + def nnheuristic(self, board: Goban.Board, color): + if board.is_game_over(): + return math.inf if board.winner() == color else -math.inf + go_board = torch.from_numpy(np.array([goban2Go(board)])).float().to(self.device) self.model.eval()