1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| import copy class CliffWalkingEnv: """ 悬崖漫步环境""" def __init__(self, ncol=12, nrow=4): self.ncol = ncol self.nrow = nrow self.P = self.createP()
def createP(self): P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)] change = [[0, -1], [0, 1], [-1, 0], [1, 0]] for i in range(self.nrow): for j in range(self.ncol): for a in range(4): if i == self.nrow - 1 and j > 0: P[i * self.ncol + j][a] = [(1, i * self.ncol + j, 0, True)] continue next_x = min(self.ncol - 1, max(0, j + change[a][0])) next_y = min(self.nrow - 1, max(0, i + change[a][1])) next_state = next_y * self.ncol + next_x reward = -1 done = False if next_y == self.nrow - 1 and next_x > 0: done = True if next_x != self.ncol - 1: reward = -100 P[i * self.ncol + j][a] = [(1, next_state, reward, done)] return P
class PolicyIteration: """ 策略迭代算法 """ def __init__(self, env, theta, gamma): self.env = env self.v = [0] * self.env.ncol * self.env.nrow self.pi = [[0.25, 0.25, 0.25, 0.25] for i in range(self.env.ncol * self.env.nrow)] self.theta = theta self.gamma = gamma def policy_evaluation(self): """ 策略评估 """ cnt1 = 0 while True: cnt1 += 1 new_v = [0] * self.env.ncol * self.env.nrow for s in range(self.env.ncol * self.env.nrow): qsa_list = [] for a in range(4): qsa = 0 for res in self.env.P[s][a]: p, next_state, reward, done = res qsa += p * (reward + self.gamma * self.v[next_state]*(1-done)) qsa_list.append(self.pi[s][a] * qsa) new_v[s] = sum(qsa_list) delta = max([abs(self.v[s] - new_v[s]) for s in range(self.env.ncol * self.env.nrow)]) self.v = new_v if delta < self.theta: break print(f"策略评估{cnt1}次完成") def policy_improvement(self): """ 策略提升 """ for s in range(self.env.ncol * self.env.nrow): qsa_list = [] for a in range(4): qsa = 0 for res in self.env.P[s][a]: p, next_state, reward, done = res qsa += p * (reward + self.gamma * self.v[next_state]*(1-done)) qsa_list.append(qsa) max_qsa = max(qsa_list) cntq=qsa_list.count(max_qsa) self.pi[s]=[1/cntq if q==max_qsa else 0 for q in qsa_list] return self.pi def policy_iteration(self): """ 策略迭代 """ cnt = 0 while True: cnt += 1 old_pi = copy.deepcopy(self.pi) self.policy_evaluation() new_pi = self.policy_improvement() print("策略提升!") if old_pi == new_pi: break print(f"策略迭代{cnt}轮完成!") return new_pi def print_result(self): """ 打印结果 """ print("状态价值:") for i in range(self.env.nrow): for j in range(self.env.ncol): print("{:.3f}".format(self.v[i * self.env.ncol + j]).center(8), end="") print('\n') print("策略:") actions = ['^', 'v', '<', '>'] for i in range(self.env.nrow): for j in range(self.env.ncol): if i == self.env.nrow - 1 and j == self.env.ncol - 1: print("goal".center(5), end="") elif i == self.env.nrow - 1 and self.env.ncol-1>j > 0: print("x".center(5), end="") else: a=self.pi[i*self.env.ncol+j] a_str=''.join( actions[i] if a[i]>0 else 'o' for i in range(len(a)) ) print(a_str.center(5), end="") print('\n')
if __name__ == '__main__': env = CliffWalkingEnv() agent = PolicyIteration(env, theta=0.001, gamma=0.9) agent.policy_iteration() agent.print_result()
|