Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions flow/record/adapter/gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import logging
import re
from fnmatch import fnmatch
from typing import Iterator

from google.cloud.storage.client import Client
from google.cloud.storage.fileio import BlobReader, BlobWriter

from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.base import Record, RecordAdapter
from flow.record.selector import Selector

__usage__ = """
Google Cloud Storage adapter
---
Read usage: rdump gcs://[BUCKET_ID]/path?project=[PROJECT]
Write usage: rdump -w gcs://[BUCKET_ID]/path?project=[PROJECT]

[BUCKET_ID]: Bucket ID
[path]: Path to read from or write to, supports glob-pattern matching when reading

Optional arguments:
[PROJECT]: Google Cloud Project ID, If not passed, falls back to the default inferred from the environment.
"""

log = logging.getLogger(__name__)

GLOB_CHARACTERS_RE = r"[\[\]\*\?]"


class GcsReader(AbstractReader):
def __init__(self, uri: str, *, project: str | None = None, selector: Selector | None = None, **kwargs):
self.selector = selector
bucket_name, _, path = uri.partition("/")
self.gcs = Client(project=project)
self.bucket = self.gcs.bucket(bucket_name)

# GCS Doesn't support iterating blobs using a glob pattern, so we have to do that ourselves. To extract the path
# prefix from the glob-pattern we have to find the first place where the glob starts.
self.prefix, *glob_pattern = re.split(GLOB_CHARACTERS_RE, path)
self.pattern = path if glob_pattern else None

def __iter__(self) -> Iterator[Record]:
blobs = self.gcs.list_blobs(bucket_or_name=self.bucket, prefix=self.prefix)
for blob in blobs:
if blob.size == 0: # Skip empty files
continue
if self.pattern and not fnmatch(blob.name, self.pattern):
continue
blobreader = BlobReader(blob)

# Give the file-like object to RecordAdapter so it will select the right adapter by peeking into the stream
reader = RecordAdapter(fileobj=blobreader, out=False, selector=self.selector)
for record in reader:
yield record

def close(self) -> None:
self.gcs.close()


class GcsWriter(AbstractWriter):
def __init__(self, uri: str, *, project: str | None = None, **kwargs):
bucket_name, _, path = uri.partition("/")
self.writer = None

self.gcs = Client(project=project)
self.bucket = self.gcs.bucket(bucket_name)

blob = self.bucket.blob(path)
self.writer = BlobWriter(blob, ignore_flush=True)
self.adapter = RecordAdapter(url=path, fileobj=self.writer, out=True, **kwargs)

def write(self, record: Record) -> None:
self.adapter.write(record)

def flush(self) -> None:
# The underlying adapter may require flushing
self.adapter.flush()

def close(self) -> None:
self.flush()
self.adapter.close()

if self.writer:
self.writer.close()
self.writer = None
70 changes: 38 additions & 32 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,29 @@ def open_path_or_stream(path: Union[str, Path, BinaryIO], mode: str, clobber: bo
raise ValueError(f"Unsupported path type {path}")


def wrap_in_compression(fp: BinaryIO, mode: str, path: str) -> BinaryIO:
if path.endswith(".gz"):
return gzip.GzipFile(fileobj=fp, mode=mode)
elif path.endswith(".bz2"):
if not HAS_BZ2:
raise RuntimeError("bz2 python module not available")
return bz2.BZ2File(fp, mode)
elif path.endswith(".lz4"):
if not HAS_LZ4:
raise RuntimeError("lz4 python module not available")
return lz4.open(fp, mode)
elif path.endswith((".zstd", ".zst")):
if not HAS_ZSTD:
raise RuntimeError("zstandard python module not available")
if "w" not in mode:
dctx = zstd.ZstdDecompressor()
return dctx.stream_reader(fp)
else:
cctx = zstd.ZstdCompressor()
return cctx.stream_writer(fp)
return fp


def open_path(path: str, mode: str, clobber: bool = True) -> IO:
"""
Open ``path`` using ``mode`` and returns a file object.
Expand Down Expand Up @@ -787,40 +810,18 @@ def open_path(path: str, mode: str, clobber: bool = True) -> IO:
if not is_stdio and not clobber and os.path.exists(path) and out:
raise IOError("Output file {!r} already exists, and clobber=False".format(path))

# check path extension for compression
if path:
if path.endswith(".gz"):
fp = gzip.GzipFile(path, mode)
elif path.endswith(".bz2"):
if not HAS_BZ2:
raise RuntimeError("bz2 python module not available")
fp = bz2.BZ2File(path, mode)
elif path.endswith(".lz4"):
if not HAS_LZ4:
raise RuntimeError("lz4 python module not available")
fp = lz4.open(path, mode)
elif path.endswith((".zstd", ".zst")):
if not HAS_ZSTD:
raise RuntimeError("zstandard python module not available")
if not out:
dctx = zstd.ZstdDecompressor()
fp = dctx.stream_reader(open(path, "rb"))
else:
cctx = zstd.ZstdCompressor()
fp = cctx.stream_writer(open(path, "wb"))

# normal file or stdio for reading or writing
if not fp:
if is_stdio:
if binary:
fp = getattr(sys.stdout, "buffer", sys.stdout) if out else getattr(sys.stdin, "buffer", sys.stdin)
else:
fp = sys.stdout if out else sys.stdin
if is_stdio:
if binary:
fp = getattr(sys.stdout, "buffer", sys.stdout) if out else getattr(sys.stdin, "buffer", sys.stdin)
else:
fp = io.open(path, mode)
# check if we are reading a compressed stream
if not out and binary:
fp = open_stream(fp, mode)
fp = sys.stdout if out else sys.stdin
else:
fp = wrap_in_compression(io.open(path, mode), mode, path)

# check if we are reading a compressed stream
if not out and binary:
fp = open_stream(fp, mode)
return fp


Expand Down Expand Up @@ -863,6 +864,11 @@ def RecordAdapter(
cls_url = p.netloc + p.path
if sub_adapter:
cls_url = sub_adapter + "://" + cls_url

# If the destination path ends with a compression extension, we wrap the fileobj in a transparent compressor.
if out and fileobj is not None:
fileobj = wrap_in_compression(fileobj, "wb", p.path)

if out is False:
if url in ("-", "", None) and fileobj is None:
# For reading stdin, we cannot rely on an extension to know what sort of stream is incoming. Thus, we will
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ geoip = [
avro = [
"fastavro[snappy]",
]
gcs = [
"google-cloud-storage",
]
duckdb = [
"duckdb",
"pytz", # duckdb requires pytz for timezone support
Expand Down
176 changes: 176 additions & 0 deletions tests/test_gcs_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from __future__ import annotations

import sys
from io import BytesIO
from typing import Any, Generator, Iterator
from unittest.mock import MagicMock, patch

import pytest

from flow.record import Record, RecordAdapter, RecordDescriptor, RecordStreamWriter
from flow.record.base import GZIP_MAGIC


def generate_records(amount) -> Generator[Record, Any, None]:
TestRecordWithFooBar = RecordDescriptor(
"test/record",
[
("string", "name"),
("string", "foo"),
("varint", "idx"),
],
)
for i in range(amount):
yield TestRecordWithFooBar(name=f"record{i}", foo="bar", idx=i)


def clean_up_adapter_import(test_function):
def wrapper(mock_google_sdk):
try:
result = test_function(mock_google_sdk)
finally:
if "flow.record.adapter.gcs" in sys.modules:
del sys.modules["flow.record.adapter.gcs"]
return result

return wrapper


@pytest.fixture
def mock_google_sdk(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]:
with monkeypatch.context() as m:
mock_google_sdk = MagicMock()
m.setitem(sys.modules, "google", mock_google_sdk)
m.setitem(sys.modules, "google.cloud", mock_google_sdk.cloud)
m.setitem(sys.modules, "google.cloud.storage", mock_google_sdk.cloud.storage)
m.setitem(sys.modules, "google.cloud.storage.client", mock_google_sdk.cloud.storage.client)
m.setitem(sys.modules, "google.cloud.storage.fileio", mock_google_sdk.cloud.storage.fileio)

yield mock_google_sdk


@clean_up_adapter_import
def test_gcs_uri_and_path(mock_google_sdk: MagicMock) -> None:
from flow.record.adapter.gcs import GcsReader

mock_client = MagicMock()
mock_google_sdk.cloud.storage.client.Client.return_value = mock_client
adapter_with_glob = RecordAdapter("gcs://test-bucket/path/to/records/*/*.avro", project="test-project")

assert isinstance(adapter_with_glob, GcsReader)

mock_google_sdk.cloud.storage.client.Client.assert_called_with(project="test-project")
mock_client.bucket.assert_called_with("test-bucket")

assert adapter_with_glob.prefix == "path/to/records/"
assert adapter_with_glob.pattern == "path/to/records/*/*.avro"

adapter_without_glob = RecordAdapter("gcs://test-bucket/path/to/records/test-records.rec", project="test-project")
assert isinstance(adapter_without_glob, GcsReader)

assert adapter_without_glob.prefix == "path/to/records/test-records.rec"
assert adapter_without_glob.pattern is None


@clean_up_adapter_import
def test_gcs_reader_glob(mock_google_sdk) -> None:
# Create a mocked record stream
test_records = list(generate_records(10))
mock_blob = BytesIO()
writer = RecordStreamWriter(fp=mock_blob)
for record in test_records:
writer.write(record)
writer.flush()
mock_recordstream = mock_blob.getvalue()
writer.close()

# Create a mocked client that will return the test-bucket
mock_client = MagicMock()
mock_client.bucket.return_value = "test-bucket-returned-from-client"
mock_google_sdk.cloud.storage.client.Client.return_value = mock_client

# Create a mocked instance of the 'Blob' class of google.cloud.storage.fileio
recordsfile_blob_mock = MagicMock()
recordsfile_blob_mock.name = "path/to/records/subfolder/results/tests.records"
recordsfile_blob_mock.data = mock_recordstream
recordsfile_blob_mock.size = len(mock_recordstream)

# As this blob is located in the '🍩 select' folder, it should not match with the glob that will be used later
# (which requires /results/ to be present in the path string)
wrong_location_blob = MagicMock()
wrong_location_blob.name = "path/to/records/subfolder/donutselect/tests.records"
wrong_location_blob.size = 0x69
wrong_location_blob.data = b""

# Return one empty file, one file that should match the glob, and one file that shouldn't match the glob
mock_client.list_blobs.return_value = [MagicMock(size=0), recordsfile_blob_mock, wrong_location_blob]

test_read_buf = BytesIO(mock_recordstream)
mock_reader = MagicMock(wraps=test_read_buf, spec=BytesIO)
mock_reader.closed = False
mock_google_sdk.cloud.storage.fileio.BlobReader.return_value = mock_reader
with patch("io.open", MagicMock(return_value=mock_reader)):
adapter = RecordAdapter(
url="gcs://test-bucket/path/to/records/*/results/*.records",
project="test-project",
selector="r.idx >= 5",
)

found_records = list(adapter)
mock_client.bucket.assert_called_with("test-bucket")
mock_client.list_blobs.assert_called_with(
bucket_or_name="test-bucket-returned-from-client",
prefix="path/to/records/",
)

# We expect the GCS Reader to skip over blobs of size 0, as those will inherently not contain records.
# Thus, a BlobReader should only have been initialized once, for the mocked records blob.
mock_google_sdk.cloud.storage.fileio.BlobReader.assert_called_once()

# We expect 5 records rather than 10 because of the selector that we used
assert len(found_records) == 5
for record in found_records:
assert record.foo == "bar"
assert record == test_records[record.idx]

adapter.close()
mock_client.close.assert_called()


@clean_up_adapter_import
def test_gcs_writer(mock_google_sdk) -> None:
from flow.record.adapter.gcs import GcsWriter

test_buf = BytesIO()
mock_writer = MagicMock(wraps=test_buf, spec=BytesIO)
mock_google_sdk.cloud.storage.fileio.BlobWriter.return_value = mock_writer

adapter = RecordAdapter("gcs://test-bucket/test/test.records.gz", project="test-project", out=True)

assert isinstance(adapter, GcsWriter)

# Add mock records
test_records = list(generate_records(10))
for record in test_records:
adapter.write(record)

adapter.flush()
mock_writer.flush.assert_called()

# Grab the bytes before it's too late
written_bytes = test_buf.getvalue()
assert written_bytes.startswith(GZIP_MAGIC)

read_buf = BytesIO(test_buf.getvalue())

# Close the writer and assure the object has been closed
adapter.close()
mock_writer.close.assert_called()
assert test_buf.closed

# Verify if the written record stream is something we can read
reader = RecordAdapter(fileobj=read_buf)
read_records = list(reader)
assert len(read_records) == 10
for idx, record in enumerate(read_records):
assert record == test_records[idx]
Loading