Commit 28b3c51e authored by Julian Rudolf's avatar Julian Rudolf
Browse files

finished basic qLearing implementation

parent 5e11cfcc
from snake_logic import step, reset, render, init_game
from state import State
import random
class QLearning:
# contains all parameters needed for qlearning
# alpha: learning rate
# gamma: discount factor
# qTable state x action
def __init__(self, alpha_, r_, g_):
def __init__(self, alpha_, g_, epsilon_):
self.alpha = alpha_
self.gamma = g_
self.qTable = {}
self.rounds = 0
self.step_sum = 0
self.reward_sum = 0
self.epsilon = epsilon_
self.num_states = 0
# calculates new qTable entry
def update(self, state, next_state, action, reward):
old_val = self.qTable[(state, action)]
self.qTable[(state, action)] = (1 - self.alpha) * self.qTable[(state, action)] + \
self.alpha * reward + \
self.alpha * self.gamma * self.max_val(next_state)
self.alpha * reward + \
self.alpha * self.gamma * self.max_val(next_state)
#print(state, ", ", action, " = ", old_val, "->", self.qTable[(state, action)])
#print(self.qTable[(state, action)], "->", (1 - self.alpha), "*", old_val, " + ", self.alpha, "*", reward, " + ", self.alpha, "*", self.gamma, "*", self.max_val(next_state))
#print(action)
#print("----------------------------------")
# calculates the max qTable value for a given state
def max_val(self, state):
poss = []
actions = possible_actions(state)
actions = possible_directions(state)
for action in actions:
poss.append(self.qTable[(state, action)])
return poss.max()
if (state, action) in self.qTable:
poss.append(self.qTable[(state, action)])
else:
poss.append(0)
return max(poss)
# function to determine best direction to travel and assign 0 to qtable entry if not present yet
def best_direction(self, state, poss_actions):
best_val = -100000
best_dir = "none"
for dir in poss_actions:
#print(dir, "-> ", end='')
if (state, dir) in self.qTable:
val = self.qTable[(state, dir)]
if val > best_val:
best_dir = dir
best_val = val
else:
self.qTable[(state, dir)] = 0
if 0 > best_val:
best_dir = dir
best_val = 0
#print(self.qTable[(state, dir)], " ", end='')
#print("")
return best_dir
# function to choose which direction to travel
def choose_direction(self, state):
poss_actions = possible_directions(state)
rand_action = random.choice(poss_actions)
best_direction = self.best_direction(state, poss_actions)
if random.random() > self.epsilon:
#print("agent chose ", best_direction)
return best_direction
else:
#print("exploration ", rand_action)
return rand_action
# prints qTable
def print_table(self):
dirs = ("left", "right", "up", "down")
print(dirs)
i = 0
while True:
for tuple, value in self.qTable.items():
if tuple[0].id == i:
print(tuple[1], "(", value, ")", " | ", end='')
print(" ")
i += 1
if i > self.num_states:
break
# returns all possible actions
def possible_actions(state):
# TODO: maby possible != not kill
def possible_directions(state):
poss_actions = []
if not state.kill_l:
poss_actions.append("left")
......@@ -42,10 +108,76 @@ def possible_actions(state):
return poss_actions
init_game()
render()
step("right")
s, r, round_over = step("down")
render()
print("dfl, dfr, dfu, dfd, killl, killr, killu, killd")
print(s.df_l, s.df_r, s.df_u, s.df_d, " ", s.kill_l, s.kill_r, s.kill_u, s.kill_d)
# function to play one game for learning
def play_game_learning(qagent, q_0):
state = q_0
qagent.step_sum = 0
qagent.reward_sum = 0
# state.id = qagent.step_sum
game_over = False
while not game_over:
action = qagent.choose_direction(state)
next_state, reward, game_over = step(action)
qagent.step_sum += 1
qagent.reward_sum += reward
# if next_state.id == -1:
# next_state.id = qagent.num_states + 1
# qagent.num_states = next_state.id
qagent.update(state, next_state, action, reward)
state = next_state
# main learning function
def learning():
alpha = 0.1
gamma = 0.5
epsilon = 0.3
max_rounds = 100
qagent = QLearning(alpha, gamma, epsilon)
q_0 = init_game()
print("Starting learning process!")
for i in range(max_rounds):
play_game_learning(qagent, q_0)
print("Round ", i+1)
print("Reward for this game: ", qagent.reward_sum)
print("Agent stepped ", qagent.step_sum, " times!")
# render()
print("-------------------------------------------")
# qagent.print_table()
print(qagent.qTable)
return qagent.qTable
# chooses best direction based on qtable
def choose_best_direction(qtable, state):
dirs = possible_directions(state)
best_dir = "none"
best_val = -10000
for dir in dirs:
val = qtable[(state, dir)]
if val > best_val:
best_val = val
best_dir = dir
return best_dir
# plays one game with the given qtable
def play_game_testing(qtable, q_0):
state = q_0
game_over = False
while not game_over:
action = choose_best_direction(qtable, state)
# let the agent play against the shield snake to test how the agent performs
def testing(qtable):
max_rounds = 10
q_0 = init_game()
for i in range(max_rounds):
print("Round ", i, " started!")
play_game_testing(qtable, q_0)
table = learning()
#testing(table)
......@@ -464,20 +464,23 @@ def init_game():
startsnake2["vel"].pop()
# create snakes
snake1 = Snake(startsnake1["pos"], startsnake1["vel"], startsnake1["angle"], 0, act_shield=True, length=snake_length)
snake2 = Snake(startsnake2["pos"], startsnake2["vel"], startsnake2["angle"], 1, act_shield=False, length=snake_length, dir=False)
snake1 = Snake(startsnake1["pos"], startsnake1["vel"], startsnake1["angle"], 0, act_shield=False, length=snake_length)
snake2 = Snake(startsnake2["pos"], startsnake2["vel"], startsnake2["angle"], 1, act_shield=True, length=snake_length, dir=False)
snake1.set_enemy_snake(snake2.shield_snake)
snake2.set_enemy_snake(snake1.shield_snake)
snake1.set_enemy_norm_snake(snake2)
snake2.set_enemy_norm_snake(snake1)
# snake needs to be on crossing for shield
# but step function thinks it needs action on first crossing
step("init")
state, _, _, = step("init")
return state
# step function
# inputs action and steps snake to next crossing
# returns state, reward and if game is over
# pro step -1 reward, win +100, loose -100, apple +10
# print: win/loose, reward, steps
def step(action):
global snake1, snake2
game_exit = False
......@@ -491,12 +494,18 @@ def step(action):
if playerwin != 0:
if playerwin == 1 or (playerwin == 3 and snake1.getscore() > snake2.getscore()):
overall_score[0] += 1
print("-------------------------------------------")
print("Snake 1 wins")
reward += 10
print(overall_score)
print("Apples eaten: ", snake1.getscore())
reward += 100
if playerwin == 2 or (playerwin == 3 and snake2.getscore() > snake1.getscore()):
overall_score[1] += 1
print("-------------------------------------------")
print("Snake 2 wins")
reward -= 10
print(overall_score)
print("Apples eaten: ", snake1.getscore())
reward -= 100
if playerwin == 3 and snake1.getscore() == snake2.getscore():
print("Tie")
reward = 0
......@@ -527,24 +536,27 @@ def step(action):
elif action == "init":
init = True
else:
print(action)
assert False, 'action not known'
action_done = True
# snake 2 acts random
if check_if_crossing(snake2.pos[0] / block_size, snake2.pos[1] / block_size):
random_choice = random.randint(1,4)
if random_choice == 1:
snake2.key_event("left")
# print("Random snake chose left(" + str(random_choice) + ")")
elif random_choice == 2:
snake2.key_event("right")
# print("Random snake chose right(" + str(random_choice) + ")")
elif random_choice == 3:
snake2.key_event("down")
# print("Random snake chose down(" + str(random_choice) + ")")
elif random_choice == 4:
snake2.key_event("up")
# print("Random snake chose up(" + str(random_choice) + ")")
poss_actions = ("left", "right", "up", "down")
snake2.key_event(random.choice(poss_actions))
# random_choice = random.randint(1,4)
# if random_choice == 1:
# snake2.key_event("left")
# # print("Random snake chose left(" + str(random_choice) + ")")
# elif random_choice == 2:
# snake2.key_event("right")
# # print("Random snake chose right(" + str(random_choice) + ")")
# elif random_choice == 3:
# snake2.key_event("down")
# # print("Random snake chose down(" + str(random_choice) + ")")
# elif random_choice == 4:
# snake2.key_event("up")
# # print("Random snake chose up(" + str(random_choice) + ")")
# determine if a crash happened
crash1 = snake1.update()
......@@ -554,7 +566,7 @@ def step(action):
playerwin = snake_crash([snake1, snake2]) if playerwin == 0 else playerwin
if snake1.eat():
reward += 1
reward += 10
snake2.eat()
if snake1.getscore() == apple_win_count:
......@@ -564,6 +576,7 @@ def step(action):
if check_if_crossing(snake1.pos[0] / block_size, snake1.pos[1] / block_size):
# print("Action needed!")
reward -= 1
agent_apples = []
for a, id in apples:
for apple in a:
......@@ -588,6 +601,15 @@ def render():
elif pos == [int(snake2.pos[1]//path), int(snake2.pos[0]//path)]:
print("+S+", end='')
isMap = False
else:
for tail in snake1.poslist:
if pos == [int(tail[1]//path), int(tail[0]//path)]:
print("-T-", end='')
isMap = False
for tail in snake2.poslist:
if pos == [int(tail[1]//path), int(tail[0]//path)]:
print("+T+", end='')
isMap = False
if not isMap:
continue
......@@ -620,17 +642,15 @@ def reset():
global snake1, snake2
global gameExit
global playerwin
rounds += 1
print("Round: " + str(rounds))
for a, id in apples:
a.clear()
while apple_count > len(a):
apple = gen_rand_apple(id)
a.add(apple)
snake1 = Snake(startsnake1["pos"], startsnake1["vel"], startsnake1["angle"], 0, act_shield=True,
snake1 = Snake(startsnake1["pos"], startsnake1["vel"], startsnake1["angle"], 0, act_shield=False,
length=snake_length)
snake2 = Snake(startsnake2["pos"], startsnake2["vel"], startsnake2["angle"], 1, act_shield=False,
snake2 = Snake(startsnake2["pos"], startsnake2["vel"], startsnake2["angle"], 1, act_shield=True,
length=snake_length, dir=False)
snake1.set_enemy_snake(snake2.shield_snake)
snake2.set_enemy_snake(snake1.shield_snake)
......
......@@ -16,6 +16,7 @@ class State:
# describes in which direction a crash would be inevitable
# binary: 0 no crash || 1 crash
def __init__(self, df_l_, df_r_, df_u_, df_d_, kill_l_, kill_r_, kill_u_, kill_d_):
# self.id = -1
self.df_l = df_l_
self.df_r = df_r_
self.df_d = df_d_
......@@ -25,6 +26,13 @@ class State:
self.kill_d = kill_d_
self.kill_u = kill_u_
def __hash__(self):
return hash((self.df_l, self.df_r, self.df_d, self.df_u, self.kill_l, self.kill_r, self.kill_d, self.kill_u))
def __eq__(self, other):
return (self.df_l, self.df_r, self.df_d, self.df_u, self.kill_l, self.kill_r, self.kill_d, self.kill_u) == \
(other.df_l, other.df_r, other.df_d, other.df_u, other.kill_l, other.kill_r, other.kill_d, other.kill_u)
# checks which directions are possible
def get_surround(x, y):
......
No preview for this file type
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment