Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training on n epochs results in 1 more step in n+1 epoch #1793

Closed
Oseltamivir opened this issue Oct 19, 2024 · 4 comments · Fixed by #1794
Closed

Training on n epochs results in 1 more step in n+1 epoch #1793

Oseltamivir opened this issue Oct 19, 2024 · 4 comments · Fixed by #1794
Assignees
Labels
bug Something isn't working

Comments

@Oseltamivir
Copy link

Bug description

Training on n epochs results in 1 more step in n+1 epoch

while state["step_count"] < max_steps and train_iterator.epoch < train.epochs:
        state["iter_num"] += 1
        iter_t0 = time.perf_counter()
        batch = next(train_iterator)

The problem is caused by next(train_iterator) being called after checking if train_iterator.epoch < train.epochs.

Epoch 1 | iter 1 step 1 | loss train: 2.112, val: n/a | iter time: 3882.13 ms (step)
Epoch 1 | iter 1 step 1 | loss train: 2.167, val: n/a | iter time: 3887.25 ms (step)
Epoch 1 | iter 2 step 2 | loss train: 1.257, val: n/a | iter time: 11192.95 ms (step)
Epoch 1 | iter 2 step 2 | loss train: 1.258, val: n/a | iter time: 11224.30 ms (step)
Epoch 2 | iter 3 step 3 | loss train: 2.108, val: n/a | iter time: 3683.60 ms (step)
Epoch 2 | iter 3 step 3 | loss train: 2.165, val: n/a | iter time: 3726.28 ms (step)

for epoch = 1

What operating system are you using?

Linux

LitGPT Version

Version: 0.4.12
@Oseltamivir Oseltamivir added the bug Something isn't working label Oct 19, 2024
@rasbt
Copy link
Contributor

rasbt commented Oct 19, 2024

Thanks for reporting! And yes, I can confirm that this has been an issue. I have it on my list of things to address in the next few weeks.

@rasbt rasbt self-assigned this Oct 19, 2024
@Oseltamivir
Copy link
Author

Oseltamivir commented Oct 19, 2024

Thanks for reporting! And yes, I can confirm that this has been an issue. I have it on my list of things to address in the next few weeks.

May I suggest instead:

while state["step_count"] < max_steps:
        state["iter_num"] += 1
        iter_t0 = time.perf_counter()
        batch = next(train_iterator)
        if train_iterator.epoch >= train.epochs:
                break

I don't think this is worth a PR, but if you'd want one, I can create one

@rasbt
Copy link
Contributor

rasbt commented Oct 19, 2024

Yes, this is the fix I would suggest as well. This needs to be added in several places, and I just want to test this carefully. (Probably also requires slight adjustments to some CI tests because the loss values will change). It's on my list!

@Oseltamivir
Copy link
Author

Alright, perfect. Thanks for the quick replies, i'll leave this in for my patch for this code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants