Skip to content
Open
7 changes: 7 additions & 0 deletions flight/flight-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ under the License.
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-core</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
Expand Down Expand Up @@ -145,6 +146,12 @@ under the License.
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.20.0</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
1 change: 0 additions & 1 deletion flight/flight-core/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
requires com.google.protobuf;
requires com.google.protobuf.util;
requires io.grpc;
requires io.grpc.internal;
requires io.grpc.netty;
requires io.grpc.protobuf;
requires io.grpc.stub;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.io.ByteStreams;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.WireFormat;
import io.grpc.Drainable;
Expand All @@ -40,8 +38,10 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.arrow.flight.FlightDataParser.ArrowBufReader;
import org.apache.arrow.flight.FlightDataParser.FlightDataReader;
import org.apache.arrow.flight.FlightDataParser.InputStreamReader;
import org.apache.arrow.flight.grpc.AddWritableBuffer;
import org.apache.arrow.flight.grpc.GetReadableBuffer;
import org.apache.arrow.flight.impl.Flight.FlightData;
import org.apache.arrow.flight.impl.Flight.FlightDescriptor;
import org.apache.arrow.memory.ArrowBuf;
Expand All @@ -55,10 +55,14 @@
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** The in-memory representation of FlightData used to manage a stream of Arrow messages. */
class ArrowMessage implements AutoCloseable {

private static final Logger LOG = LoggerFactory.getLogger(ArrowMessage.class);

// If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer
// instead of copying the data. Defaults to true.
public static final boolean ENABLE_ZERO_COPY_READ;
Expand All @@ -75,19 +79,10 @@ class ArrowMessage implements AutoCloseable {
if (zeroCopyWriteFlag == null) {
zeroCopyWriteFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE");
}
ENABLE_ZERO_COPY_READ = !"false".equalsIgnoreCase(zeroCopyReadFlag);
ENABLE_ZERO_COPY_READ = true; // !"false".equalsIgnoreCase(zeroCopyReadFlag);
ENABLE_ZERO_COPY_WRITE = "true".equalsIgnoreCase(zeroCopyWriteFlag);
}

private static final int DESCRIPTOR_TAG =
(FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final int BODY_TAG =
(FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final int HEADER_TAG =
(FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final int APP_METADATA_TAG =
(FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;

private static final Marshaller<FlightData> NO_BODY_MARSHALLER =
ProtoUtils.marshaller(FlightData.getDefaultInstance());

Expand Down Expand Up @@ -212,7 +207,7 @@ public ArrowMessage(FlightDescriptor descriptor) {
this.tryZeroCopyWrite = false;
}

private ArrowMessage(
ArrowMessage(
FlightDescriptor descriptor,
MessageMetadataResult message,
ArrowBuf appMetadata,
Expand Down Expand Up @@ -280,101 +275,16 @@ public Iterable<ArrowBuf> getBufs() {
}

private static ArrowMessage frame(BufferAllocator allocator, final InputStream stream) {

try {
FlightDescriptor descriptor = null;
MessageMetadataResult header = null;
ArrowBuf body = null;
ArrowBuf appMetadata = null;
while (stream.available() > 0) {
final int tagFirstByte = stream.read();
if (tagFirstByte == -1) {
break;
}
int tag = readRawVarint32(tagFirstByte, stream);
switch (tag) {
case DESCRIPTOR_TAG:
{
int size = readRawVarint32(stream);
byte[] bytes = new byte[size];
ByteStreams.readFully(stream, bytes);
descriptor = FlightDescriptor.parseFrom(bytes);
break;
}
case HEADER_TAG:
{
int size = readRawVarint32(stream);
byte[] bytes = new byte[size];
ByteStreams.readFully(stream, bytes);
header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size);
break;
}
case APP_METADATA_TAG:
{
int size = readRawVarint32(stream);
appMetadata = allocator.buffer(size);
GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ);
break;
}
case BODY_TAG:
if (body != null) {
// only read last body.
body.getReferenceManager().release();
body = null;
}
int size = readRawVarint32(stream);
body = allocator.buffer(size);
GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ);
break;

default:
// ignore unknown fields.
}
FlightDataReader reader;
if (ENABLE_ZERO_COPY_READ) {
reader = ArrowBufReader.tryArrowBufReader(allocator, stream);
if (reader != null) {
return reader.toMessage();
}
// Protobuf implementations can omit empty fields, such as body; for some message types, like
// RecordBatch,
// this will fail later as we still expect an empty buffer. In those cases only, fill in an
// empty buffer here -
// in other cases, like Schema, having an unexpected empty buffer will also cause failures.
// We don't fill in defaults for fields like header, for which there is no reasonable default,
// or for appMetadata
// or descriptor, which are intended to be empty in some cases.
if (header != null) {
switch (HeaderType.getHeader(header.headerType())) {
case SCHEMA:
// Ignore 0-length buffers in case a Protobuf implementation wrote it out
if (body != null && body.capacity() == 0) {
body.close();
body = null;
}
break;
case DICTIONARY_BATCH:
case RECORD_BATCH:
// A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here
if (body == null) {
body = allocator.getEmpty();
}
break;
case NONE:
case TENSOR:
default:
// Do nothing
break;
}
}
return new ArrowMessage(descriptor, header, appMetadata, body);
} catch (Exception ioe) {
throw new RuntimeException(ioe);
}
}

private static int readRawVarint32(InputStream is) throws IOException {
int firstByte = is.read();
return readRawVarint32(firstByte, is);
}

private static int readRawVarint32(int firstByte, InputStream is) throws IOException {
return CodedInputStream.readRawVarint32(firstByte, is);
reader = new InputStreamReader(allocator, stream);
return reader.toMessage();
}

/**
Expand Down
Loading
Loading