@@ -69,19 +69,14 @@ public static void build(Tensor tensor, String memoryName) throws IllegalArgumen
6969 {
7070 if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Byte )
7171 || tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Char )) {
72- System .out .println ("SSECRET_KEY : BYTE " );
7372 buildFromTensorByte (tensor , memoryName );
7473 } else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Int )) {
75- System .out .println ("SSECRET_KEY : INT " );
7674 buildFromTensorInt (tensor , memoryName );
7775 } else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Float )) {
78- System .out .println ("SSECRET_KEY : FLOAT " );
7976 buildFromTensorFloat (tensor , memoryName );
8077 } else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Double )) {
81- System .out .println ("SSECRET_KEY : SOUBKE " );
8278 buildFromTensorDouble (tensor , memoryName );
8379 } else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Long )) {
84- System .out .println ("SSECRET_KEY : LONG " );
8580 buildFromTensorLong (tensor , memoryName );
8681 } else {
8782 throw new IllegalArgumentException ("Unsupported tensor type: " + tensor .scalar_type ());
@@ -98,10 +93,9 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
9893 long flatSize = 1 ;
9994 for (long l : arrayShape ) {flatSize *= l ;}
10095 byte [] flat = new byte [(int ) flatSize ];
101- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize ));
96+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize )). order ( ByteOrder . LITTLE_ENDIAN ) ;
10297 tensor .data_ptr_byte ().get (flat );
10398 byteBuffer .put (flat );
104- byteBuffer .rewind ();
10599 shma .getDataBufferNoHeader ().put (byteBuffer );
106100 if (PlatformDetection .isWindows ()) shma .close ();
107101 }
@@ -116,11 +110,10 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
116110 long flatSize = 1 ;
117111 for (long l : arrayShape ) {flatSize *= l ;}
118112 int [] flat = new int [(int ) flatSize ];
119- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Integer .BYTES ));
113+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Integer .BYTES )). order ( ByteOrder . LITTLE_ENDIAN ) ;
120114 IntBuffer floatBuffer = byteBuffer .asIntBuffer ();
121115 tensor .data_ptr_int ().get (flat );
122116 floatBuffer .put (flat );
123- byteBuffer .rewind ();
124117 shma .getDataBufferNoHeader ().put (byteBuffer );
125118 if (PlatformDetection .isWindows ()) shma .close ();
126119 }
@@ -140,10 +133,6 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
140133 tensor .data_ptr_float ().get (flat );
141134 floatBuffer .put (flat );
142135 shma .getDataBufferNoHeader ().put (byteBuffer );
143- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (100 ) == byteBuffer .get (100 )));
144- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (500 ) == byteBuffer .get (500 )));
145- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (300 ) == byteBuffer .get (300 )));
146- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (1000 ) == byteBuffer .get (1000 )));
147136 if (PlatformDetection .isWindows ()) shma .close ();
148137 }
149138
@@ -157,11 +146,10 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
157146 long flatSize = 1 ;
158147 for (long l : arrayShape ) {flatSize *= l ;}
159148 double [] flat = new double [(int ) flatSize ];
160- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Double .BYTES ));
149+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Double .BYTES )). order ( ByteOrder . LITTLE_ENDIAN ) ;
161150 DoubleBuffer floatBuffer = byteBuffer .asDoubleBuffer ();
162151 tensor .data_ptr_double ().get (flat );
163152 floatBuffer .put (flat );
164- byteBuffer .rewind ();
165153 shma .getDataBufferNoHeader ().put (byteBuffer );
166154 if (PlatformDetection .isWindows ()) shma .close ();
167155 }
@@ -176,11 +164,10 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
176164 long flatSize = 1 ;
177165 for (long l : arrayShape ) {flatSize *= l ;}
178166 long [] flat = new long [(int ) flatSize ];
179- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Long .BYTES ));
167+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Long .BYTES )). order ( ByteOrder . LITTLE_ENDIAN ) ;
180168 LongBuffer floatBuffer = byteBuffer .asLongBuffer ();
181169 tensor .data_ptr_long ().get (flat );
182170 floatBuffer .put (flat );
183- byteBuffer .rewind ();
184171 shma .getDataBufferNoHeader ().put (byteBuffer );
185172 if (PlatformDetection .isWindows ()) shma .close ();
186173 }
0 commit comments