diff --git a/stream.go b/stream.go index ae468abda..64247c5c7 100644 --- a/stream.go +++ b/stream.go @@ -75,11 +75,11 @@ type Stream struct { // // Note: Calls to KeyToList are concurrent. KeyToList func(key []byte, itr *Iterator) (*pb.KVList, error) - // UseKeyToListWithThreadId is used to indicate that KeyToListWithThreadId should be used - // instead of KeyToList. This is a new api that can be used to figure out parallelism - // of the stream. Each threadId would be run serially. KeyToList being concurrent makes you - // take care of concurrency in KeyToList. Here threadId could be used to do some things serially. - // Once a thread finishes FinishThread() would be called. + // UseKeyToListWithThreadId indicates that KeyToListWithThreadId should be used + // instead of KeyToList to express stream parallelism. Entries with the same + // threadId run serially; different threadIds may run in parallel. This avoids + // handling concurrency inside KeyToList. We call FinishThread() when a thread + // completes. UseKeyToListWithThreadId bool KeyToListWithThreadId func(key []byte, itr *Iterator, threadId int) (*pb.KVList, error) FinishThread func(threadId int) (*pb.KVList, error) @@ -190,6 +190,12 @@ func (st *Stream) produceKVs(ctx context.Context, threadId int) error { _ = outList.Release() }() + if st.FinishThread == nil { + st.FinishThread = func(threadId int) (*pb.KVList, error) { + return &pb.KVList{}, nil + } + } + iterate := func(kr keyRange) error { iterOpts := DefaultIteratorOptions iterOpts.AllVersions = true @@ -267,19 +273,17 @@ func (st *Stream) produceKVs(ctx context.Context, threadId int) error { } } - if st.UseKeyToListWithThreadId { - if kvs, err := st.FinishThread(threadId); err != nil { - return err - } else { - for _, kv := range kvs.Kv { - kv.StreamId = streamId - KVToBuffer(kv, outList) - if outList.LenNoPadding() < batchSize { - continue - } - if err := sendIt(); err != nil { - return err - } + if kvs, err := st.FinishThread(threadId); err != nil { + return err + } else { + for _, kv := range kvs.Kv { + kv.StreamId = streamId + KVToBuffer(kv, outList) + if outList.LenNoPadding() < batchSize { + continue + } + if err := sendIt(); err != nil { + return err } } } diff --git a/stream_test.go b/stream_test.go index a3dd053e2..6fd897833 100644 --- a/stream_test.go +++ b/stream_test.go @@ -26,6 +26,87 @@ func keyWithPrefix(prefix string, k int) []byte { return []byte(fmt.Sprintf("%s-%d", prefix, k)) } +func TestStreamKeyToListWithThreadId_MapReduceWordFreq(t *testing.T) { + dir, err := os.MkdirTemp("", "badger-test") + require.NoError(t, err) + defer removeDir(dir) + + db, err := OpenManaged(DefaultOptions(dir)) + require.NoError(t, err) + + // Seed dataset with words: beta, gamma, alpha + // Mapping: i%3==0 -> beta (33), i%3==1 -> gamma (34), i%3==2 -> alpha (33) per 1..100 + words := []string{"beta", "gamma", "alpha"} + for _, prefix := range []string{"p0", "p1", "p2"} { + txn := db.NewTransactionAt(math.MaxUint64, true) + for i := 1; i <= 100; i++ { + w := words[i%3] + require.NoError(t, txn.SetEntry(NewEntry(keyWithPrefix(prefix, i), []byte(w)))) + } + require.NoError(t, txn.CommitAt(5, nil)) + } + + stream := db.NewStreamAt(math.MaxUint64) + stream.LogPrefix = "Testing" + stream.NumGo = 16 + stream.UseKeyToListWithThreadId = true + + // Accumulate per-thread word counts to be emitted in FinishThread. + // Use a slice indexed by threadId. Each threadId only writes to its own slot. + toCounts := make([]map[string]int, stream.NumGo) + + stream.KeyToListWithThreadId = func(key []byte, itr *Iterator, threadId int) (*bpb.KVList, error) { + item := itr.Item() + val, err := item.ValueCopy(nil) + if err != nil { + return nil, err + } + w := string(val) + if toCounts[threadId] == nil { + toCounts[threadId] = make(map[string]int) + } + toCounts[threadId][w]++ + // Return nothing here; flushing happens in FinishThread. + return &bpb.KVList{}, nil + } + + stream.FinishThread = func(threadId int) (*bpb.KVList, error) { + counts := toCounts[threadId] + if len(counts) == 0 { + return &bpb.KVList{}, nil + } + out := make([]*bpb.KV, 0, len(counts)) + for w, c := range counts { + out = append(out, &bpb.KV{Key: []byte("wc-" + w), Value: []byte(fmt.Sprintf("%d", c))}) + } + return &bpb.KVList{Kv: out}, nil + } + + // Use a sink, but expect no data as KeyToListWithThreadId returns nothing and FinishThread returns empty list + c := &collector{} + stream.Send = c.Send + + require.NoError(t, stream.Orchestrate(ctxb)) + + // Reduce: aggregate per-word totals from partial outputs + totals := map[string]int{} + for _, kv := range c.kv { + if strings.HasPrefix(string(kv.Key), "wc-") { + w := strings.TrimPrefix(string(kv.Key), "wc-") + n, err := strconv.Atoi(string(kv.Value)) + require.NoError(t, err) + totals[w] += n + } + } + + // Expected totals across 3 prefixes x 100 items with distribution defined above + require.Equal(t, 99, totals["alpha"]) // 33 per prefix * 3 + require.Equal(t, 99, totals["beta"]) // 33 per prefix * 3 + require.Equal(t, 102, totals["gamma"]) // 34 per prefix * 3 + + require.NoError(t, db.Close()) +} + func keyToInt(k []byte) (string, int) { splits := strings.Split(string(k), "-") key, err := strconv.Atoi(splits[1]) @@ -160,6 +241,62 @@ func TestStream(t *testing.T) { require.NoError(t, db.Close()) } +func TestStreamKeyToListWithThreadId(t *testing.T) { + dir, err := os.MkdirTemp("", "badger-test") + require.NoError(t, err) + defer removeDir(dir) + + db, err := OpenManaged(DefaultOptions(dir)) + require.NoError(t, err) + + // Seed small dataset + for _, prefix := range []string{"p0", "p1", "p2"} { + txn := db.NewTransactionAt(math.MaxUint64, true) + for i := 1; i <= 100; i++ { + require.NoError(t, txn.SetEntry(NewEntry(keyWithPrefix(prefix, i), value(i)))) + } + require.NoError(t, txn.CommitAt(5, nil)) + } + + stream := db.NewStreamAt(math.MaxUint64) + stream.LogPrefix = "Testing" + stream.NumGo = 4 // fix number of threads for deterministic assertions + stream.UseKeyToListWithThreadId = true + + // Ensure threadId passed to KeyToListWithThreadId matches iterator's ThreadId + stream.KeyToListWithThreadId = func(key []byte, itr *Iterator, threadId int) (*bpb.KVList, error) { + require.Equal(t, threadId, itr.ThreadId) + return stream.ToList(key, itr) + } + + // Emit a per-thread marker to verify FinishThread is invoked once per thread + stream.FinishThread = func(threadId int) (*bpb.KVList, error) { + kv := &bpb.KV{Key: []byte(fmt.Sprintf("done-%d", threadId))} + return &bpb.KVList{Kv: []*bpb.KV{kv}}, nil + } + + c := &collector{} + stream.Send = c.Send + + err = stream.Orchestrate(ctxb) + require.NoError(t, err) + + // Verify presence of FinishThread markers and totals + markers := make(map[string]struct{}) + for _, kv := range c.kv { + if strings.HasPrefix(string(kv.Key), "done-") { + markers[string(kv.Key)] = struct{}{} + } + } + // Total should be data KVs plus marker count + require.Equal(t, 300+len(markers), len(c.kv)) + // We expect at least one marker and at most NumGo markers (ranges may be fewer than NumGo) + require.GreaterOrEqual(t, len(markers), 1) + require.LessOrEqual(t, len(markers), stream.NumGo) + + require.NoError(t, db.Close()) +} + func TestStreamMaxSize(t *testing.T) { if !*manual { t.Skip("Skipping test meant to be run manually.")