Skip to content

Commit 9add70f

Browse files
Add files via upload
1 parent 7cdfe5d commit 9add70f

File tree

7 files changed

+373
-0
lines changed

7 files changed

+373
-0
lines changed

Snake_AI.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import torch
2+
import random
3+
import numpy as np
4+
from collections import deque
5+
from game import SnakeGameAI, Direction, Point
6+
from model import Linear_QNet, QTrainer
7+
from helper import plot
8+
import joblib
9+
import os
10+
from pathlib import Path
11+
12+
MAX_MEMORY = 100_000
13+
BATCH_SIZE = 1000
14+
LR = 0.001
15+
16+
class Agent:
17+
18+
def __init__(self):
19+
self.n_games = 0
20+
self.epsilon = 0
21+
self.gamma = 0.9
22+
self.memory = deque(maxlen=MAX_MEMORY) #
23+
self.model = Linear_QNet(11, 512, 3)
24+
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
25+
26+
27+
def get_state(self, game):
28+
head = game.snake[0]
29+
point_l = Point(head.x - 20, head.y)
30+
point_r = Point(head.x + 20, head.y)
31+
point_u = Point(head.x, head.y - 20)
32+
point_d = Point(head.x, head.y + 20)
33+
34+
dir_l = game.direction == Direction.LEFT
35+
dir_r = game.direction == Direction.RIGHT
36+
dir_u = game.direction == Direction.UP
37+
dir_d = game.direction == Direction.DOWN
38+
39+
state = [
40+
41+
(dir_r and game.is_collision(point_r)) or
42+
(dir_l and game.is_collision(point_l)) or
43+
(dir_u and game.is_collision(point_u)) or
44+
(dir_d and game.is_collision(point_d)),
45+
46+
(dir_u and game.is_collision(point_r)) or
47+
(dir_d and game.is_collision(point_l)) or
48+
(dir_l and game.is_collision(point_u)) or
49+
(dir_r and game.is_collision(point_d)),
50+
51+
(dir_d and game.is_collision(point_r)) or
52+
(dir_u and game.is_collision(point_l)) or
53+
(dir_r and game.is_collision(point_u)) or
54+
(dir_l and game.is_collision(point_d)),
55+
56+
dir_l,
57+
dir_r,
58+
dir_u,
59+
dir_d,
60+
61+
game.food.x < game.head.x,
62+
game.food.x > game.head.x,
63+
game.food.y < game.head.y,
64+
game.food.y > game.head.y
65+
]
66+
67+
return np.array(state, dtype=int)
68+
69+
def remember(self, state, action, reward, next_state, done):
70+
self.memory.append((state, action, reward, next_state, done))
71+
72+
def train_long_memory(self):
73+
if len(self.memory) > BATCH_SIZE:
74+
mini_sample = random.sample(self.memory, BATCH_SIZE)
75+
else:
76+
mini_sample = self.memory
77+
states, actions, rewards, next_states, dones = zip(*mini_sample)
78+
self.trainer.train_step(states, actions, rewards, next_states, dones)
79+
80+
def train_short_memory(self, state, action, reward, next_state, done):
81+
self.trainer.train_step(state, action, reward, next_state, done)
82+
83+
def get_action(self, state):
84+
self.epsilon = 80 - self.n_games
85+
final_move = [0,0,0]
86+
if random.randint(0, 200) < self.epsilon:
87+
move = random.randint(0, 2)
88+
final_move[move] = 1
89+
else:
90+
state0 = torch.tensor(state, dtype=torch.float)
91+
prediction = self.model(state0)
92+
move = torch.argmax(prediction).item()
93+
final_move[move] = 1
94+
95+
return final_move
96+
97+
98+
def train():
99+
100+
plot_scores = []
101+
plot_mean_scores = []
102+
total_score = 0
103+
record = 0
104+
agent = Agent()
105+
ans = int(input("Train or Load Model [0,1]"))
106+
107+
while ans not in [0,1]:
108+
print("invalid input")
109+
ans = int(input("Train or Load Model [0,1]"))
110+
111+
if ans == 1:
112+
agent = joblib.load((os.fspath(Path(__file__).resolve().parent / "best.pkl")))
113+
114+
game = SnakeGameAI()
115+
116+
while True:
117+
if game.save == True:
118+
joblib.dump(agent, (os.fspath(Path(__file__).resolve().parent / "model.pkl")))
119+
state_old = agent.get_state(game)
120+
final_move = agent.get_action(state_old)
121+
reward, done, score = game.play_step(final_move)
122+
state_new = agent.get_state(game)
123+
agent.train_short_memory(state_old, final_move, reward, state_new, done)
124+
agent.remember(state_old, final_move, reward, state_new, done)
125+
126+
if done:
127+
game.reset()
128+
agent.n_games += 1
129+
agent.train_long_memory()
130+
131+
if score > record:
132+
record = score
133+
joblib.dump(agent,(os.fspath(Path(__file__).resolve().parent / "model.pkl")))
134+
agent.model.save()
135+
136+
print('Game', agent.n_games, 'Score', score, 'Record:', record)
137+
138+
plot_scores.append(score)
139+
total_score += score
140+
mean_score = total_score / agent.n_games
141+
plot_mean_scores.append(mean_score)
142+
plot(plot_scores, plot_mean_scores)
143+
144+
145+
if __name__ == '__main__':
146+
train()
147+

best.pkl

4.63 MB
Binary file not shown.

game.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import pygame
2+
import random
3+
from enum import Enum
4+
from collections import namedtuple
5+
import numpy as np
6+
7+
8+
pygame.init()
9+
10+
font = pygame.font.SysFont('calibri', 20)
11+
12+
class Direction(Enum):
13+
RIGHT = 1
14+
LEFT = 2
15+
UP = 3
16+
DOWN = 4
17+
18+
Point = namedtuple('Point', 'x, y')
19+
20+
WHITE = (255, 255, 255)
21+
RED = (248,143,147)
22+
GREEN = (186,217,181)
23+
FONT = (66,12,20)
24+
BACK = (239,247,207)
25+
26+
BLOCK_SIZE = 20
27+
SPEED = 20
28+
29+
class SnakeGameAI:
30+
31+
def __init__(self, w=640, h=480):
32+
self.w = w
33+
self.h = h
34+
self.display = pygame.display.set_mode((self.w, self.h))
35+
pygame.display.set_caption('Snake')
36+
self.clock = pygame.time.Clock()
37+
self.reset()
38+
self.save = False
39+
40+
def reset(self):
41+
self.direction = Direction.RIGHT
42+
43+
self.head = Point(self.w/2, self.h/2)
44+
self.snake = [self.head,
45+
Point(self.head.x-BLOCK_SIZE, self.head.y),
46+
Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]
47+
48+
self.score = 0
49+
self.food = None
50+
self._place_food()
51+
self.loop_count = 0
52+
53+
def _place_food(self):
54+
x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
55+
y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
56+
self.food = Point(x, y)
57+
if self.food in self.snake:
58+
self._place_food()
59+
60+
def play_step(self, action):
61+
self.loop_count += 1
62+
for event in pygame.event.get():
63+
if event.type == pygame.QUIT:
64+
pygame.quit()
65+
quit()
66+
keys_pressed = pygame.key.get_pressed()
67+
68+
if keys_pressed[pygame.K_u]:
69+
print("saving")
70+
self.save = True
71+
self._move(action)
72+
self.snake.insert(0, self.head)
73+
reward = 0
74+
game_over = False
75+
if self.is_collision() or self.loop_count > 100*len(self.snake):
76+
reward = -10
77+
game_over = True
78+
return reward, game_over, self.score
79+
80+
if self.head == self.food:
81+
self.score += 1
82+
reward = 10
83+
self._place_food()
84+
else:
85+
self.snake.pop()
86+
87+
self._update_ui()
88+
self.clock.tick(SPEED)
89+
return reward,game_over, self.score
90+
91+
def is_collision(self, pt=None):
92+
if pt is None:
93+
pt = self.head
94+
95+
if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0:
96+
return True
97+
if pt in self.snake[1:]:
98+
return True
99+
100+
return False
101+
102+
def _update_ui(self):
103+
self.display.fill(BACK)
104+
105+
for pt in self.snake:
106+
pygame.draw.rect(self.display, GREEN, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
107+
108+
pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))
109+
110+
text = font.render( str(self.score), True, GREEN)
111+
self.display.blit(text, [320, 5])
112+
pygame.display.flip()
113+
114+
def _move(self, action):
115+
clock_wise = [Direction.RIGHT, Direction.DOWN,Direction.LEFT,Direction.UP]
116+
idx = clock_wise.index(self.direction)
117+
118+
if np.array_equal(action, [1,0,0]):
119+
new_dir = clock_wise[idx]
120+
121+
elif np.array_equal(action, [0,1,0]):
122+
next_idx = (idx+1) % 4
123+
new_dir = clock_wise[next_idx]
124+
else:
125+
next_idx = (idx-1) % 4
126+
new_dir = clock_wise[next_idx]
127+
self.direction = new_dir
128+
129+
x = self.head.x
130+
y = self.head.y
131+
if self.direction == Direction.RIGHT:
132+
x += BLOCK_SIZE
133+
elif self.direction == Direction.LEFT:
134+
x -= BLOCK_SIZE
135+
elif self.direction == Direction.DOWN:
136+
y += BLOCK_SIZE
137+
elif self.direction == Direction.UP:
138+
y -= BLOCK_SIZE
139+
140+
self.head = Point(x, y)
141+

helper.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import matplotlib.pyplot as plt
2+
from IPython import display
3+
4+
plt.ion()
5+
6+
def plot(scores, mean_scores):
7+
display.clear_output(wait=True)
8+
display.display(plt.gcf())
9+
plt.clf()
10+
plt.style.use('ggplot')
11+
plt.title('Training...')
12+
plt.xlabel('Number of Games')
13+
plt.ylabel('Score')
14+
plt.plot(scores)
15+
plt.plot(mean_scores,linestyle=":")
16+
plt.ylim(ymin=0)
17+
plt.text(len(scores)-1, scores[-1], str(scores[-1]))
18+
plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))
19+
plt.show(block=False)
20+
plt.pause(.1)

model.pkl

4.63 MB
Binary file not shown.

model.pth

31.4 KB
Binary file not shown.

model.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import torch.nn.functional as F
5+
import os
6+
from pathlib import Path
7+
8+
class Linear_QNet(nn.Module):
9+
def __init__(self, input_size, hidden_size, output_size):
10+
super().__init__()
11+
self.linear1 = nn.Linear(input_size, hidden_size)
12+
self.linear2 = nn.Linear(hidden_size, output_size)
13+
14+
def forward(self, x):
15+
x = F.relu(self.linear1(x))
16+
x = self.linear2(x)
17+
return x
18+
19+
def save(self, file_name='model.pth'):
20+
torch.save(self.state_dict(), (os.fspath(Path(__file__).resolve().parent / file_name)))
21+
22+
def load(self):
23+
self = torch.load((os.fspath(Path(__file__).resolve().parent / 'model.pth')))
24+
25+
26+
class QTrainer:
27+
def __init__(self, model, lr, gamma):
28+
self.lr = lr
29+
self.gamma = gamma
30+
self.model = model
31+
self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
32+
self.criterion = nn.MSELoss()
33+
34+
def train_step(self, state, action, reward, next_state, done):
35+
state = torch.tensor(state, dtype=torch.float)
36+
next_state = torch.tensor(next_state, dtype=torch.float)
37+
action = torch.tensor(action, dtype=torch.long)
38+
reward = torch.tensor(reward, dtype=torch.float)
39+
40+
if len(state.shape) == 1:
41+
state = torch.unsqueeze(state, 0)
42+
next_state = torch.unsqueeze(next_state, 0)
43+
action = torch.unsqueeze(action, 0)
44+
reward = torch.unsqueeze(reward, 0)
45+
done = (done, )
46+
pred = self.model(state)
47+
48+
target = pred.clone()
49+
for idx in range(len(done)):
50+
Q_new = reward[idx]
51+
if not done[idx]:
52+
Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
53+
54+
target[idx][torch.argmax(action[idx]).item()] = Q_new
55+
56+
self.optimizer.zero_grad()
57+
loss = self.criterion(target, pred)
58+
loss.backward()
59+
60+
self.optimizer.step()
61+
62+
63+
64+
65+

0 commit comments

Comments
 (0)