-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSARSA.m
106 lines (79 loc) · 2.48 KB
/
SARSA.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
% SARSA implementation
load('cliffinit.mat')
load('parameters.mat')
% init Q matrix
Q = zeros(goal, 4);
% count for task
count = 1;
% init reward array
reward_arr = [];
% run through task max_itr number of times
for i=1:max_itr
% init totalreward
totalreward = 0;
% init current state at start
curr_state = start;
% find possible actions from currrent state
action_arr = find_action(curr_state);
% find action based on egreedy
action = egreedy(curr_state, action_arr, Q);
% count for episode
count_ep = 1;
while ( curr_state ~= goal )
% find next state based on action chosen
next_state = action_arr(:,action);
% reward at next_state
reward = calculate_r(next_state);
% if next state is on the cliff
if ( mod(next_state,4) == 0 && next_state ~= 4 && next_state ~= 48)
totalreward = totalreward + reward;
break;
end
% find next possible actions from next state
next_action_arr = find_action(next_state);
% find next action based on egreedy
next_action = egreedy(next_state, next_action_arr, Q);
% SARSA equation update
Q(curr_state,action) = Q(curr_state,action) + alpha_p...
* [reward + gamma_p * Q(next_state,next_action) - Q(curr_state,action)];
% update state and action
curr_state = next_state;
action_arr = next_action_arr;
action = next_action;
% make sure while loop ~= infinite loop
count_ep = count_ep + 1;
if ( count_ep == max_ep_itr )
break;
end
% end of episode
totalreward = totalreward + reward;
end
if ( mod(count, 10) == 0 )
fprintf('count:');
disp(count);
end
if ( mod(count, 50) == 0 )
Qimagefunc(count, Q, 'SARSA');
end
count = count + 1;
reward_arr = [reward_arr totalreward];
% end of task
end
% final run through of cliff using greedy
% init final path array
[final_path, ~] = cliffrun(Q);
% display final path
plotfinalpath(final_path, 'SARSA');
% disp avg. reward
disp('SARSA average reward');
disp(mean(reward_arr));
% make learning curve
x = 1:max_itr;
figure;
plot(x,reward_arr);
title(['SARSA algorithm; \alpha = ' num2str(alpha_p)...
' \gamma = ' num2str(gamma_p) ' \epsilon = ' num2str(epsilon_p)])
xlim([0 max_itr]);
ylim([-125 150]);
xlabel('Episodes');
ylabel('Reward');