Skip to content

Commit 0fea299

Browse files
committed
[MINOR] OOC Bugfix Cache Reference Management + Return Right BlockKey on Externally Managed Grouped Callbacks
1 parent 28ebb6c commit 0fea299

4 files changed

Lines changed: 102 additions & 5 deletions

File tree

src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ public CachingStream(OOCStream<IndexedMatrixValue> source, long streamId) {
130130
boolean ownsEntry = true;
131131
if(tmp instanceof OOCCacheManager.CachedGroupCallback<?> cachedGroup) {
132132
baseKey = cachedGroup.getBlockKey();
133+
ensureReferencedOrRematerialize(baseKey, cachedGroup);
133134
ownsEntry = false;
134135
if(mSubscribers != null && mSubscribers.length > 0)
135136
mCallback = tmp.keepOpen();
@@ -183,12 +184,14 @@ public CachingStream(OOCStream<IndexedMatrixValue> source, long streamId) {
183184

184185
if(tmp instanceof OOCCacheManager.CachedQueueCallback<?> cachedQueue) {
185186
blockKey = cachedQueue.getBlockKey();
187+
ensureReferencedOrRematerialize(blockKey, task);
186188
ownsEntry = false;
187189
if(mSubscribers != null && mSubscribers.length > 0)
188190
mCallback = tmp.keepOpen();
189191
}
190192
else if(tmp instanceof OOCCacheManager.CachedSubCallback<?> cachedSub) {
191193
BlockKey parent = cachedSub.getParent().getBlockKey();
194+
ensureReferencedOrRematerialize(parent, cachedSub.getParent());
192195
blockKey = new GroupedBlockKey(parent.getStreamId(), (int) parent.getSequenceNumber(),
193196
cachedSub.getGroupIndex());
194197
ownsEntry = false;
@@ -297,6 +300,49 @@ else if(tmp instanceof OOCCacheManager.CachedSubCallback<?> cachedSub) {
297300
});
298301
}
299302

303+
304+
private void ensureReferencedOrRematerialize(BlockKey key, IndexedMatrixValue value) {
305+
try {
306+
OOCCacheManager.getCache().addReference(key);
307+
}
308+
catch(IllegalArgumentException ex) {
309+
try {
310+
OOCCacheManager.putRaw(key, value, ((MatrixBlock) value.getValue()).getExactSerializedSize());
311+
}
312+
catch(IllegalStateException putEx) {
313+
// Another downstream stream may have re-materialized the same entry first.
314+
OOCCacheManager.getCache().addReference(key);
315+
}
316+
}
317+
}
318+
319+
private void ensureReferencedOrRematerialize(BlockKey key, OOCCacheManager.CachedGroupCallback<?> group) {
320+
try {
321+
OOCCacheManager.getCache().addReference(key);
322+
}
323+
catch(IllegalArgumentException ex) {
324+
try {
325+
List<IndexedMatrixValue> values = new ArrayList<>(group.size());
326+
long totalSize = 0;
327+
for(int gi = 0; gi < group.size(); gi++) {
328+
@SuppressWarnings("unchecked")
329+
OOCStream.QueueCallback<IndexedMatrixValue> sub =
330+
(OOCStream.QueueCallback<IndexedMatrixValue>) group.getCallback(gi);
331+
try(sub) {
332+
IndexedMatrixValue imv = sub.get();
333+
values.add(imv);
334+
totalSize += ((MatrixBlock) imv.getValue()).getExactSerializedSize();
335+
}
336+
}
337+
OOCCacheManager.putRaw(key, values, totalSize);
338+
}
339+
catch(IllegalStateException putEx) {
340+
// Another downstream stream may have re-materialized the same entry first.
341+
OOCCacheManager.getCache().addReference(key);
342+
}
343+
}
344+
}
345+
300346
private String getCtxMsg() {
301347
StackTraceElement[] st = new Exception().getStackTrace();
302348
// Skip the first few frames (constructor, createWritableStream, etc.)
@@ -687,7 +733,7 @@ public void setSubscriber(Consumer<OOCStream.QueueCallback<IndexedMatrixValue>>
687733
if(groupIdx > 0)
688734
continue; // only replay grouped blocks once at the base index
689735

690-
BlockKey replayKey = (groupSize > 1 && groupIdx == 0) ? new BlockKey(_streamId, idx) : getBlockKey(i);
736+
BlockKey replayKey = (groupSize > 1 && groupIdx == 0) ? getEntryBlockKey(idx) : getBlockKey(i);
691737
OOCCacheManager.requestBlock(replayKey).whenComplete((cb, r) -> {
692738
if(r != null) {
693739
subscriber.accept(OOCStream.eos(DMLRuntimeException.of(r)));
@@ -697,7 +743,6 @@ public void setSubscriber(Consumer<OOCStream.QueueCallback<IndexedMatrixValue>>
697743
synchronized(CachingStream.this) {
698744
if(_index != null) {
699745
if(cb instanceof OOCStream.GroupQueueCallback<?> && groupSize > 1) {
700-
@SuppressWarnings("unchecked")
701746
OOCStream.GroupQueueCallback<IndexedMatrixValue> group =
702747
(OOCStream.GroupQueueCallback<IndexedMatrixValue>) cb;
703748
for(int gi = 0; gi < groupSize; gi++) {

src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public final class BlockEntry {
3030
private volatile BlockState _state;
3131
private Object _data;
3232
private int _retainHintCount;
33+
private int _referenceCount; // The number of references from different managing instances (e.g. CachingStream)
3334

3435
BlockEntry(BlockKey key, long size, Object data) {
3536
this._key = key;
@@ -38,6 +39,7 @@ public final class BlockEntry {
3839
this._state = BlockState.HOT;
3940
this._data = data;
4041
this._retainHintCount = 0;
42+
this._referenceCount = 1;
4143
}
4244

4345
public BlockKey getKey() {
@@ -84,6 +86,14 @@ public boolean isPinned() {
8486
return _pinCount > 0;
8587
}
8688

89+
synchronized int addReference() {
90+
return ++_referenceCount;
91+
}
92+
93+
synchronized int forget() {
94+
return --_referenceCount;
95+
}
96+
8797
synchronized void setState(BlockState state) {
8898
_state = state;
8999
}

src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ public interface OOCCacheScheduler {
103103
BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size,
104104
OOCIOHandler.SourceBlockDescriptor descriptor);
105105

106+
/**
107+
* Notifies the cache that there is another reference to the same block key.
108+
* This will prevent forget(key) from removing the block from cache.
109+
* A block will only be forgotten after all referencing instances called forget(key).
110+
* @param key
111+
*/
112+
void addReference(BlockKey key);
113+
106114
/**
107115
* Forgets a block from the cache.
108116
* @param key the associated key of the block

src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
package org.apache.sysds.runtime.ooc.cache;
2121

22+
import org.apache.commons.lang3.mutable.MutableObject;
2223
import org.apache.commons.logging.Log;
2324
import org.apache.commons.logging.LogFactory;
2425
import org.apache.sysds.api.DMLScript;
2526
import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
2627
import org.apache.sysds.utils.Statistics;
2728
import scala.Tuple2;
2829

30+
import java.lang.ref.Reference;
2931
import java.util.ArrayList;
3032
import java.util.ArrayDeque;
3133
import java.util.Collection;
@@ -291,6 +293,18 @@ public BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size, OO
291293
return put(key, data, size, true, descriptor);
292294
}
293295

296+
@Override
297+
public void addReference(BlockKey key) {
298+
synchronized(this) {
299+
BlockEntry entry = _cache.get(key);
300+
if(entry == null)
301+
entry = _evictionCache.get(key);
302+
if(entry == null)
303+
throw new IllegalArgumentException("Could not find requested block with key " + key);
304+
entry.addReference();
305+
}
306+
}
307+
294308
private BlockEntry put(BlockKey key, Object data, long size, boolean pin, OOCIOHandler.SourceBlockDescriptor descriptor) {
295309
if (!this._running)
296310
throw new IllegalStateException();
@@ -324,14 +338,34 @@ private BlockEntry put(BlockKey key, Object data, long size, boolean pin, OOCIOH
324338
public void forget(BlockKey key) {
325339
if (!this._running)
326340
return;
341+
final MutableObject<BlockEntry> mEntry = new MutableObject<>();
327342
BlockEntry entry;
328343
boolean shouldScheduleDeletion = false;
329344
long cacheSizeDelta = 0;
330345
synchronized(this) {
331-
entry = _cache.remove(key);
346+
_cache.compute(key, (k, e) -> {
347+
if(e == null)
348+
return null;
349+
if(e.forget() == 0) {
350+
mEntry.setValue(e);
351+
return null;
352+
}
353+
return e;
354+
});
332355

333-
if (entry == null)
334-
entry = _evictionCache.remove(key);
356+
if (mEntry.getValue() == null) {
357+
_evictionCache.compute(key, (k, e) -> {
358+
if(e == null)
359+
return null;
360+
if(e.forget() == 0) {
361+
mEntry.setValue(e);
362+
return null;
363+
}
364+
return e;
365+
});
366+
}
367+
368+
entry = mEntry.getValue();
335369

336370
if (entry != null) {
337371
synchronized(entry) {

0 commit comments

Comments
 (0)