diff --git a/src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs b/src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs index bf476355511abd..688b81e82b4d53 100644 --- a/src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs @@ -73,6 +73,17 @@ public override bool MoveNext() return false; } + public override void Dispose() + { + if (_enumerator is { } e) + { + _enumerator = null; + e.Dispose(); + } + + base.Dispose(); + } + public override TResult[] ToArray() { TResult[] array = new TResult[_source.Count]; diff --git a/src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs b/src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs index ebb94853261d81..67bbb347a69965 100644 --- a/src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs @@ -51,6 +51,17 @@ public override bool MoveNext() return false; } + public override void Dispose() + { + if (_enumerator is { } e) + { + _enumerator = null; + e.Dispose(); + } + + base.Dispose(); + } + public override IEnumerable Where(Func predicate) => new SizeOptIListWhereIterator(_source, Utilities.CombinePredicates(_predicate, predicate)); diff --git a/src/libraries/System.Linq/tests/EnumerableTests.cs b/src/libraries/System.Linq/tests/EnumerableTests.cs index 66bc25861e080e..7e237f80a01300 100644 --- a/src/libraries/System.Linq/tests/EnumerableTests.cs +++ b/src/libraries/System.Linq/tests/EnumerableTests.cs @@ -439,5 +439,64 @@ public DelegateIterator( void IDisposable.Dispose() => _dispose(); } + + protected sealed class DisposeTrackingList : IList, IReadOnlyList + { + private readonly List _list; + private int _disposeCalls; + + public DisposeTrackingList(T[] items) + { + _list = [.. items]; + } + + public int DisposeCalls => _disposeCalls; + + public T this[int index] + { + get => _list[index]; + set => _list[index] = value; + } + + public int Count => _list.Count; + + public bool IsReadOnly => false; + + public void Add(T item) => _list.Add(item); + public void Clear() => _list.Clear(); + public bool Contains(T item) => _list.Contains(item); + public void CopyTo(T[] array, int arrayIndex) => _list.CopyTo(array, arrayIndex); + public int IndexOf(T item) => _list.IndexOf(item); + public void Insert(int index, T item) => _list.Insert(index, item); + public bool Remove(T item) => _list.Remove(item); + public void RemoveAt(int index) => _list.RemoveAt(index); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public IEnumerator GetEnumerator() => new DisposeTrackingEnumerator(this); + + private sealed class DisposeTrackingEnumerator : IEnumerator + { + private readonly DisposeTrackingList _parent; + private readonly IEnumerator _enumerator; + + public DisposeTrackingEnumerator(DisposeTrackingList parent) + { + _parent = parent; + _enumerator = parent._list.GetEnumerator(); + } + + public T Current => _enumerator.Current; + object? IEnumerator.Current => Current; + public bool MoveNext() => _enumerator.MoveNext(); + public void Reset() => throw new NotSupportedException(); + + public void Dispose() + { + _parent._disposeCalls++; + _enumerator.Dispose(); + } + } + } } } diff --git a/src/libraries/System.Linq/tests/SelectTests.cs b/src/libraries/System.Linq/tests/SelectTests.cs index 5fdae70d06df38..6427f2758c3bb2 100644 --- a/src/libraries/System.Linq/tests/SelectTests.cs +++ b/src/libraries/System.Linq/tests/SelectTests.cs @@ -1261,5 +1261,30 @@ public static IEnumerable RunSelectorDuringCountData() } } } + + [Fact] + public void Select_SourceIsIList_EnumeratorDisposedOnComplete() + { + var source = new DisposeTrackingList([1, 2, 3, 4, 5]); + + foreach (int item in source.Select(i => i * 2)) + { + } + + Assert.Equal(1, source.DisposeCalls); + } + + [Fact] + public void Select_SourceIsIList_EnumeratorDisposedOnExplicitDispose() + { + var source = new DisposeTrackingList([1, 2, 3, 4, 5]); + + using (var enumerator = source.Select(i => i * 2).GetEnumerator()) + { + enumerator.MoveNext(); + } + + Assert.Equal(1, source.DisposeCalls); + } } } diff --git a/src/libraries/System.Linq/tests/WhereTests.cs b/src/libraries/System.Linq/tests/WhereTests.cs index f877e7bbaba4a6..e2f86b3336ad3f 100644 --- a/src/libraries/System.Linq/tests/WhereTests.cs +++ b/src/libraries/System.Linq/tests/WhereTests.cs @@ -1163,5 +1163,30 @@ private static IEnumerable GenerateRandomSequnce(uint seed, int count) yield return random.Next(int.MinValue, int.MaxValue); } } + + [Fact] + public void Where_SourceIsIList_EnumeratorDisposedOnComplete() + { + var source = new DisposeTrackingList([1, 2, 3, 4, 5]); + + foreach (int item in source.Where(i => i % 2 == 0)) + { + } + + Assert.Equal(1, source.DisposeCalls); + } + + [Fact] + public void Where_SourceIsIList_EnumeratorDisposedOnExplicitDispose() + { + var source = new DisposeTrackingList([1, 2, 3, 4, 5]); + + using (var enumerator = source.Where(i => i % 2 == 0).GetEnumerator()) + { + enumerator.MoveNext(); + } + + Assert.Equal(1, source.DisposeCalls); + } } }