diff --git a/src/Weaviate.Client.Tests/Integration/TestIterator.cs b/src/Weaviate.Client.Tests/Integration/TestIterator.cs index f63a2796..e516f6cc 100644 --- a/src/Weaviate.Client.Tests/Integration/TestIterator.cs +++ b/src/Weaviate.Client.Tests/Integration/TestIterator.cs @@ -35,6 +35,39 @@ var obj in collection.Iterator(cancellationToken: TestContext.Current.Cancellati Assert.Contains("Name 2", names); } + /// + /// Tests that test iterator + /// + [Fact] + public async Task Test_Iterator_With_Filter() + { + var collection = await CollectionFactory( + properties: [Property.Text("name"), Property.Bool("isActive")], + vectorConfig: Configure.Vector("custom", v => v.SelfProvided()) + ); + + await collection.Data.InsertMany( + BatchInsertRequest.Create( + Enumerable.Range(1, 200).Select(i => new { Name = $"Name {i}", IsActive = i == 2 }) + ), + TestContext.Current.CancellationToken + ); + + var names = new List(); + await foreach ( + var obj in collection.Iterator( + filter: Filter.Property("isActive").IsEqual(true), + cacheSize: 10, + cancellationToken: TestContext.Current.CancellationToken + ) + ) + { + obj.Do(o => names.Add(o.Name)); + } + + Assert.Single(names); + } + /// /// Tests that test iterator arguments /// diff --git a/src/Weaviate.Client/CollectionClient.cs b/src/Weaviate.Client/CollectionClient.cs index 22d0d149..b033dde0 100644 --- a/src/Weaviate.Client/CollectionClient.cs +++ b/src/Weaviate.Client/CollectionClient.cs @@ -130,11 +130,13 @@ internal CollectionClient( /// Metadata to include in the response. /// Vector configuration for returned objects. /// Properties to return in the response. + /// Filter to apply to the objects. /// Cross-references to return. /// Cancellation token. /// An async enumerable of WeaviateObject instances. public async IAsyncEnumerable Iterator( Guid? after = null, + Filter? filter = null, uint cacheSize = ITERATOR_CACHE_SIZE, MetadataQuery? returnMetadata = null, VectorQuery? includeVectors = null, @@ -145,15 +147,18 @@ public async IAsyncEnumerable Iterator( { await _client.EnsureInitializedAsync(); Guid? cursor = after; + IDictionary? shardCursors = null; while (true) { cancellationToken.ThrowIfCancellationRequested(); - WeaviateResult page = await _client.GrpcClient.FetchObjects( + var reply = await _client.GrpcClient.FetchObjects( Name, limit: cacheSize, after: cursor, + filters: filter, + shardCursors: shardCursors, returnMetadata: returnMetadata, includeVectors: includeVectors, returnProperties: returnProperties, @@ -162,15 +167,34 @@ public async IAsyncEnumerable Iterator( tenant: Tenant ); - if (!page.Objects.Any()) + WeaviateResult page = reply; + + if (filter is null) { - yield break; + if (!page.Objects.Any()) + { + yield break; + } + + foreach (var c in page.Objects) + { + cursor = c.UUID; + yield return c; + } } - - foreach (var c in page.Objects) + else { - cursor = c.UUID; - yield return c; + foreach (var c in page.Objects) + { + yield return c; + } + + if (reply.ShardCursors.Count == 0) + { + yield break; + } + + shardCursors = reply.ShardCursors; } } } diff --git a/src/Weaviate.Client/Typed/TypedCollectionClient.cs b/src/Weaviate.Client/Typed/TypedCollectionClient.cs index 2e26fbff..7d91b095 100644 --- a/src/Weaviate.Client/Typed/TypedCollectionClient.cs +++ b/src/Weaviate.Client/Typed/TypedCollectionClient.cs @@ -142,6 +142,7 @@ public TypedCollectionClient WithConsistencyLevel(ConsistencyLevels consisten /// Uses cursor-based pagination for efficient iteration over large collections. /// /// Start iteration after this object ID. + /// Filter to apply to the objects. /// Number of objects to fetch per page. /// Metadata to include in results. /// Whether to include vectors. @@ -151,6 +152,7 @@ public TypedCollectionClient WithConsistencyLevel(ConsistencyLevels consisten /// An async enumerable of strongly-typed objects. public async IAsyncEnumerable> Iterator( Guid? after = null, + Filter? filter = null, uint cacheSize = CollectionClient.ITERATOR_CACHE_SIZE, MetadataQuery? returnMetadata = null, VectorQuery? includeVectors = null, @@ -162,6 +164,7 @@ public async IAsyncEnumerable> Iterator( await foreach ( var obj in _collectionClient.Iterator( after, + filter, cacheSize, returnMetadata, includeVectors, diff --git a/src/Weaviate.Client/gRPC/Search.cs b/src/Weaviate.Client/gRPC/Search.cs index 296f5455..66f50e5c 100644 --- a/src/Weaviate.Client/gRPC/Search.cs +++ b/src/Weaviate.Client/gRPC/Search.cs @@ -56,6 +56,7 @@ internal partial class WeaviateGrpcClient /// The return references /// The return metadata /// The include vectors + /// Per-shard cursors for filtered iterator continuation /// The cancellation token /// A task containing the search reply internal async Task FetchObjects( @@ -75,6 +76,7 @@ internal partial class WeaviateGrpcClient IList? returnReferences = null, MetadataQuery? returnMetadata = null, VectorQuery? includeVectors = null, + IDictionary? shardCursors = null, CancellationToken cancellationToken = default ) { @@ -97,6 +99,16 @@ internal partial class WeaviateGrpcClient includeVectors: includeVectors ); + if (filters is not null && !req.HasAfter) + { + req.After = ""; + } + + if (shardCursors is { Count: > 0 }) + { + req.ShardCursors.Add(shardCursors); + } + return await Search(req, cancellationToken); } diff --git a/src/Weaviate.Client/gRPC/proto/v1/search_get.proto b/src/Weaviate.Client/gRPC/proto/v1/search_get.proto index 47de19f1..fbb73736 100644 --- a/src/Weaviate.Client/gRPC/proto/v1/search_get.proto +++ b/src/Weaviate.Client/gRPC/proto/v1/search_get.proto @@ -29,9 +29,14 @@ message SearchRequest { uint32 limit = 30; uint32 offset = 31; uint32 autocut = 32; - string after = 33; + optional string after = 33; // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED repeated SortBy sort_by = 34; + // Per-shard cursor continuation state for filtered iterator mode + // Key: shard name, Value: UUID to start after for that shard + // ONLY used when BOTH 'after' AND 'filters' are set + // Value of uuid.Nil ("00000000-0000-0000-0000-000000000000") indicates shard is exhausted + map shard_cursors = 35; // matches/searches for objects optional Filters filters = 40; @@ -52,7 +57,7 @@ message SearchRequest { bool uses_123_api = 100 [deprecated = true]; bool uses_125_api = 101 [deprecated = true]; - bool uses_127_api = 102; + bool uses_127_api = 102; } message GroupBy { @@ -117,6 +122,10 @@ message SearchReply { optional string generative_grouped_result = 3 [deprecated = true]; repeated GroupByResult group_by_results = 4; optional GenerativeResult generative_grouped_results = 5; + // Per-shard cursor state for pagination continuation (filtered iterator mode) + // Key: shard name, Value: UUID to start after for that shard on next request + // Value of uuid.Nil indicates shard is exhausted (no more results to scan) + map shard_cursors = 6; } message RerankReply {