diff --git a/flow/record/adapter/gcs.py b/flow/record/adapter/gcs.py new file mode 100644 index 00000000..376415e9 --- /dev/null +++ b/flow/record/adapter/gcs.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import logging +import re +from fnmatch import fnmatch +from typing import Iterator + +from google.cloud.storage.client import Client +from google.cloud.storage.fileio import BlobReader, BlobWriter + +from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.base import Record, RecordAdapter +from flow.record.selector import Selector + +__usage__ = """ +Google Cloud Storage adapter +--- +Read usage: rdump gcs://[BUCKET_ID]/path?project=[PROJECT] +Write usage: rdump -w gcs://[BUCKET_ID]/path?project=[PROJECT] + +[BUCKET_ID]: Bucket ID +[path]: Path to read from or write to, supports glob-pattern matching when reading + +Optional arguments: + [PROJECT]: Google Cloud Project ID, If not passed, falls back to the default inferred from the environment. +""" + +log = logging.getLogger(__name__) + +GLOB_CHARACTERS_RE = r"[\[\]\*\?]" + + +class GcsReader(AbstractReader): + def __init__(self, uri: str, *, project: str | None = None, selector: Selector | None = None, **kwargs): + self.selector = selector + bucket_name, _, path = uri.partition("/") + self.gcs = Client(project=project) + self.bucket = self.gcs.bucket(bucket_name) + + # GCS Doesn't support iterating blobs using a glob pattern, so we have to do that ourselves. To extract the path + # prefix from the glob-pattern we have to find the first place where the glob starts. + self.prefix, *glob_pattern = re.split(GLOB_CHARACTERS_RE, path) + self.pattern = path if glob_pattern else None + + def __iter__(self) -> Iterator[Record]: + blobs = self.gcs.list_blobs(bucket_or_name=self.bucket, prefix=self.prefix) + for blob in blobs: + if blob.size == 0: # Skip empty files + continue + if self.pattern and not fnmatch(blob.name, self.pattern): + continue + blobreader = BlobReader(blob) + + # Give the file-like object to RecordAdapter so it will select the right adapter by peeking into the stream + reader = RecordAdapter(fileobj=blobreader, out=False, selector=self.selector) + for record in reader: + yield record + + def close(self) -> None: + self.gcs.close() + + +class GcsWriter(AbstractWriter): + def __init__(self, uri: str, *, project: str | None = None, **kwargs): + bucket_name, _, path = uri.partition("/") + self.writer = None + + self.gcs = Client(project=project) + self.bucket = self.gcs.bucket(bucket_name) + + blob = self.bucket.blob(path) + self.writer = BlobWriter(blob, ignore_flush=True) + self.adapter = RecordAdapter(url=path, fileobj=self.writer, out=True, **kwargs) + + def write(self, record: Record) -> None: + self.adapter.write(record) + + def flush(self) -> None: + # The underlying adapter may require flushing + self.adapter.flush() + + def close(self) -> None: + self.flush() + self.adapter.close() + + if self.writer: + self.writer.close() + self.writer = None diff --git a/flow/record/base.py b/flow/record/base.py index 6a2625f3..86b3ab5d 100644 --- a/flow/record/base.py +++ b/flow/record/base.py @@ -758,6 +758,29 @@ def open_path_or_stream(path: Union[str, Path, BinaryIO], mode: str, clobber: bo raise ValueError(f"Unsupported path type {path}") +def wrap_in_compression(fp: BinaryIO, mode: str, path: str) -> BinaryIO: + if path.endswith(".gz"): + return gzip.GzipFile(fileobj=fp, mode=mode) + elif path.endswith(".bz2"): + if not HAS_BZ2: + raise RuntimeError("bz2 python module not available") + return bz2.BZ2File(fp, mode) + elif path.endswith(".lz4"): + if not HAS_LZ4: + raise RuntimeError("lz4 python module not available") + return lz4.open(fp, mode) + elif path.endswith((".zstd", ".zst")): + if not HAS_ZSTD: + raise RuntimeError("zstandard python module not available") + if "w" not in mode: + dctx = zstd.ZstdDecompressor() + return dctx.stream_reader(fp) + else: + cctx = zstd.ZstdCompressor() + return cctx.stream_writer(fp) + return fp + + def open_path(path: str, mode: str, clobber: bool = True) -> IO: """ Open ``path`` using ``mode`` and returns a file object. @@ -787,40 +810,18 @@ def open_path(path: str, mode: str, clobber: bool = True) -> IO: if not is_stdio and not clobber and os.path.exists(path) and out: raise IOError("Output file {!r} already exists, and clobber=False".format(path)) - # check path extension for compression - if path: - if path.endswith(".gz"): - fp = gzip.GzipFile(path, mode) - elif path.endswith(".bz2"): - if not HAS_BZ2: - raise RuntimeError("bz2 python module not available") - fp = bz2.BZ2File(path, mode) - elif path.endswith(".lz4"): - if not HAS_LZ4: - raise RuntimeError("lz4 python module not available") - fp = lz4.open(path, mode) - elif path.endswith((".zstd", ".zst")): - if not HAS_ZSTD: - raise RuntimeError("zstandard python module not available") - if not out: - dctx = zstd.ZstdDecompressor() - fp = dctx.stream_reader(open(path, "rb")) - else: - cctx = zstd.ZstdCompressor() - fp = cctx.stream_writer(open(path, "wb")) - # normal file or stdio for reading or writing - if not fp: - if is_stdio: - if binary: - fp = getattr(sys.stdout, "buffer", sys.stdout) if out else getattr(sys.stdin, "buffer", sys.stdin) - else: - fp = sys.stdout if out else sys.stdin + if is_stdio: + if binary: + fp = getattr(sys.stdout, "buffer", sys.stdout) if out else getattr(sys.stdin, "buffer", sys.stdin) else: - fp = io.open(path, mode) - # check if we are reading a compressed stream - if not out and binary: - fp = open_stream(fp, mode) + fp = sys.stdout if out else sys.stdin + else: + fp = wrap_in_compression(io.open(path, mode), mode, path) + + # check if we are reading a compressed stream + if not out and binary: + fp = open_stream(fp, mode) return fp @@ -863,6 +864,11 @@ def RecordAdapter( cls_url = p.netloc + p.path if sub_adapter: cls_url = sub_adapter + "://" + cls_url + + # If the destination path ends with a compression extension, we wrap the fileobj in a transparent compressor. + if out and fileobj is not None: + fileobj = wrap_in_compression(fileobj, "wb", p.path) + if out is False: if url in ("-", "", None) and fileobj is None: # For reading stdin, we cannot rely on an extension to know what sort of stream is incoming. Thus, we will diff --git a/pyproject.toml b/pyproject.toml index 7998e85c..3ad78e0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ geoip = [ avro = [ "fastavro[snappy]", ] +gcs = [ + "google-cloud-storage", +] duckdb = [ "duckdb", "pytz", # duckdb requires pytz for timezone support diff --git a/tests/test_gcs_adapter.py b/tests/test_gcs_adapter.py new file mode 100644 index 00000000..ede57ed5 --- /dev/null +++ b/tests/test_gcs_adapter.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import sys +from io import BytesIO +from typing import Any, Generator, Iterator +from unittest.mock import MagicMock, patch + +import pytest + +from flow.record import Record, RecordAdapter, RecordDescriptor, RecordStreamWriter +from flow.record.base import GZIP_MAGIC + + +def generate_records(amount) -> Generator[Record, Any, None]: + TestRecordWithFooBar = RecordDescriptor( + "test/record", + [ + ("string", "name"), + ("string", "foo"), + ("varint", "idx"), + ], + ) + for i in range(amount): + yield TestRecordWithFooBar(name=f"record{i}", foo="bar", idx=i) + + +def clean_up_adapter_import(test_function): + def wrapper(mock_google_sdk): + try: + result = test_function(mock_google_sdk) + finally: + if "flow.record.adapter.gcs" in sys.modules: + del sys.modules["flow.record.adapter.gcs"] + return result + + return wrapper + + +@pytest.fixture +def mock_google_sdk(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]: + with monkeypatch.context() as m: + mock_google_sdk = MagicMock() + m.setitem(sys.modules, "google", mock_google_sdk) + m.setitem(sys.modules, "google.cloud", mock_google_sdk.cloud) + m.setitem(sys.modules, "google.cloud.storage", mock_google_sdk.cloud.storage) + m.setitem(sys.modules, "google.cloud.storage.client", mock_google_sdk.cloud.storage.client) + m.setitem(sys.modules, "google.cloud.storage.fileio", mock_google_sdk.cloud.storage.fileio) + + yield mock_google_sdk + + +@clean_up_adapter_import +def test_gcs_uri_and_path(mock_google_sdk: MagicMock) -> None: + from flow.record.adapter.gcs import GcsReader + + mock_client = MagicMock() + mock_google_sdk.cloud.storage.client.Client.return_value = mock_client + adapter_with_glob = RecordAdapter("gcs://test-bucket/path/to/records/*/*.avro", project="test-project") + + assert isinstance(adapter_with_glob, GcsReader) + + mock_google_sdk.cloud.storage.client.Client.assert_called_with(project="test-project") + mock_client.bucket.assert_called_with("test-bucket") + + assert adapter_with_glob.prefix == "path/to/records/" + assert adapter_with_glob.pattern == "path/to/records/*/*.avro" + + adapter_without_glob = RecordAdapter("gcs://test-bucket/path/to/records/test-records.rec", project="test-project") + assert isinstance(adapter_without_glob, GcsReader) + + assert adapter_without_glob.prefix == "path/to/records/test-records.rec" + assert adapter_without_glob.pattern is None + + +@clean_up_adapter_import +def test_gcs_reader_glob(mock_google_sdk) -> None: + # Create a mocked record stream + test_records = list(generate_records(10)) + mock_blob = BytesIO() + writer = RecordStreamWriter(fp=mock_blob) + for record in test_records: + writer.write(record) + writer.flush() + mock_recordstream = mock_blob.getvalue() + writer.close() + + # Create a mocked client that will return the test-bucket + mock_client = MagicMock() + mock_client.bucket.return_value = "test-bucket-returned-from-client" + mock_google_sdk.cloud.storage.client.Client.return_value = mock_client + + # Create a mocked instance of the 'Blob' class of google.cloud.storage.fileio + recordsfile_blob_mock = MagicMock() + recordsfile_blob_mock.name = "path/to/records/subfolder/results/tests.records" + recordsfile_blob_mock.data = mock_recordstream + recordsfile_blob_mock.size = len(mock_recordstream) + + # As this blob is located in the '🍩 select' folder, it should not match with the glob that will be used later + # (which requires /results/ to be present in the path string) + wrong_location_blob = MagicMock() + wrong_location_blob.name = "path/to/records/subfolder/donutselect/tests.records" + wrong_location_blob.size = 0x69 + wrong_location_blob.data = b"" + + # Return one empty file, one file that should match the glob, and one file that shouldn't match the glob + mock_client.list_blobs.return_value = [MagicMock(size=0), recordsfile_blob_mock, wrong_location_blob] + + test_read_buf = BytesIO(mock_recordstream) + mock_reader = MagicMock(wraps=test_read_buf, spec=BytesIO) + mock_reader.closed = False + mock_google_sdk.cloud.storage.fileio.BlobReader.return_value = mock_reader + with patch("io.open", MagicMock(return_value=mock_reader)): + adapter = RecordAdapter( + url="gcs://test-bucket/path/to/records/*/results/*.records", + project="test-project", + selector="r.idx >= 5", + ) + + found_records = list(adapter) + mock_client.bucket.assert_called_with("test-bucket") + mock_client.list_blobs.assert_called_with( + bucket_or_name="test-bucket-returned-from-client", + prefix="path/to/records/", + ) + + # We expect the GCS Reader to skip over blobs of size 0, as those will inherently not contain records. + # Thus, a BlobReader should only have been initialized once, for the mocked records blob. + mock_google_sdk.cloud.storage.fileio.BlobReader.assert_called_once() + + # We expect 5 records rather than 10 because of the selector that we used + assert len(found_records) == 5 + for record in found_records: + assert record.foo == "bar" + assert record == test_records[record.idx] + + adapter.close() + mock_client.close.assert_called() + + +@clean_up_adapter_import +def test_gcs_writer(mock_google_sdk) -> None: + from flow.record.adapter.gcs import GcsWriter + + test_buf = BytesIO() + mock_writer = MagicMock(wraps=test_buf, spec=BytesIO) + mock_google_sdk.cloud.storage.fileio.BlobWriter.return_value = mock_writer + + adapter = RecordAdapter("gcs://test-bucket/test/test.records.gz", project="test-project", out=True) + + assert isinstance(adapter, GcsWriter) + + # Add mock records + test_records = list(generate_records(10)) + for record in test_records: + adapter.write(record) + + adapter.flush() + mock_writer.flush.assert_called() + + # Grab the bytes before it's too late + written_bytes = test_buf.getvalue() + assert written_bytes.startswith(GZIP_MAGIC) + + read_buf = BytesIO(test_buf.getvalue()) + + # Close the writer and assure the object has been closed + adapter.close() + mock_writer.close.assert_called() + assert test_buf.closed + + # Verify if the written record stream is something we can read + reader = RecordAdapter(fileobj=read_buf) + read_records = list(reader) + assert len(read_records) == 10 + for idx, record in enumerate(read_records): + assert record == test_records[idx] diff --git a/tests/test_record_adapter.py b/tests/test_record_adapter.py index 310e3cb2..eabef0ee 100644 --- a/tests/test_record_adapter.py +++ b/tests/test_record_adapter.py @@ -1,6 +1,7 @@ import datetime import platform import sys +from gzip import GzipFile from io import BytesIO import pytest @@ -22,6 +23,8 @@ HAS_LZ4, HAS_ZSTD, LZ4_MAGIC, + RECORDSTREAM_MAGIC, + RECORDSTREAM_MAGIC_DEPTH, ZSTD_MAGIC, ) from flow.record.selector import CompiledSelector, Selector @@ -476,12 +479,16 @@ def test_csvfilereader(tmp_path): assert rec.count == "2" -def test_file_like_writer_reader() -> None: +@pytest.mark.parametrize("use_gzip", [False, True]) +def test_file_like_writer_reader(use_gzip: bool) -> None: test_buf = BytesIO() - adapter = RecordAdapter(fileobj=test_buf, out=True) + url = "nonexistent/path/my_records.gz" if use_gzip else None + adapter = RecordAdapter(url, fileobj=test_buf, out=True) assert isinstance(adapter, StreamWriter) + if use_gzip: + assert isinstance(adapter.fp, GzipFile) # Add mock records test_records = list(generate_records(10)) @@ -491,7 +498,14 @@ def test_file_like_writer_reader() -> None: adapter.flush() # Grab the bytes before closing the BytesIO object. - read_buf = BytesIO(test_buf.getvalue()) + written_bytes = test_buf.getvalue() + + if use_gzip: + assert written_bytes.startswith(GZIP_MAGIC) + else: + assert written_bytes[:RECORDSTREAM_MAGIC_DEPTH].endswith(RECORDSTREAM_MAGIC) + + read_buf = BytesIO(written_bytes) # Close the writer and assure the object has been closed adapter.close()