Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrates nll_loss_backward from TH to Aten (CUDA) #60299

Closed

Conversation

thomasjpfan
Copy link
Contributor

Fixes #24609
Aten Umbrella issue #24507
Related to #59765

There are no performance differences when running the following benchmark:

Benchmark script
import torch
import torch.nn as nn
import time

torch.manual_seed(0)


def _time():
    torch.cuda.synchronize()
    MS_PER_SECOND = 1000
    return time.perf_counter() * MS_PER_SECOND


device = "cuda"
C = 30
softmax = nn.LogSoftmax(dim=1)
n_runs = 250

for reduction in ["none", "mean", "sum"]:
    for N in [100_000, 500_000, 1_000_000]:
        elapsed = 0
        for i in range(n_runs):
            data = torch.randn(N, C, device=device, requires_grad=True)
            target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
            loss = nn.NLLLoss(reduction=reduction)
            input = softmax(data)
            result = loss(input, target)

            if reduction == "none":
                gradient = torch.randn(N, device=device)
            else:
                gradient = torch.randn(1, device=device).squeeze()

            t1 = _time()
            result.backward(gradient)
            t2 = _time()
            elapsed = elapsed + (t2 - t1)
        elapsed_avg = elapsed / n_runs
        print(
            f"input size({N}, {C}), reduction: {reduction} "
            f"elapsed time is {elapsed_avg:.2f} (ms)"
        )
    print()

master

input size(100000, 30), reduction: none elapsed time is 0.19 (ms)
input size(500000, 30), reduction: none elapsed time is 0.83 (ms)
input size(1000000, 30), reduction: none elapsed time is 1.66 (ms)

input size(100000, 30), reduction: mean elapsed time is 1.50 (ms)
input size(500000, 30), reduction: mean elapsed time is 7.19 (ms)
input size(1000000, 30), reduction: mean elapsed time is 14.35 (ms)

input size(100000, 30), reduction: sum elapsed time is 1.49 (ms)
input size(500000, 30), reduction: sum elapsed time is 7.17 (ms)
input size(1000000, 30), reduction: sum elapsed time is 14.21 (ms)

this PR

input size(100000, 30), reduction: none elapsed time is 0.19 (ms)
input size(500000, 30), reduction: none elapsed time is 0.83 (ms)
input size(1000000, 30), reduction: none elapsed time is 1.66 (ms)

input size(100000, 30), reduction: mean elapsed time is 1.48 (ms)
input size(500000, 30), reduction: mean elapsed time is 7.16 (ms)
input size(1000000, 30), reduction: mean elapsed time is 14.29 (ms)

input size(100000, 30), reduction: sum elapsed time is 1.49 (ms)
input size(500000, 30), reduction: sum elapsed time is 7.15 (ms)
input size(1000000, 30), reduction: sum elapsed time is 14.18 (ms)

@thomasjpfan thomasjpfan added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 18, 2021
@thomasjpfan thomasjpfan requested a review from ezyang as a code owner June 18, 2021 20:45
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 18, 2021

💊 CI failures summary and remediations

As of commit 3eccb41 (more details on the Dr. CI page and at hud.pytorch.org/pr/60299):


  • 4/4 failures possibly* introduced in this PR
    • 2/4 non-scanned failure(s)

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_bionic_rocm3_9_py3_6_build (1/2)

Step: "Spin up environment" (full log | diagnosis details | 🔁 rerun)

Waiting for a VM assignment: .......................................................................
Build-agent version 1.0.74137-e7d5cf4b (2021-06-21T13:20:20+0000)
Creating a dedicated VM with ubuntu-2004:202104-01 image
Waiting for a VM assignment: ............................................................................................................................................................................................................................................................................................................

We timed out preparing a VM for this build, potentially due to our infrastructure or cloud provider.  Please retry the build in a few minutes

Unexpected capacity error: error caused by capacity

See CircleCI build pytorch_macos_10_13_py3_test (2/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Jun 21 22:55:44 ERROR [0.004s]: test_poisson_sample (__main__.TestDistributions)
Jun 21 22:55:44   File "distributions/test_distributions.py", line 805, in _check_sampler_discrete
Jun 21 22:55:44     chisq, p = scipy.stats.chisquare(counts[msk], pmf[msk] * num_samples)
Jun 21 22:55:44   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/scipy/stats/stats.py", line 6853, in chisquare
Jun 21 22:55:44     lambda_="pearson")
Jun 21 22:55:44   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/scipy/stats/stats.py", line 6694, in power_divergence
Jun 21 22:55:44     raise ValueError(msg)
Jun 21 22:55:44 ValueError: For each axis slice, the sum of the observed frequencies must agree with the sum of the expected frequencies to a relative tolerance of 1e-08, but the percent differences are:
Jun 21 22:55:44 0.008265582255680495
Jun 21 22:55:44 
Jun 21 22:55:44 ======================================================================
Jun 21 22:55:44 ERROR [0.004s]: test_poisson_sample (__main__.TestDistributions)
Jun 21 22:55:44 ----------------------------------------------------------------------
Jun 21 22:55:44 Traceback (most recent call last):
Jun 21 22:55:44   File "distributions/test_distributions.py", line 1333, in test_poisson_sample
Jun 21 22:55:44     failure_rate=1e-3)
Jun 21 22:55:44   File "distributions/test_distributions.py", line 805, in _check_sampler_discrete
Jun 21 22:55:44     chisq, p = scipy.stats.chisquare(counts[msk], pmf[msk] * num_samples)
Jun 21 22:55:44   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/scipy/stats/stats.py", line 6853, in chisquare
Jun 21 22:55:44     lambda_="pearson")
Jun 21 22:55:44   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/scipy/stats/stats.py", line 6694, in power_divergence
Jun 21 22:55:44     raise ValueError(msg)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@thomasjpfan thomasjpfan marked this pull request as ready for review June 21, 2021 03:19
@thomasjpfan thomasjpfan requested a review from ngimel June 21, 2021 15:34
@ezyang ezyang removed their request for review June 21, 2021 22:16
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, thank you!

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 99ca2c5.

facebook-github-bot pushed a commit that referenced this pull request Jun 23, 2021
Summary:
Addresses a part of #59765

This PR adds byte support for nll_loss on the CPU for `input.dim() == 2`.

CUDA support will be implemented when `nll_loss` migration to CUDA is completed in #60299 and #60097

Pull Request resolved: #60308

Reviewed By: VitalyFedyunin

Differential Revision: D29329458

Pulled By: jbschlosser

fbshipit-source-id: d3585c4966030bc61e451f8aa817406a8a3acf47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Migrate nll_loss_backward from the TH to Aten (CUDA)
4 participants