Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/flyte/_code_bundle/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import logging
import os
import pathlib
import random
import sqlite3
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Type

Expand All @@ -28,6 +31,61 @@
_pickled_file_extension = ".pkl.gz"
_tar_file_extension = ".tar.gz"

_BUNDLE_CACHE_DB = Path.home() / ".flyte" / "cache" / "bundles.db"
_BUNDLE_CACHE_TTL_DAYS = 30


def _get_bundle_cache_db() -> sqlite3.Connection:
"""Open (and lazily initialize) the SQLite bundle cache database."""
_BUNDLE_CACHE_DB.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(_BUNDLE_CACHE_DB))
conn.execute(
"CREATE TABLE IF NOT EXISTS bundle_cache "
"(digest TEXT PRIMARY KEY, hash_digest TEXT NOT NULL, remote_path TEXT NOT NULL, "
"created_at REAL NOT NULL)"
)
return conn


def _read_bundle_cache(digest: str) -> tuple[str, str] | None:
"""Look up a previously uploaded bundle by its file digest. Returns (hash_digest, remote_path) or None."""
try:
conn = _get_bundle_cache_db()
try:
cutoff = time.time() - _BUNDLE_CACHE_TTL_DAYS * 86400
row = conn.execute(
"SELECT hash_digest, remote_path FROM bundle_cache WHERE digest = ? AND created_at > ?",
(digest, cutoff),
).fetchone()
# Prune expired entries ~5% of the time to avoid doing it on every read
if random.random() < 0.05:
conn.execute("DELETE FROM bundle_cache WHERE created_at <= ?", (cutoff,))
conn.commit()
if row:
return row[0], row[1]
finally:
conn.close()
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to read bundle cache: {e}")
return None


def _write_bundle_cache(digest: str, hash_digest: str, remote_path: str) -> None:
"""Persist a successfully uploaded bundle to the SQLite cache."""
try:
conn = _get_bundle_cache_db()
try:
with conn:
conn.execute(
"INSERT OR REPLACE INTO bundle_cache (digest, hash_digest, remote_path, created_at) "
"VALUES (?, ?, ?, ?)",
(digest, hash_digest, remote_path, time.time()),
)
finally:
conn.close()
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to write bundle cache: {e}")


class _PklCache:
_pkl_cache: ClassVar[AsyncLRUCache[str, str]] = AsyncLRUCache[str, str](maxsize=100)
Expand Down Expand Up @@ -141,8 +199,6 @@ async def build_code_bundle(
if copy_style == "none":
raise ValueError("If copy_style is 'none', just don't make a code bundle")

status.step("Bundling code...")
logger.debug("Building code bundle.")
from flyte.remote import upload_file

if not ignore:
Expand All @@ -163,6 +219,16 @@ async def build_code_bundle(
if logger.getEffectiveLevel() <= logging.INFO:
print_ls_tree(from_dir, files)

# Check persistent cache before creating the tar bundle to avoid unnecessary work
if not dryrun:
cached = _read_bundle_cache(digest)
if cached:
hash_digest, remote_path = cached
status.success("Code bundle found in cache, skipping upload")
logger.debug(f"Code bundle cache hit: {remote_path}")
return CodeBundle(tgz=remote_path, destination=extract_dir, computed_version=hash_digest, files=files)

status.step("Bundling code...")
logger.debug("Building code bundle.")
with tempfile.TemporaryDirectory() as tmp_dir:
bundle_path, tar_size, archive_size = create_bundle(
Expand All @@ -173,6 +239,7 @@ async def build_code_bundle(
status.step("Uploading code bundle...")
hash_digest, remote_path = await upload_file.aio(bundle_path)
logger.debug(f"Code bundle uploaded to {remote_path}")
_write_bundle_cache(digest, hash_digest, remote_path)
else:
if copy_bundle_to:
remote_path = str(copy_bundle_to / bundle_path.name)
Expand Down Expand Up @@ -218,6 +285,15 @@ async def build_code_bundle_from_relative_paths(
if logger.getEffectiveLevel() <= logging.INFO:
print_ls_tree(from_dir, files)

# Check persistent cache before creating the tar bundle to avoid unnecessary work
if not dryrun:
cached = _read_bundle_cache(digest)
if cached:
hash_digest, remote_path = cached
status.success("Code bundle found in cache, skipping upload")
logger.debug(f"Code bundle cache hit: {remote_path}")
return CodeBundle(tgz=remote_path, destination=extract_dir, computed_version=hash_digest, files=files)

logger.debug("Building code bundle.")
with tempfile.TemporaryDirectory() as tmp_dir:
bundle_path, tar_size, archive_size = create_bundle(from_dir, pathlib.Path(tmp_dir), files, digest)
Expand All @@ -226,6 +302,7 @@ async def build_code_bundle_from_relative_paths(
status.step("Uploading code bundle...")
hash_digest, remote_path = await upload_file.aio(bundle_path)
logger.debug(f"Code bundle uploaded to {remote_path}")
_write_bundle_cache(digest, hash_digest, remote_path)
else:
remote_path = "na"
if copy_bundle_to:
Expand Down
8 changes: 7 additions & 1 deletion src/flyte/_internal/imagebuild/docker_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ImageChecker,
LocalDockerCommandImageChecker,
LocalPodmanCommandImageChecker,
PersistentCacheImageChecker,
)
from flyte._internal.imagebuild.utils import (
copy_files_to_context,
Expand Down Expand Up @@ -575,7 +576,12 @@ class DockerImageBuilder(ImageBuilder):

def get_checkers(self) -> Optional[typing.List[typing.Type[ImageChecker]]]:
# Can get a public token for docker.io but ghcr requires a pat, so harder to get the manifest anonymously
return [LocalDockerCommandImageChecker, LocalPodmanCommandImageChecker, DockerAPIImageChecker]
return [
PersistentCacheImageChecker,
LocalDockerCommandImageChecker,
LocalPodmanCommandImageChecker,
DockerAPIImageChecker,
]

async def build_image(
self, image: Image, dry_run: bool = False, wait: bool = True, force: bool = False
Expand Down
100 changes: 97 additions & 3 deletions src/flyte/_internal/imagebuild/image_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

import asyncio
import hashlib
import json
import random
import sqlite3
import time
import typing
from importlib.metadata import entry_points
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Tuple

from async_lru import alru_cache
Expand All @@ -15,6 +20,9 @@
from flyte._logging import logger
from flyte._status import status

_IMAGE_CACHE_DB = Path.home() / ".flyte" / "cache" / "images.db"
_IMAGE_CACHE_TTL_DAYS = 30

if TYPE_CHECKING:
from flyte._build import ImageBuild

Expand All @@ -36,7 +44,15 @@ class ImageChecker(Protocol):
@classmethod
async def image_exists(
cls, repository: str, tag: str, arch: Tuple[Architecture, ...] = ("linux/amd64",)
) -> Optional[str]: ...
) -> Optional[str]:
"""
Check whether an image exists in a registry or cache.

Returns the image URI if found, or None if the image definitively does not exist.
Raise an exception if existence cannot be determined (e.g. cache miss, network failure)
so the next checker in the chain gets a chance.
"""
...


class DockerAPIImageChecker(ImageChecker):
Expand Down Expand Up @@ -93,6 +109,79 @@ async def image_exists(
return None


def _cache_key(repository: str, tag: str, arch: Tuple[str, ...]) -> str:
"""Return a stable cache key for an image."""
raw = f"{repository}:{tag}:{','.join(sorted(arch))}"
return hashlib.sha256(raw.encode()).hexdigest()


def _get_cache_db() -> sqlite3.Connection:
"""Open (and lazily initialize) the SQLite image cache database."""
_IMAGE_CACHE_DB.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(_IMAGE_CACHE_DB))
conn.execute(
"CREATE TABLE IF NOT EXISTS image_cache "
"(key TEXT PRIMARY KEY, image_uri TEXT NOT NULL, "
"created_at REAL NOT NULL)"
)
return conn


def _read_image_cache(repository: str, tag: str, arch: Tuple[str, ...]) -> Optional[str]:
"""Look up a previously verified image URI by repository, tag, and arch. Returns image_uri or None."""
try:
conn = _get_cache_db()
try:
cutoff = time.time() - _IMAGE_CACHE_TTL_DAYS * 86400
row = conn.execute(
"SELECT image_uri FROM image_cache WHERE key = ? AND created_at > ?",
(_cache_key(repository, tag, arch), cutoff),
).fetchone()
# Prune expired entries ~5% of the time to avoid doing it on every read
if random.random() < 0.05:
conn.execute("DELETE FROM image_cache WHERE created_at <= ?", (cutoff,))
conn.commit()
if row:
return row[0]
finally:
conn.close()
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to read image cache: {e}")
return None


def _write_image_cache(repository: str, tag: str, arch: Tuple[str, ...], image_uri: str) -> None:
"""Persist a verified image URI to the SQLite cache."""
try:
conn = _get_cache_db()
try:
with conn:
conn.execute(
"INSERT OR REPLACE INTO image_cache (key, image_uri, created_at) VALUES (?, ?, ?)",
(_cache_key(repository, tag, arch), image_uri, time.time()),
)
finally:
conn.close()
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to write image cache: {e}")


class PersistentCacheImageChecker(ImageChecker):
"""Check if image was previously verified and cached in SQLite (~0ms)."""

@classmethod
async def image_exists(
cls, repository: str, tag: str, arch: Tuple[Architecture, ...] = ("linux/amd64",)
) -> Optional[str]:
uri = _read_image_cache(repository, tag, arch)
if uri:
logger.debug(f"Image {uri} found in persistent cache")
return uri
# Cache miss — raise so the next checker in the chain gets a chance.
# Returning None would mean "image definitely doesn't exist".
raise LookupError(f"Image {repository}:{tag} not found in persistent cache")


class LocalDockerCommandImageChecker(ImageChecker):
command_name: ClassVar[str] = "docker"

Expand Down Expand Up @@ -170,12 +259,17 @@ async def image_exists(image: Image) -> Optional[str]:
image_uri = await checker.image_exists(repository, tag, tuple(image.platform))
if image_uri:
logger.debug(f"Image {image_uri} in registry")
return image_uri
# Persist to disk so future process invocations skip network checks
if checker is not PersistentCacheImageChecker:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the problem is the persistent cache never invalidates. If an image is deleted from the registry, the cache will still say it exists, and the build will be skipped

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it makes UX way more better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can solve this, by just having a very short TTL on the cache. this is why i am suggesting using sqlite. Anyways the data is tiny and one row is enough?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short TTL sounds good to me. I'll update it to use sqlite

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Add TTL and cache for code bundle.

_write_image_cache(repository, tag, tuple(image.platform), image_uri)
return image_uri
# Checker ran successfully and returned None — image not found
return None
except Exception as e:
logger.debug(f"Error checking image existence with {checker.__name__}: {e}")
continue

# If all checkers fail, then assume the image exists. This is current flytekit behavior
# All checkers raised exceptions (e.g. network failures) — assume image exists
status.info(f"All checkers failed to check existence of {image.uri}, assuming it exists")
return image.uri

Expand Down
45 changes: 39 additions & 6 deletions tests/flyte/imagebuild/test_image_build_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,59 @@
DockerAPIImageChecker,
ImageBuildEngine,
LocalDockerCommandImageChecker,
PersistentCacheImageChecker,
)


@mock.patch("flyte._internal.imagebuild.image_builder.DockerAPIImageChecker.image_exists")
@mock.patch("flyte._internal.imagebuild.image_builder.LocalDockerCommandImageChecker.image_exists")
@mock.patch("flyte._internal.imagebuild.image_builder.PersistentCacheImageChecker.image_exists")
@pytest.mark.asyncio
async def test_cached(mock_checker_cli, mock_checker_api):
# Simulate that the image exists locally
mock_checker_cli.return_value = True
async def test_cached(mock_checker_cache, mock_checker_cli, mock_checker_api):
# Simulate that the image exists via persistent cache
mock_checker_cache.return_value = True

img = Image.from_debian_base()
await ImageBuildEngine.image_exists(img)
await ImageBuildEngine.image_exists(img)

# The local checker should be called once, and its result cached
mock_checker_cli.assert_called_once()
# The API checker should not be called at all
# The persistent cache checker should be called once, and its result cached by alru_cache
mock_checker_cache.assert_called_once()
# All other checkers should not be called
mock_checker_cli.assert_not_called()
mock_checker_api.assert_not_called()


def test_persistent_cache_write_and_read(tmp_path, monkeypatch):
"""PersistentCacheImageChecker reads back what _write_image_cache wrote."""
import flyte._internal.imagebuild.image_builder as ib

monkeypatch.setattr(ib, "_IMAGE_CACHE_DB", tmp_path / "images.db")

# Initially nothing cached
import asyncio

result = asyncio.get_event_loop().run_until_complete(
PersistentCacheImageChecker.image_exists("myrepo", "v1.0", ("linux/amd64",))
)
assert result is None

# Write to cache
ib._write_image_cache("myrepo", "v1.0", ("linux/amd64",), "myrepo:v1.0")

# Now it should be found
result = asyncio.get_event_loop().run_until_complete(
PersistentCacheImageChecker.image_exists("myrepo", "v1.0", ("linux/amd64",))
)
assert result == "myrepo:v1.0"

# Different arch should NOT be found
result = asyncio.get_event_loop().run_until_complete(
PersistentCacheImageChecker.image_exists("myrepo", "v1.0", ("linux/arm64",))
)
assert result is None


@mock.patch("flyte._internal.imagebuild.image_builder.ImageBuildEngine._get_builder")
@mock.patch("flyte._internal.imagebuild.image_builder.ImageBuildEngine.image_exists", new_callable=mock.AsyncMock)
@pytest.mark.asyncio
Expand Down
Loading