forked from fairy-stockfish/variant-nnue-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
halfkp.py
75 lines (61 loc) · 2.69 KB
/
halfkp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import chess
import torch
import feature_block
from collections import OrderedDict
from feature_block import *
import variant
NUM_SQ = variant.SQUARES
NUM_PT = variant.PIECES - 2
NUM_PLANES = (NUM_SQ * NUM_PT + 1)
def orient(is_white_pov: bool, sq: int):
return (variant.FILES - 1 - sq % variant.FILES) + (variant.RANKS - 1 - (sq // variant.FILES)) * variant.FILES if not is_white_pov else sq
def halfkp_idx(is_white_pov: bool, king_sq: int, sq: int, piece_type: int, color: bool):
p_idx = (piece_type - 1) * 2 + (color != is_white_pov)
return 1 + orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES
class Features(FeatureBlock):
def __init__(self):
super(Features, self).__init__('HalfKP', 0x5d69d5b8, OrderedDict([('HalfKP', NUM_PLANES * NUM_SQ)]))
def get_active_features(self, board: chess.Board):
def piece_features(turn):
indices = torch.zeros(NUM_PLANES * NUM_SQ)
for sq, p in board.piece_map().items():
if p.piece_type == chess.KING:
continue
indices[halfkp_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0
return indices
return (piece_features(chess.WHITE), piece_features(chess.BLACK))
def get_initial_psqt_features(self):
raise Exception('Not supported yet. See HalfKA')
class FactorizedFeatures(FeatureBlock):
def __init__(self):
super(FactorizedFeatures, self).__init__('HalfKP^', 0x5d69d5b8, OrderedDict([('HalfKP', NUM_PLANES * NUM_SQ), ('HalfK', NUM_SQ), ('P', NUM_SQ * 10 )]))
self.base = Features()
def get_active_features(self, board: chess.Board):
white, black = self.base.get_active_features(board)
def piece_features(base, color):
indices = torch.zeros(NUM_SQ * 11)
piece_count = 0
# P feature
for sq, p in board.piece_map().items():
if p.piece_type == chess.KING:
continue
piece_count += 1
p_idx = (p.piece_type - 1) * 2 + (p.color != color)
indices[(p_idx + 1) * NUM_SQ + orient(color, sq)] = 1.0
# HalfK feature
indices[orient(color, board.king(color))] = piece_count
return torch.cat((base, indices))
return (piece_features(white, chess.WHITE), piece_features(black, chess.BLACK))
def get_feature_factors(self, idx):
if idx >= self.num_real_features:
raise Exception('Feature must be real')
k_idx = idx // NUM_PLANES
p_idx = idx % NUM_PLANES - 1
return [idx, self.get_factor_base_feature('HalfK') + k_idx, self.get_factor_base_feature('P') + p_idx]
def get_initial_psqt_features(self):
raise Exception('Not supported yet. See HalfKA^')
'''
This is used by the features module for discovery of feature blocks.
'''
def get_feature_block_clss():
return [Features, FactorizedFeatures]