feat: add gameover test in heuristic and map model to current device
This commit is contained in:
parent
21392f229d
commit
c48403ba16
@ -7,6 +7,7 @@ Right now, this class contains the copy of the randomPlayer. But you have to cha
|
||||
|
||||
from sys import stderr
|
||||
import time
|
||||
import math
|
||||
import Goban
|
||||
from random import choice
|
||||
from moveSearch import IDDFS, alphabeta
|
||||
@ -173,7 +174,7 @@ class myPlayer(PlayerInterface):
|
||||
|
||||
self.model = GoModel().to(self.device)
|
||||
|
||||
checkpoint = torch.load("scrum.pt", weights_only=True)
|
||||
checkpoint = torch.load("scrum.pt", weights_only=True, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
def getPlayerName(self):
|
||||
@ -187,7 +188,10 @@ class myPlayer(PlayerInterface):
|
||||
score[0] - score[1] if color == Goban.Board._BLACK else score[1] - score[0]
|
||||
)
|
||||
|
||||
def nnheuristic(self, board, color):
|
||||
def nnheuristic(self, board: Goban.Board, color):
|
||||
if board.is_game_over():
|
||||
return math.inf if board.winner() == color else -math.inf
|
||||
|
||||
go_board = torch.from_numpy(np.array([goban2Go(board)])).float().to(self.device)
|
||||
|
||||
self.model.eval()
|
||||
|
Loading…
x
Reference in New Issue
Block a user