From f120d522798d73db6e999d9e5d465244e3c1b48a Mon Sep 17 00:00:00 2001 From: Joel Verhagen Date: Thu, 19 Apr 2018 12:44:06 -0700 Subject: [PATCH] Fix collection modified exception which causes rare HTTP 500s (#5830) Add unit tests to verify casing changes and repro of bug Use ConcurrentDictionary instead of Dictionary Address https://github.com/NuGet/NuGetGallery/issues/5779 --- .../Services/CloudDownloadCountService.cs | 140 +++++++------ .../NuGetGallery.Facts.csproj | 1 + .../CloudDownloadCountServiceFacts.cs | 191 ++++++++++++++++++ 3 files changed, 270 insertions(+), 62 deletions(-) create mode 100644 tests/NuGetGallery.Facts/Services/CloudDownloadCountServiceFacts.cs diff --git a/src/NuGetGallery/Services/CloudDownloadCountService.cs b/src/NuGetGallery/Services/CloudDownloadCountService.cs index 8d019752d0..cbe2d423bb 100644 --- a/src/NuGetGallery/Services/CloudDownloadCountService.cs +++ b/src/NuGetGallery/Services/CloudDownloadCountService.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; @@ -27,9 +28,8 @@ public class CloudDownloadCountService : IDownloadCountService private readonly object _refreshLock = new object(); private bool _isRefreshing; - private readonly IDictionary> _downloadCounts = new Dictionary>(StringComparer.OrdinalIgnoreCase); - - public DateTime LastRefresh { get; protected set; } + private readonly ConcurrentDictionary> _downloadCounts + = new ConcurrentDictionary>(StringComparer.OrdinalIgnoreCase); public CloudDownloadCountService(ITelemetryClient telemetryClient, string connectionString, bool readAccessGeoRedundant) { @@ -38,7 +38,7 @@ public CloudDownloadCountService(ITelemetryClient telemetryClient, string connec _connectionString = connectionString; _readAccessGeoRedundant = readAccessGeoRedundant; } - + public bool TryGetDownloadCountForPackageRegistration(string id, out int downloadCount) { if (string.IsNullOrEmpty(id)) @@ -46,11 +46,9 @@ public bool TryGetDownloadCountForPackageRegistration(string id, out int downloa throw new ArgumentNullException(nameof(id)); } - id = id.ToLowerInvariant(); - - if (_downloadCounts.ContainsKey(id)) + if (_downloadCounts.TryGetValue(id, out var versions)) { - downloadCount = _downloadCounts[id].Sum(kvp => kvp.Value); + downloadCount = CalculateSum(versions); return true; } @@ -70,16 +68,10 @@ public bool TryGetDownloadCountForPackage(string id, string version, out int dow throw new ArgumentNullException(nameof(version)); } - id = id.ToLowerInvariant(); - version = version.ToLowerInvariant(); - - if (_downloadCounts.ContainsKey(id)) + if (_downloadCounts.TryGetValue(id, out var versions) + && versions.TryGetValue(version, out downloadCount)) { - if (_downloadCounts[id].ContainsKey(version)) - { - downloadCount = _downloadCounts[id][version]; - return true; - } + return true; } downloadCount = 0; @@ -127,78 +119,102 @@ public void Refresh() finally { _isRefreshing = false; - LastRefresh = DateTime.UtcNow; } } } + /// + /// This method is added for unit testing purposes. + /// + protected virtual int CalculateSum(ConcurrentDictionary versions) + { + return versions.Sum(kvp => kvp.Value); + } + + /// + /// This method is added for unit testing purposes. It can return a null stream if the blob does not exist + /// and assumes the caller will properly dispose of the returned stream. + /// + protected virtual Stream GetBlobStream() + { + var blob = GetBlobReference(); + if (blob == null) + { + return null; + } + + return blob.OpenRead(); + } + private void RefreshCore() { try { - var blob = GetBlobReference(); - if (blob == null) - { - return; - } - // The data in downloads.v1.json will be an array of Package records - which has Id, Array of Versions and download count. // Sample.json : [["AutofacContrib.NSubstitute",["2.4.3.700",406],["2.5.0",137]],["Assman.Core",["2.0.7",138]].... - using (var jsonReader = new JsonTextReader(new StreamReader(blob.OpenRead()))) + using (var blobStream = GetBlobStream()) { - try + if (blobStream == null) { - jsonReader.Read(); + return; + } - while (jsonReader.Read()) + using (var jsonReader = new JsonTextReader(new StreamReader(blobStream))) + { + try { - try + jsonReader.Read(); + + while (jsonReader.Read()) { - if (jsonReader.TokenType == JsonToken.StartArray) + try { - JToken record = JToken.ReadFrom(jsonReader); - string id = record[0].ToString().ToLowerInvariant(); - - // The second entry in each record should be an array of versions, if not move on to next entry. - // This is a check to safe guard against invalid entries. - if (record.Count() == 2 && record[1].Type != JTokenType.Array) + if (jsonReader.TokenType == JsonToken.StartArray) { - continue; - } + JToken record = JToken.ReadFrom(jsonReader); + string id = record[0].ToString().ToLowerInvariant(); - if (!_downloadCounts.ContainsKey(id)) - { - _downloadCounts.Add(id, new Dictionary(StringComparer.OrdinalIgnoreCase)); - } - var versions = _downloadCounts[id]; + // The second entry in each record should be an array of versions, if not move on to next entry. + // This is a check to safe guard against invalid entries. + if (record.Count() == 2 && record[1].Type != JTokenType.Array) + { + continue; + } - foreach (JToken token in record) - { - if (token != null && token.Count() == 2) + var versions = _downloadCounts.GetOrAdd( + id, + _ => new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase)); + + foreach (JToken token in record) { - string version = token[0].ToString().ToLowerInvariant(); - versions[version] = token[1].ToObject(); + if (token != null && token.Count() == 2) + { + var version = token[0].ToString(); + var downloadCount = token[1].ToObject(); + + versions.AddOrSet(version, downloadCount); + } } } } - } - catch (JsonReaderException ex) - { - _telemetryClient.TrackException(ex, new Dictionary + catch (JsonReaderException ex) { - { "Origin", TelemetryOriginForRefreshMethod }, - { "AdditionalInfo", "Invalid entry found in downloads.v1.json." } - }); + _telemetryClient.TrackException(ex, new Dictionary + { + { "Origin", TelemetryOriginForRefreshMethod }, + { "AdditionalInfo", "Invalid entry found in downloads.v1.json." } + }); + } } } - } - catch (JsonReaderException ex) - { - _telemetryClient.TrackException(ex, new Dictionary + catch (JsonReaderException ex) { - { "Origin", TelemetryOriginForRefreshMethod }, - { "AdditionalInfo", "Data present in downloads.v1.json is invalid. Couldn't get download data." } - }); + _telemetryClient.TrackException(ex, new Dictionary + { + { "Origin", TelemetryOriginForRefreshMethod }, + { "AdditionalInfo", "Data present in downloads.v1.json is invalid. Couldn't get download data." } + }); + } } } } diff --git a/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj b/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj index ca62132d1c..0a778c1c09 100644 --- a/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj +++ b/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj @@ -445,6 +445,7 @@ + diff --git a/tests/NuGetGallery.Facts/Services/CloudDownloadCountServiceFacts.cs b/tests/NuGetGallery.Facts/Services/CloudDownloadCountServiceFacts.cs new file mode 100644 index 0000000000..864eda150b --- /dev/null +++ b/tests/NuGetGallery.Facts/Services/CloudDownloadCountServiceFacts.cs @@ -0,0 +1,191 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using Xunit; + +namespace NuGetGallery +{ + public class CloudDownloadCountServiceFacts + { + public class TheTryGetDownloadCountForPackageRegistrationMethod : BaseFacts + { + [Theory] + [InlineData("NuGet.Versioning", "NuGet.Frameworks")] + [InlineData("NuGet.Versioning", " NuGet.Versioning ")] + [InlineData("NuGet.Versioning", "NuGet.Versioning ")] + [InlineData("NuGet.Versioning", " NuGet.Versioning")] + [InlineData("İ", "ı")] + [InlineData("İ", "i")] + [InlineData("İ", "I")] + [InlineData("ı", "İ")] + [InlineData("ı", "i")] + [InlineData("ı", "I")] + [InlineData("i", "İ")] + [InlineData("i", "ı")] + [InlineData("I", "İ")] + [InlineData("I", "ı")] + public void ReturnsZeroWhenIdDoesNotExist(string inputId, string contentId) + { + // Arrange + _content = $"[[\"{contentId}\",[\"4.6.0\",23],[\"4.6.2\",42]]"; + _target.Refresh(); + + // Act + var found = _target.TryGetDownloadCountForPackageRegistration(inputId, out var actual); + + // Assert + Assert.Equal(0, actual); + Assert.False(found, "The package ID should not have been found."); + } + + [Theory] + [InlineData("NuGet.Versioning", "NuGet.Versioning")] + [InlineData("NUGET.VERSIONING", "nuget.versioning")] + [InlineData("nuget.versioning", "NUGET.VERSIONING")] + [InlineData("İ", "İ")] + [InlineData("ı", "ı")] + public void ReturnsSumOfVersionsWhenIdExists(string inputId, string contentId) + { + // Arrange + _content = $"[[\"{contentId}\",[\"4.6.0\",23],[\"4.6.2\",42]]"; + _target.Refresh(); + + // Act + var found = _target.TryGetDownloadCountForPackageRegistration(inputId, out var actual); + + // Assert + Assert.Equal(23 + 42, actual); + Assert.True(found, "The package ID should have been found."); + } + + [Fact] + public async Task CanCalculatedDownloadCountsDuringRefresh() + { + // Arrange + var id = "NuGet.Versioning"; + var duration = TimeSpan.FromSeconds(3); + var refreshTask = LoadNewVersionsAsync(id, TimeSpan.FromSeconds(1)); + var getDownloadCountsTask = GetDownloadCountsAsync(id, duration); + _calculateSum = v => v.Sum(kvp => { Thread.Sleep(1); return kvp.Value; }); + + // Act & Assert + // We are just ensuring that no exception is thrown. + await Task.WhenAll(refreshTask, getDownloadCountsTask); + } + + private async Task LoadNewVersionsAsync(string id, TimeSpan duration) + { + await Task.Yield(); + var stopwatch = Stopwatch.StartNew(); + var iteration = 0; + while (stopwatch.Elapsed < duration) + { + iteration++; + var version = $"0.0.0-beta{iteration}"; + _content = $"[[\"{id}\",[\"{version}\",1]]"; + _target.Refresh(); + await Task.Delay(5); + } + } + + private async Task GetDownloadCountsAsync(string id, TimeSpan duration) + { + await Task.Yield(); + var stopwatch = Stopwatch.StartNew(); + while (stopwatch.Elapsed < duration) + { + _target.TryGetDownloadCountForPackageRegistration(id, out var downloadCount); + await Task.Delay(5); + } + } + } + + public class TheTryGetDownloadCountForPackageMethod : BaseFacts + { + [Fact] + public void ReturnsZeroWhenVersionDoesNotExist() + { + // Arrange + _target.Refresh(); + + // Act + var found = _target.TryGetDownloadCountForPackage("NuGet.Versioning", "9.9.9", out var actual); + + // Assert + Assert.Equal(0, actual); + Assert.False(found, "The package version should not have been found."); + } + + [Fact] + public void ReturnsCountWhenVersionExists() + { + // Arrange + _target.Refresh(); + + // Act + var found = _target.TryGetDownloadCountForPackage("NuGet.Versioning", "4.6.0", out var actual); + + // Assert + Assert.Equal(23, actual); + Assert.True(found, "The package version should have been found."); + } + } + + public class BaseFacts + { + internal readonly Mock _telemetryService; + internal string _content; + internal Func, int> _calculateSum; + internal TestableCloudDownloadCountService _target; + + public BaseFacts() + { + _telemetryService = new Mock(); + _content = "[[\"NuGet.Versioning\",[\"4.6.0\",23],[\"4.6.2\",42]]"; + _calculateSum = null; + _target = new TestableCloudDownloadCountService(this); + } + } + + public class TestableCloudDownloadCountService : CloudDownloadCountService + { + private readonly BaseFacts _baseFacts; + + public TestableCloudDownloadCountService(BaseFacts baseFacts) + : base(baseFacts._telemetryService.Object, "UseDevelopmentStorage=true", readAccessGeoRedundant: true) + { + _baseFacts = baseFacts; + } + + protected override int CalculateSum(ConcurrentDictionary versions) + { + if (_baseFacts._calculateSum == null) + { + return base.CalculateSum(versions); + } + + return _baseFacts._calculateSum(versions); + } + + protected override Stream GetBlobStream() + { + if (_baseFacts._content == null) + { + return null; + } + + return new MemoryStream(Encoding.UTF8.GetBytes(_baseFacts._content)); + } + } + } +}