|
34 | 34 |
|
35 | 35 | package net.imglib2.algorithm.linalg.eigen; |
36 | 36 |
|
37 | | -import java.util.ArrayList; |
38 | | -import java.util.List; |
39 | | -import java.util.concurrent.Callable; |
40 | | -import java.util.concurrent.ExecutionException; |
41 | 37 | import java.util.concurrent.ExecutorService; |
42 | | -import java.util.concurrent.Future; |
43 | 38 |
|
44 | | -import net.imglib2.Cursor; |
45 | | -import net.imglib2.FinalInterval; |
46 | 39 | import net.imglib2.RandomAccessibleInterval; |
47 | 40 | import net.imglib2.img.Img; |
48 | 41 | import net.imglib2.img.ImgFactory; |
| 42 | +import net.imglib2.loops.LoopBuilder; |
| 43 | +import net.imglib2.parallel.Parallelization; |
| 44 | +import net.imglib2.parallel.TaskExecutors; |
49 | 45 | import net.imglib2.type.numeric.ComplexType; |
50 | 46 | import net.imglib2.type.numeric.RealType; |
51 | | -import net.imglib2.view.IntervalView; |
52 | 47 | import net.imglib2.view.Views; |
53 | | -import net.imglib2.view.composite.NumericComposite; |
54 | | -import net.imglib2.view.composite.RealComposite; |
| 48 | +import net.imglib2.view.composite.CompositeIntervalView; |
| 49 | +import net.imglib2.view.composite.GenericComposite; |
55 | 50 |
|
56 | 51 | /** |
57 | 52 | * |
@@ -286,82 +281,26 @@ public static < T extends RealType< T >, U extends ComplexType< U > > RandomAcce |
286 | 281 | final ExecutorService es ) |
287 | 282 | { |
288 | 283 |
|
289 | | - assert nTasks > 0: "Passed nTasks < 1"; |
290 | | - |
291 | | - final int tensorDims = tensor.numDimensions(); |
292 | | - |
293 | | - long dimensionMax = Long.MIN_VALUE; |
294 | | - int dimensionArgMax = -1; |
295 | | - |
296 | | - for ( int d = 0; d < tensorDims - 1; ++d ) |
297 | | - { |
298 | | - final long size = tensor.dimension( d ); |
299 | | - if ( size > dimensionMax ) |
300 | | - { |
301 | | - dimensionMax = size; |
302 | | - dimensionArgMax = d; |
303 | | - } |
304 | | - } |
305 | | - |
306 | | - final long stepSize = Math.max( dimensionMax / nTasks, 1 ); |
307 | | - final long stepSizeMinusOne = stepSize - 1; |
308 | | - final long max = dimensionMax - 1; |
309 | | - |
310 | | - final ArrayList< Callable< RandomAccessibleInterval< U > > > tasks = new ArrayList<>(); |
311 | | - for ( long currentMin = 0; currentMin < dimensionMax; currentMin += stepSize ) |
312 | | - { |
313 | | - final long currentMax = Math.min( currentMin + stepSizeMinusOne, max ); |
314 | | - final long[] minT = new long[ tensorDims ]; |
315 | | - final long[] maxT = new long[ tensorDims ]; |
316 | | - final long[] minE = new long[ tensorDims ]; |
317 | | - final long[] maxE = new long[ tensorDims ]; |
318 | | - tensor.min( minT ); |
319 | | - tensor.max( maxT ); |
320 | | - eigenvalues.min( minE ); |
321 | | - eigenvalues.max( maxE ); |
322 | | - minE[ dimensionArgMax ] = minT[ dimensionArgMax ] = currentMin; |
323 | | - maxE[ dimensionArgMax ] = maxT[ dimensionArgMax ] = currentMax; |
324 | | - final IntervalView< T > currentTensor = Views.interval( tensor, new FinalInterval( minT, maxT ) ); |
325 | | - final IntervalView< U > currentEigenvalues = Views.interval( eigenvalues, new FinalInterval( minE, maxE ) ); |
326 | | - tasks.add( () -> calculateEigenValuesImpl( currentTensor, currentEigenvalues, ev.copy() ) ); |
327 | | - } |
328 | | - |
329 | | - |
330 | | - try |
331 | | - { |
332 | | - final List< Future< RandomAccessibleInterval< U > > > futures = es.invokeAll( tasks ); |
333 | | - for ( final Future< RandomAccessibleInterval< U > > f : futures ) |
334 | | - try |
335 | | - { |
336 | | - f.get(); |
337 | | - } |
338 | | - catch ( final ExecutionException e ) |
339 | | - { |
340 | | - // TODO Auto-generated catch block |
341 | | - e.printStackTrace(); |
342 | | - } |
343 | | - } |
344 | | - catch ( final InterruptedException e ) |
345 | | - { |
346 | | - // TODO Auto-generated catch block |
347 | | - e.printStackTrace(); |
348 | | - } |
349 | | - |
350 | | - return eigenvalues; |
351 | | - |
352 | | - |
| 284 | + assert nTasks > 0 : "Passed nTasks < 1"; |
353 | 285 |
|
| 286 | + return Parallelization.runWithExecutor( TaskExecutors.forExecutorServiceAndNumTasks( es, nTasks ), |
| 287 | + () -> calculateEigenValues( tensor, eigenvalues, ev ) ); |
354 | 288 | } |
355 | 289 |
|
356 | 290 | private static < T extends RealType< T >, U extends ComplexType< U > > RandomAccessibleInterval< U > calculateEigenValuesImpl( |
357 | 291 | final RandomAccessibleInterval< T > tensor, |
358 | 292 | final RandomAccessibleInterval< U > eigenvalues, |
359 | 293 | final EigenValues< T, U > ev ) |
360 | 294 | { |
361 | | - final Cursor< RealComposite< T > > m = Views.iterable( Views.collapseReal( tensor ) ).cursor(); |
362 | | - final Cursor< NumericComposite< U > > e = Views.iterable( Views.collapseNumeric( eigenvalues ) ).cursor(); |
363 | | - while ( m.hasNext() ) |
364 | | - ev.compute( m.next(), e.next() ); |
| 295 | + RandomAccessibleInterval< ? extends GenericComposite< T > > tensorVectors = Views.collapse( tensor ); |
| 296 | + CompositeIntervalView< U, ? extends GenericComposite< U > > eigenvaluesVectors = Views.collapse( eigenvalues ); |
| 297 | + LoopBuilder.setImages( tensorVectors, eigenvaluesVectors ) |
| 298 | + .multiThreaded() |
| 299 | + .forEachChunk( chunk -> { |
| 300 | + EigenValues< T, U > copy = ev.copy(); |
| 301 | + chunk.forEachPixel( copy::compute ); |
| 302 | + return null; |
| 303 | + } ); |
365 | 304 | return eigenvalues; |
366 | 305 | } |
367 | 306 |
|
|
0 commit comments