You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Training on n epochs results in 1 more step in n+1 epoch
while state["step_count"] < max_steps and train_iterator.epoch < train.epochs:
state["iter_num"] += 1
iter_t0 = time.perf_counter()
batch = next(train_iterator)
The problem is caused by next(train_iterator) being called after checking if train_iterator.epoch < train.epochs.
Epoch 1 | iter 1 step 1 | loss train: 2.112, val: n/a | iter time: 3882.13 ms (step)
Epoch 1 | iter 1 step 1 | loss train: 2.167, val: n/a | iter time: 3887.25 ms (step)
Epoch 1 | iter 2 step 2 | loss train: 1.257, val: n/a | iter time: 11192.95 ms (step)
Epoch 1 | iter 2 step 2 | loss train: 1.258, val: n/a | iter time: 11224.30 ms (step)
Epoch 2 | iter 3 step 3 | loss train: 2.108, val: n/a | iter time: 3683.60 ms (step)
Epoch 2 | iter 3 step 3 | loss train: 2.165, val: n/a | iter time: 3726.28 ms (step)
for epoch = 1
What operating system are you using?
Linux
LitGPT Version
Version: 0.4.12
The text was updated successfully, but these errors were encountered:
Yes, this is the fix I would suggest as well. This needs to be added in several places, and I just want to test this carefully. (Probably also requires slight adjustments to some CI tests because the loss values will change). It's on my list!
Bug description
Training on n epochs results in 1 more step in n+1 epoch
The problem is caused by
next(train_iterator)
being called after checking if train_iterator.epoch < train.epochs.for epoch = 1
What operating system are you using?
Linux
LitGPT Version
The text was updated successfully, but these errors were encountered: