-
Notifications
You must be signed in to change notification settings - Fork 2
/
policies.py
72 lines (55 loc) · 2.08 KB
/
policies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import random
import utilities
__author__ = 'miljan'
# action constants
HIT = 0
STICK = 1
def epsilon_greedy(epsilon, value_function, state):
""" Epsilon greedy policy, which returns random action with probability epsilon, highest value action otherwise.
:param epsilon: random action probability
:param value_function: (state, action) value function
:param state: current state of the game
:return: action to take
"""
# exploration
if random.random() < epsilon:
return _random_action()
# exploitation
else:
player = state.player_sum
dealer = state.dealer_first_card
value_HIT = value_function[(player, dealer, HIT)]
value_STICK = value_function[(player, dealer, STICK)]
if value_HIT > value_STICK:
return HIT
elif value_STICK > value_HIT:
return STICK
else:
return _random_action()
def epsilon_greedy_lfa(epsilon, theta, state_features):
""" Epsilon greedy policy, for linear function approximation,
which returns random action with probability epsilon, highest value action otherwise.
:param epsilon: random action probability
:param value_function: [state_action_feature_vector] value function
:param state: current state feature vector
:return: action to take
"""
# exploration
if random.random() < epsilon:
act = _random_action()
val = utilities.get_state_action_features(state_features, act).dot(theta)
return val, act
# exploitation
else:
value_HIT = utilities.get_state_action_features(state_features, 0).dot(theta)
value_STICK = utilities.get_state_action_features(state_features, 1).dot(theta)
if value_HIT > value_STICK:
return value_HIT, HIT
elif value_STICK > value_HIT:
return value_STICK, STICK
else:
act = _random_action()
val = utilities.get_state_action_features(state_features, act).dot(theta)
return val, act
def _random_action():
return HIT if random.random() < 0.5 else STICK