fix reshape in smooth lddt loss #205
Open
+5
−10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I believe the reshape on line 161 in
smooth_lddt_loss
is incorrect and will misalignmask
andeps
if both batch and multiplicity are larger than 1.The reason is that if some
xs
has shape(batch, d1, d2, ...)
thenrepeat_interleave
has the following property:That is, since
repeat_interleave
repeats every element of the batch a given number of times, not the whole batch, reshape should put it on dim=1. However, on line 161 in the current code, the reshape puts multiplicity on dim=0, and averages over it.Additionally, since the resulting averaged
eps
(shape(batch, d1, d2, ...)
) doesn't match the mask (which has shape (batch * multiplicity, d1, d2, ...)
),eps
is interleaved again on line 166. If lines 161-162 were to be changed to.view(B // multiplicity, multiplicity, N, N).mean(dim=1)
, therepeat_interleave
on line 166 simply undoes this reshape and average. Thus, I suggest simply removing these lines.