Skip to content

Commit c012be5

Browse files
committed
added try/except to training evaluation and unsqueeze correctly to model
1 parent 9affcff commit c012be5

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
# python-ner
22
A neural NER implementation built first for unstructured text
3+
4+
# Resources:
5+
https://medium.com/@rohit.sharma_7010/a-complete-tutorial-for-named-entity-recognition-and-extraction-in-natural-language-processing-71322b6fb090
6+
https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html
7+
https://pytorch-crf.readthedocs.io/en/stable/
8+
https://www.aclweb.org/anthology/Y18-1061

model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ def forward(self, sentence):
4242
:return: a list of words
4343
"""
4444
self.hidden = self.init_hidden()
45-
# we unsqueeze because elmo only takes batches
45+
# embeds = [1, sentence_len, 1024]
4646
embeds = self.elmo([sentence])
4747
# lstm_out = (1, seq_len, hidden_size * 2)
4848
lstm_out, hidden_out = self.lstm(embeds, self.hidden)
4949
predictions = self.linear(lstm_out)
5050
# squeeze the result to get rid of the batch for (seq_len, 2)
51-
return predictions.squeeze()
51+
predictions = predictions.squeeze(0)
52+
return predictions
5253

5354
def evaluate(self, sentence):
5455
"""

train.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from entity_recognition_datasets.src import utils
1212
from model import BiLSTM
1313

14-
EPOCHS = 5
14+
EPOCHS = 4
1515
LOSS_FUNC = nn.CrossEntropyLoss()
1616

1717

@@ -50,7 +50,7 @@ def train_single(sentence, optimizer, backprop):
5050
optimizer.step()
5151
loss = loss.item()
5252
except Exception as e:
53-
print("e")
53+
print(e)
5454
print("words: {}".format(words))
5555
print("tags: {}".format(tags))
5656
print("continuing...")
@@ -72,32 +72,32 @@ def train():
7272
dev_losses = []
7373
for epoch in range(EPOCHS):
7474
print("EPOCH {}/{}".format(epoch + 1, EPOCHS))
75-
start = time.time()
7675

7776
# run train epoch
77+
start = time.time()
7878
train_loss = 0
7979
for sentence in tqdm(train_data, desc="train-set"):
8080
loss = train_single(sentence, optimizer, backprop=True)
8181
train_loss += loss
8282
train_loss /= len(train_data)
8383
train_losses.append(train_loss)
84+
duration = time.time() - start
85+
print("train set completed in {:.3f}s, {:.3f}s per iteration".format(duration, duration / len(train_data)))
8486

8587
# run a dev epoch
88+
start = time.time()
8689
dev_loss = 0
8790
with torch.no_grad():
8891
for sentence in tqdm(dev_data, desc="dev-set"):
8992
loss = train_single(sentence, optimizer, backprop=False)
90-
train_loss += loss
93+
dev_loss += loss
9194
dev_loss /= len(dev_data)
9295
dev_losses.append(dev_loss)
96+
duration = time.time() - start
97+
print("dev set completed in {:.3f}s, {:.3f}s per iteration".format(duration, duration / len(dev_data)))
9398

9499
print("train loss = {}".format(train_loss))
95100
print("dev loss = {}".format(dev_loss))
96-
duration = time.time() - start
97-
print("epoch completed in {:.3f}s, {:.3f}s per iteration".format(
98-
duration,
99-
duration / (len(train_data) + len(dev_data))
100-
))
101101

102102
losses = {
103103
"train": train_losses,
@@ -120,11 +120,17 @@ def evaluate():
120120
with torch.no_grad():
121121
confusion = np.zeros((2, 2))
122122
for sentence in tqdm(test_data, desc="train-set"):
123-
words, tags = get_words_and_tags(sentence)
124-
pred = model.evaluate(words)
125-
assert len(pred) == len(tags)
126-
for i in range(len(pred)):
127-
confusion[pred[i]][tags[i]] += 1
123+
try:
124+
words, tags = get_words_and_tags(sentence)
125+
pred = model.evaluate(words)
126+
assert len(pred) == len(tags)
127+
for i in range(len(pred)):
128+
confusion[pred[i]][tags[i]] += 1
129+
except Exception as e:
130+
print(e)
131+
print("words: {}".format(words))
132+
print("tags: {}".format(tags))
133+
print("continuing...")
128134

129135
confusion /= np.sum(confusion)
130136
return confusion

0 commit comments

Comments
 (0)