@@ -237,6 +237,52 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws Run
237237 inputsVector .close ();
238238 inputsVector .deallocate ();
239239 }
240+
241+ /**
242+ * {@inheritDoc}
243+ *
244+ * Run a Pytorch model using JavaCpp on the data provided by the {@link Tensor} input list
245+ * and modifies the output list with the results obtained
246+ *
247+ */
248+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
249+ List <RandomAccessibleInterval <R >> inference (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
250+ if (interprocessing ) {
251+ return runInterprocessing (inputs );
252+ }
253+ IValueVector inputsVector = new IValueVector ();
254+ for (RandomAccessibleInterval <T > tt : inputs ) {
255+ inputsVector .put (new IValue (JavaCPPTensorBuilder .build (tt )));
256+ }
257+ // Run model
258+ model .eval ();
259+ IValue output = model .forward (inputsVector );
260+ TensorVector outputTensorVector = null ;
261+ if (output .isTensorList ()) {
262+ outputTensorVector = output .toTensorVector ();
263+ } else {
264+ outputTensorVector = new TensorVector ();
265+ outputTensorVector .put (output .toTensor ());
266+ }
267+ // Fill the agnostic output tensors list with data from the inference result
268+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
269+ for (int i = 0 ; i < outputTensorVector .size (); i ++) {
270+ rais .add (ImgLib2Builder .build (outputTensorVector .get (i )));
271+ outputTensorVector .get (i ).close ();
272+ outputTensorVector .get (i ).deallocate ();
273+ }
274+ outputTensorVector .close ();
275+ outputTensorVector .deallocate ();
276+ output .close ();
277+ output .deallocate ();
278+ for (int i = 0 ; i < inputsVector .size (); i ++) {
279+ inputsVector .get (i ).close ();
280+ inputsVector .get (i ).deallocate ();
281+ }
282+ inputsVector .close ();
283+ inputsVector .deallocate ();
284+ return rais ;
285+ }
240286
241287 protected void runFromShmas (List <String > inputs , List <String > outputs ) throws IOException {
242288 IValueVector inputsVector = new IValueVector ();
@@ -276,17 +322,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
276322 inputsVector .deallocate ();
277323 }
278324
279- /**
280- * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
281- * to create another process, communicate the model info and tensors to the other
282- * process and then retrieve the results of the other process
283- * @param inputTensors
284- * tensors that are going to be run on the model
285- * @param outputTensors
286- * expected results of the model
287- * @throws RunModelException if there is any issue running the model
288- */
289- public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
325+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
326+ IValueVector inputsVector = new IValueVector ();
327+ for (String ee : inputs ) {
328+ Map <String , Object > decoded = Types .decode (ee );
329+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
330+ org .bytedeco .pytorch .Tensor inT = TensorBuilder .build (shma );
331+ inputsVector .put (new IValue (inT ));
332+ if (PlatformDetection .isWindows ()) shma .close ();
333+ }
334+ // Run model
335+ model .eval ();
336+ IValue output = model .forward (inputsVector );
337+ TensorVector outputTensorVector = null ;
338+ if (output .isTensorList ()) {
339+ outputTensorVector = output .toTensorVector ();
340+ } else {
341+ outputTensorVector = new TensorVector ();
342+ outputTensorVector .put (output .toTensor ());
343+ }
344+
345+ shmaNamesList = new ArrayList <String >();
346+ for (int i = 0 ; i < outputTensorVector .size (); i ++) {
347+ String name = SharedMemoryArray .createShmName ();
348+ ShmBuilder .build (outputTensorVector .get (i ), name , false );
349+ shmaNamesList .add (name );
350+ }
351+ outputTensorVector .close ();
352+ outputTensorVector .deallocate ();
353+ output .close ();
354+ output .deallocate ();
355+ for (int i = 0 ; i < inputsVector .size (); i ++) {
356+ inputsVector .get (i ).close ();
357+ inputsVector .get (i ).deallocate ();
358+ }
359+ inputsVector .close ();
360+ inputsVector .deallocate ();
361+ return shmaNamesList ;
362+ }
363+
364+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
290365 void runInterprocessing (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
291366 shmaInputList = new ArrayList <SharedMemoryArray >();
292367 shmaOutputList = new ArrayList <SharedMemoryArray >();
@@ -297,7 +372,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
297372 args .put ("outputs" , encOuts );
298373
299374 try {
300- Task task = runner .task ("inference " , args );
375+ Task task = runner .task ("run " , args );
301376 task .waitFor ();
302377 if (task .status == TaskStatus .CANCELED )
303378 throw new RuntimeException ();
@@ -328,7 +403,89 @@ else if (task.status == TaskStatus.CRASHED) {
328403 closeShmas ();
329404 }
330405
331- private void closeShmas () {
406+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
407+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
408+ shmaInputList = new ArrayList <SharedMemoryArray >();
409+ List <String > encIns = new ArrayList <String >();
410+ Gson gson = new Gson ();
411+ for (RandomAccessibleInterval <T > tt : inputs ) {
412+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
413+ shmaInputList .add (shma );
414+ HashMap <String , Object > map = new HashMap <String , Object >();
415+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
416+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
417+ map .put (IS_INPUT_KEY , true );
418+ map .put (MEM_NAME_KEY , shma .getName ());
419+ encIns .add (gson .toJson (map ));
420+ }
421+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
422+ args .put ("inputs" , encIns );
423+
424+ try {
425+ Task task = runner .task ("inference" , args );
426+ task .waitFor ();
427+ if (task .status == TaskStatus .CANCELED )
428+ throw new RuntimeException ();
429+ else if (task .status == TaskStatus .FAILED )
430+ throw new RuntimeException (task .error );
431+ else if (task .status == TaskStatus .CRASHED ) {
432+ this .runner .close ();
433+ runner = null ;
434+ throw new RuntimeException (task .error );
435+ } else if (task .outputs == null )
436+ throw new RuntimeException ("No outputs generated" );
437+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
438+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
439+ for (String out : outputs ) {
440+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
441+ SharedMemoryArray shm = SharedMemoryArray .read (name );
442+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
443+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
444+ shm .close ();
445+ }
446+ closeShmas ();
447+ return rais ;
448+ } catch (Exception e ) {
449+ closeShmas ();
450+ if (e instanceof RunModelException )
451+ throw (RunModelException ) e ;
452+ throw new RunModelException (Types .stackTrace (e ));
453+ }
454+ }
455+
456+ private void closeInterprocess () throws RunModelException {
457+ try {
458+ Task task = runner .task ("closeTensors" );
459+ task .waitFor ();
460+ if (task .status == TaskStatus .CANCELED )
461+ throw new RuntimeException ();
462+ else if (task .status == TaskStatus .FAILED )
463+ throw new RuntimeException (task .error );
464+ else if (task .status == TaskStatus .CRASHED ) {
465+ this .runner .close ();
466+ runner = null ;
467+ throw new RuntimeException (task .error );
468+ }
469+ } catch (Exception e ) {
470+ if (e instanceof RunModelException )
471+ throw (RunModelException ) e ;
472+ throw new RunModelException (Types .stackTrace (e ));
473+ }
474+ }
475+
476+ protected void closeFromInterp () {
477+ if (!PlatformDetection .isWindows ())
478+ return ;
479+ this .shmaNamesList .stream ().forEach (nn -> {
480+ try {
481+ SharedMemoryArray .read (nn ).close ();
482+ } catch (IOException e ) {
483+ e .printStackTrace ();
484+ }
485+ });
486+ }
487+
488+ private void closeShmas () throws RunModelException {
332489 shmaInputList .forEach (shm -> {
333490 try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
334491 });
@@ -337,6 +494,8 @@ private void closeShmas() {
337494 try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
338495 });
339496 shmaOutputList = null ;
497+ if (interprocessing )
498+ closeInterprocess ();
340499 }
341500
342501 /**
0 commit comments