Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ public override bool MoveNext()
return false;
}

public override void Dispose()
{
if (_enumerator is { } e)
{
_enumerator = null;
e.Dispose();
}

base.Dispose();
}

public override TResult[] ToArray()
{
TResult[] array = new TResult[_source.Count];
Expand Down
11 changes: 11 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ public override bool MoveNext()
return false;
}

public override void Dispose()
{
if (_enumerator is { } e)
{
_enumerator = null;
e.Dispose();
}

base.Dispose();
}

public override IEnumerable<TSource> Where(Func<TSource, bool> predicate) =>
new SizeOptIListWhereIterator<TSource>(_source, Utilities.CombinePredicates(_predicate, predicate));

Expand Down
59 changes: 59 additions & 0 deletions src/libraries/System.Linq/tests/EnumerableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -439,5 +439,64 @@ public DelegateIterator(

void IDisposable.Dispose() => _dispose();
}

protected sealed class DisposeTrackingList<T> : IList<T>, IReadOnlyList<T>
{
private readonly List<T> _list;
private int _disposeCalls;

public DisposeTrackingList(T[] items)
{
_list = [.. items];
}

public int DisposeCalls => _disposeCalls;

public T this[int index]
{
get => _list[index];
set => _list[index] = value;
}

public int Count => _list.Count;

public bool IsReadOnly => false;

public void Add(T item) => _list.Add(item);
public void Clear() => _list.Clear();
public bool Contains(T item) => _list.Contains(item);
public void CopyTo(T[] array, int arrayIndex) => _list.CopyTo(array, arrayIndex);
public int IndexOf(T item) => _list.IndexOf(item);
public void Insert(int index, T item) => _list.Insert(index, item);
public bool Remove(T item) => _list.Remove(item);
public void RemoveAt(int index) => _list.RemoveAt(index);

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

public IEnumerator<T> GetEnumerator() => new DisposeTrackingEnumerator(this);

private sealed class DisposeTrackingEnumerator : IEnumerator<T>
{
private readonly DisposeTrackingList<T> _parent;
private readonly IEnumerator<T> _enumerator;

public DisposeTrackingEnumerator(DisposeTrackingList<T> parent)
{
_parent = parent;
_enumerator = parent._list.GetEnumerator();
}

public T Current => _enumerator.Current;
object? IEnumerator.Current => Current;
public bool MoveNext() => _enumerator.MoveNext();
public void Reset() => throw new NotSupportedException();

public void Dispose()
{
_parent._disposeCalls++;
_enumerator.Dispose();
}
}
}
}
}
25 changes: 25 additions & 0 deletions src/libraries/System.Linq/tests/SelectTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1261,5 +1261,30 @@ public static IEnumerable<object[]> RunSelectorDuringCountData()
}
}
}

[Fact]
public void Select_SourceIsIList_EnumeratorDisposedOnComplete()
{
var source = new DisposeTrackingList<int>([1, 2, 3, 4, 5]);

foreach (int item in source.Select(i => i * 2))
{
}

Assert.Equal(1, source.DisposeCalls);
}

[Fact]
public void Select_SourceIsIList_EnumeratorDisposedOnExplicitDispose()
{
var source = new DisposeTrackingList<int>([1, 2, 3, 4, 5]);

using (var enumerator = source.Select(i => i * 2).GetEnumerator())
{
enumerator.MoveNext();
}

Assert.Equal(1, source.DisposeCalls);
}
}
}
25 changes: 25 additions & 0 deletions src/libraries/System.Linq/tests/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1163,5 +1163,30 @@ private static IEnumerable<int> GenerateRandomSequnce(uint seed, int count)
yield return random.Next(int.MinValue, int.MaxValue);
}
}

[Fact]
public void Where_SourceIsIList_EnumeratorDisposedOnComplete()
{
var source = new DisposeTrackingList<int>([1, 2, 3, 4, 5]);

foreach (int item in source.Where(i => i % 2 == 0))
{
}

Assert.Equal(1, source.DisposeCalls);
}

[Fact]
public void Where_SourceIsIList_EnumeratorDisposedOnExplicitDispose()
{
var source = new DisposeTrackingList<int>([1, 2, 3, 4, 5]);

using (var enumerator = source.Where(i => i % 2 == 0).GetEnumerator())
{
enumerator.MoveNext();
}

Assert.Equal(1, source.DisposeCalls);
}
}
}