diff --git a/plugins/hitl/src/flyteplugins/hitl/_event.py b/plugins/hitl/src/flyteplugins/hitl/_event.py index e293cd35a..5a7f97820 100644 --- a/plugins/hitl/src/flyteplugins/hitl/_event.py +++ b/plugins/hitl/src/flyteplugins/hitl/_event.py @@ -346,6 +346,7 @@ async def show_form(form_url: str, api_url: str, curl_body: str, type_name: str) await flyte.report.flush.aio() return html_report + @flyte.trace async def wait_for_input_event( name: str, diff --git a/src/flyte/_code_bundle/bundle.py b/src/flyte/_code_bundle/bundle.py index 13fd109d7..02543d3ab 100644 --- a/src/flyte/_code_bundle/bundle.py +++ b/src/flyte/_code_bundle/bundle.py @@ -2,10 +2,14 @@ import asyncio import gzip +import hashlib import logging import os import pathlib +import random +import sqlite3 import tempfile +import time from pathlib import Path from typing import TYPE_CHECKING, ClassVar, Type @@ -28,6 +32,56 @@ _pickled_file_extension = ".pkl.gz" _tar_file_extension = ".tar.gz" +_BUNDLE_CACHE_TTL_DAYS = 1 + + +def _scoped_digest(digest: str) -> str: + """Return a digest scoped to the current endpoint/project/domain.""" + from flyte._persistence._db import _cache_scope + + raw = f"{_cache_scope()}:{digest}" + return hashlib.sha256(raw.encode()).hexdigest() + + +def _read_bundle_cache(digest: str) -> tuple[str, str] | None: + """Look up a previously uploaded bundle by its file digest. Returns (hash_digest, remote_path) or None.""" + from flyte._persistence._db import LocalDB + + try: + conn = LocalDB.get_sync() + cutoff = time.time() - _BUNDLE_CACHE_TTL_DAYS * 86400 + row = conn.execute( + "SELECT hash_digest, remote_path FROM bundle_cache WHERE digest = ? AND created_at > ?", + (_scoped_digest(digest), cutoff), + ).fetchone() + # Prune expired entries ~5% of the time to avoid doing it on every read + if random.random() < 0.05: + with LocalDB._write_lock: + conn.execute("DELETE FROM bundle_cache WHERE created_at <= ?", (cutoff,)) + conn.commit() + if row: + return row[0], row[1] + except (OSError, sqlite3.Error) as e: + logger.debug(f"Failed to read bundle cache: {e}") + return None + + +def _write_bundle_cache(digest: str, hash_digest: str, remote_path: str) -> None: + """Persist a successfully uploaded bundle to the SQLite cache.""" + from flyte._persistence._db import LocalDB + + try: + conn = LocalDB.get_sync() + with LocalDB._write_lock: + conn.execute( + "INSERT OR REPLACE INTO bundle_cache (digest, hash_digest, remote_path, created_at) " + "VALUES (?, ?, ?, ?)", + (_scoped_digest(digest), hash_digest, remote_path, time.time()), + ) + conn.commit() + except (OSError, sqlite3.Error) as e: + logger.debug(f"Failed to write bundle cache: {e}") + class _PklCache: _pkl_cache: ClassVar[AsyncLRUCache[str, str]] = AsyncLRUCache[str, str](maxsize=100) @@ -125,6 +179,7 @@ async def build_code_bundle( dryrun: bool = False, copy_bundle_to: pathlib.Path | None = None, copy_style: CopyFiles = "loaded_modules", + skip_cache: bool = False, ) -> CodeBundle: """ Build the code bundle for the current environment. @@ -135,14 +190,13 @@ async def build_code_bundle( :param dryrun: If dryrun is enabled, files will not be uploaded to the control plane. :param copy_bundle_to: If set, the bundle will be copied to this path. This is used for testing purposes. :param copy_style: What to put into the tarball. (either all, or loaded_modules. if none, skip this function) + :param skip_cache: If true, skip the persistent SQLite cache lookup and always rebuild/re-upload. :return: The code bundle, which contains the path where the code was zipped to. """ if copy_style == "none": raise ValueError("If copy_style is 'none', just don't make a code bundle") - status.step("Bundling code...") - logger.debug("Building code bundle.") from flyte.remote import upload_file if not ignore: @@ -163,6 +217,16 @@ async def build_code_bundle( if logger.getEffectiveLevel() <= logging.INFO: print_ls_tree(from_dir, files) + # Check persistent cache before creating the tar bundle to avoid unnecessary work + if not dryrun and not skip_cache: + cached = _read_bundle_cache(digest) + if cached: + hash_digest, remote_path = cached + status.success("Code bundle found in cache, skipping upload") + logger.debug(f"Code bundle cache hit: {remote_path}") + return CodeBundle(tgz=remote_path, destination=extract_dir, computed_version=hash_digest, files=files) + + status.step("Bundling code...") logger.debug("Building code bundle.") with tempfile.TemporaryDirectory() as tmp_dir: bundle_path, tar_size, archive_size = create_bundle( @@ -173,6 +237,7 @@ async def build_code_bundle( status.step("Uploading code bundle...") hash_digest, remote_path = await upload_file.aio(bundle_path) logger.debug(f"Code bundle uploaded to {remote_path}") + _write_bundle_cache(digest, hash_digest, remote_path) else: if copy_bundle_to: remote_path = str(copy_bundle_to / bundle_path.name) @@ -198,6 +263,7 @@ async def build_code_bundle_from_relative_paths( extract_dir: str = ".", dryrun: bool = False, copy_bundle_to: pathlib.Path | None = None, + skip_cache: bool = False, ) -> CodeBundle: """ Build a code bundle from a list of relative paths. @@ -207,6 +273,7 @@ async def build_code_bundle_from_relative_paths( working directory. :param dryrun: If dryrun is enabled, files will not be uploaded to the control plane. :param copy_bundle_to: If set, the bundle will be copied to this path. This is used for testing purposes. + :param skip_cache: If true, skip the persistent SQLite cache lookup and always rebuild/re-upload. :return: The code bundle, which contains the path where the code was zipped to. """ status.step("Bundling code...") @@ -218,6 +285,15 @@ async def build_code_bundle_from_relative_paths( if logger.getEffectiveLevel() <= logging.INFO: print_ls_tree(from_dir, files) + # Check persistent cache before creating the tar bundle to avoid unnecessary work + if not dryrun and not skip_cache: + cached = _read_bundle_cache(digest) + if cached: + hash_digest, remote_path = cached + status.success("Code bundle found in cache, skipping upload") + logger.debug(f"Code bundle cache hit: {remote_path}") + return CodeBundle(tgz=remote_path, destination=extract_dir, computed_version=hash_digest, files=files) + logger.debug("Building code bundle.") with tempfile.TemporaryDirectory() as tmp_dir: bundle_path, tar_size, archive_size = create_bundle(from_dir, pathlib.Path(tmp_dir), files, digest) @@ -226,6 +302,7 @@ async def build_code_bundle_from_relative_paths( status.step("Uploading code bundle...") hash_digest, remote_path = await upload_file.aio(bundle_path) logger.debug(f"Code bundle uploaded to {remote_path}") + _write_bundle_cache(digest, hash_digest, remote_path) else: remote_path = "na" if copy_bundle_to: diff --git a/src/flyte/_internal/imagebuild/docker_builder.py b/src/flyte/_internal/imagebuild/docker_builder.py index ce07622dc..b8d1f9ea1 100644 --- a/src/flyte/_internal/imagebuild/docker_builder.py +++ b/src/flyte/_internal/imagebuild/docker_builder.py @@ -40,6 +40,7 @@ ImageChecker, LocalDockerCommandImageChecker, LocalPodmanCommandImageChecker, + PersistentCacheImageChecker, ) from flyte._internal.imagebuild.utils import ( copy_files_to_context, @@ -583,7 +584,12 @@ class DockerImageBuilder(ImageBuilder): def get_checkers(self) -> Optional[typing.List[typing.Type[ImageChecker]]]: # Can get a public token for docker.io but ghcr requires a pat, so harder to get the manifest anonymously - return [LocalDockerCommandImageChecker, LocalPodmanCommandImageChecker, DockerAPIImageChecker] + return [ + PersistentCacheImageChecker, + LocalDockerCommandImageChecker, + LocalPodmanCommandImageChecker, + DockerAPIImageChecker, + ] async def build_image( self, image: Image, dry_run: bool = False, wait: bool = True, force: bool = False diff --git a/src/flyte/_internal/imagebuild/image_builder.py b/src/flyte/_internal/imagebuild/image_builder.py index 8eeb9562b..450a5649c 100644 --- a/src/flyte/_internal/imagebuild/image_builder.py +++ b/src/flyte/_internal/imagebuild/image_builder.py @@ -1,7 +1,11 @@ from __future__ import annotations import asyncio +import hashlib import json +import random +import sqlite3 +import time import typing from importlib.metadata import entry_points from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Tuple @@ -13,8 +17,11 @@ from flyte._image import Architecture, Image from flyte._initialize import _get_init_config from flyte._logging import logger +from flyte._persistence._db import LocalDB from flyte._status import status +_IMAGE_CACHE_TTL_DAYS = 30 + if TYPE_CHECKING: from flyte._build import ImageBuild @@ -36,7 +43,15 @@ class ImageChecker(Protocol): @classmethod async def image_exists( cls, repository: str, tag: str, arch: Tuple[Architecture, ...] = ("linux/amd64",) - ) -> Optional[str]: ... + ) -> Optional[str]: + """ + Check whether an image exists in a registry or cache. + + Returns the image URI if found, or None if the image definitively does not exist. + Raise an exception if existence cannot be determined (e.g. cache miss, network failure) + so the next checker in the chain gets a chance. + """ + ... class DockerAPIImageChecker(ImageChecker): @@ -93,6 +108,65 @@ async def image_exists( return None +def _cache_key(repository: str, tag: str, arch: Tuple[str, ...]) -> str: + """Return a stable cache key for an image, scoped to the current endpoint/project/domain.""" + from flyte._persistence._db import _cache_scope + + raw = f"{_cache_scope()}:{repository}:{tag}:{','.join(sorted(arch))}" + return hashlib.sha256(raw.encode()).hexdigest() + + +def _read_image_cache(repository: str, tag: str, arch: Tuple[str, ...]) -> Optional[str]: + """Look up a previously verified image URI by repository, tag, and arch. Returns image_uri or None.""" + try: + conn = LocalDB.get_sync() + cutoff = time.time() - _IMAGE_CACHE_TTL_DAYS * 86400 + row = conn.execute( + "SELECT image_uri FROM image_cache WHERE key = ? AND created_at > ?", + (_cache_key(repository, tag, arch), cutoff), + ).fetchone() + # Prune expired entries ~5% of the time to avoid doing it on every read + if random.random() < 0.05: + with LocalDB._write_lock: + conn.execute("DELETE FROM image_cache WHERE created_at <= ?", (cutoff,)) + conn.commit() + if row: + return row[0] + except (OSError, sqlite3.Error) as e: + logger.debug(f"Failed to read image cache: {e}") + return None + + +def _write_image_cache(repository: str, tag: str, arch: Tuple[str, ...], image_uri: str) -> None: + """Persist a verified image URI to the SQLite cache.""" + try: + conn = LocalDB.get_sync() + with LocalDB._write_lock: + conn.execute( + "INSERT OR REPLACE INTO image_cache (key, image_uri, created_at) VALUES (?, ?, ?)", + (_cache_key(repository, tag, arch), image_uri, time.time()), + ) + conn.commit() + except (OSError, sqlite3.Error) as e: + logger.debug(f"Failed to write image cache: {e}") + + +class PersistentCacheImageChecker(ImageChecker): + """Check if image was previously verified and cached in SQLite (~0ms).""" + + @classmethod + async def image_exists( + cls, repository: str, tag: str, arch: Tuple[Architecture, ...] = ("linux/amd64",) + ) -> Optional[str]: + uri = _read_image_cache(repository, tag, arch) + if uri: + logger.debug(f"Image {uri} found in persistent cache") + return uri + # Cache miss — raise so the next checker in the chain gets a chance. + # Returning None would mean "image definitely doesn't exist". + raise LookupError(f"Image {repository}:{tag} not found in persistent cache") + + class LocalDockerCommandImageChecker(ImageChecker): command_name: ClassVar[str] = "docker" @@ -170,12 +244,17 @@ async def image_exists(image: Image) -> Optional[str]: image_uri = await checker.image_exists(repository, tag, tuple(image.platform)) if image_uri: logger.debug(f"Image {image_uri} in registry") - return image_uri + # Persist to disk so future process invocations skip network checks + if checker is not PersistentCacheImageChecker: + _write_image_cache(repository, tag, tuple(image.platform), image_uri) + return image_uri + # Checker ran successfully and returned None — image not found + return None except Exception as e: logger.debug(f"Error checking image existence with {checker.__name__}: {e}") continue - # If all checkers fail, then assume the image exists. This is current flytekit behavior + # All checkers raised exceptions (e.g. network failures) — assume image exists status.info(f"All checkers failed to check existence of {image.uri}, assuming it exists") return image.uri diff --git a/src/flyte/_persistence/_db.py b/src/flyte/_persistence/_db.py index 699aa19e3..ebf60a444 100644 --- a/src/flyte/_persistence/_db.py +++ b/src/flyte/_persistence/_db.py @@ -15,6 +15,20 @@ DEFAULT_CACHE_DIR = "~/.flyte" CACHE_LOCATION = "local-cache/cache.db" + +def _cache_scope() -> str: + """Return a stable string identifying the current endpoint+project+domain. + + Used to scope image/bundle cache entries so that different environments + don't collide. + """ + config = auto() + endpoint = config.platform.endpoint or "" + project = config.task.project or "" + domain = config.task.domain or "" + return f"{endpoint}:{project}:{domain}" + + _TASK_CACHE_DDL = """ CREATE TABLE IF NOT EXISTS task_cache ( key TEXT PRIMARY KEY, @@ -22,6 +36,23 @@ ) """ +_IMAGE_CACHE_DDL = """ +CREATE TABLE IF NOT EXISTS image_cache ( + key TEXT PRIMARY KEY, + image_uri TEXT NOT NULL, + created_at REAL NOT NULL +) +""" + +_BUNDLE_CACHE_DDL = """ +CREATE TABLE IF NOT EXISTS bundle_cache ( + digest TEXT PRIMARY KEY, + hash_digest TEXT NOT NULL, + remote_path TEXT NOT NULL, + created_at REAL NOT NULL +) +""" + _RUNS_DDL = """ CREATE TABLE IF NOT EXISTS runs ( run_name TEXT NOT NULL, @@ -49,6 +80,8 @@ """ +_ALL_TABLE_DDLS = [_TASK_CACHE_DDL, _RUNS_DDL, _IMAGE_CACHE_DDL, _BUNDLE_CACHE_DDL] + _RUNS_INDEXES = [ "CREATE INDEX IF NOT EXISTS idx_runs_action_start ON runs (action_name, start_time)", "CREATE INDEX IF NOT EXISTS idx_runs_status_start ON runs (status, start_time)", @@ -114,16 +147,16 @@ async def initialize(): async def _initialize_async(): db_path = LocalDB._get_db_path() conn = await aiosqlite.connect(db_path) - await conn.execute(_TASK_CACHE_DDL) - await conn.execute(_RUNS_DDL) + for ddl in _ALL_TABLE_DDLS: + await conn.execute(ddl) for idx_stmt in _RUNS_INDEXES: await conn.execute(idx_stmt) await conn.commit() LocalDB._conn = conn # Also open a sync connection for sync callers sync_conn = sqlite3.connect(db_path, check_same_thread=False) - sync_conn.execute(_TASK_CACHE_DDL) - sync_conn.execute(_RUNS_DDL) + for ddl in _ALL_TABLE_DDLS: + sync_conn.execute(ddl) _migrate_sync(sync_conn) LocalDB._conn_sync = sync_conn LocalDB._initialized = True @@ -140,8 +173,8 @@ def initialize_sync(): def _initialize_sync_inner(): db_path = LocalDB._get_db_path() conn = sqlite3.connect(db_path, check_same_thread=False) - conn.execute(_TASK_CACHE_DDL) - conn.execute(_RUNS_DDL) + for ddl in _ALL_TABLE_DDLS: + conn.execute(ddl) _migrate_sync(conn) LocalDB._conn_sync = conn LocalDB._initialized = True diff --git a/src/flyte/_persistence/_task_cache.py b/src/flyte/_persistence/_task_cache.py index 8a472c983..1fb13a979 100644 --- a/src/flyte/_persistence/_task_cache.py +++ b/src/flyte/_persistence/_task_cache.py @@ -24,8 +24,9 @@ async def clear(): await conn.commit() else: conn = LocalDB.get_sync() - conn.execute("DELETE FROM task_cache") - conn.commit() + with LocalDB._write_lock: + conn.execute("DELETE FROM task_cache") + conn.commit() @staticmethod async def get(cache_key: str) -> convert.Outputs | None: @@ -76,8 +77,9 @@ async def _set_async(cache_key: str, value: convert.Outputs) -> None: def _set_sync(cache_key: str, value: convert.Outputs) -> None: conn = LocalDB.get_sync() output_bytes = value.proto_outputs.SerializeToString() - conn.execute("INSERT OR REPLACE INTO task_cache (key, value) VALUES (?, ?)", (cache_key, output_bytes)) - conn.commit() + with LocalDB._write_lock: + conn.execute("INSERT OR REPLACE INTO task_cache (key, value) VALUES (?, ?)", (cache_key, output_bytes)) + conn.commit() @staticmethod async def close(): diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 841d80eca..4a7ec2e43 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -6,7 +6,6 @@ import pathlib import sys import uuid -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast from flyte._context import Context, contextual_run, internal_ctx @@ -40,26 +39,11 @@ from flyte.remote._task import LazyEntity from ._code_bundle import CopyFiles - from ._internal.imagebuild.image_builder import ImageCache Mode = Literal["local", "remote", "hybrid"] CacheLookupScope = Literal["global", "project-domain"] -@dataclass(frozen=True) -class _CacheKey: - obj_id: int - dry_run: bool - - -@dataclass(frozen=True) -class _CacheValue: - code_bundle: CodeBundle | None - image_cache: Optional[ImageCache] - - -_RUN_CACHE: Dict[_CacheKey, _CacheValue] = {} - # ContextVar for run mode - thread-safe and coroutine-safe alternative to a global variable. # This allows offloaded types (files, directories, dataframes) to be aware of the run mode # for controlling auto-uploading behavior (only enabled in remote mode). @@ -203,56 +187,47 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar if obj.parent_env is None: raise ValueError("Task is not attached to an environment. Please attach the task to an environment") - if ( - not self._disable_run_cache - and _RUN_CACHE.get(_CacheKey(obj_id=id(obj), dry_run=self._dry_run)) is not None - ): - cached_value = _RUN_CACHE[_CacheKey(obj_id=id(obj), dry_run=self._dry_run)] - code_bundle = cached_value.code_bundle - image_cache = cached_value.image_cache - else: - # Resolve any CodeBundleLayer layers before building images - parent_env = cast(Environment, obj.parent_env()) - from flyte._image import Image, resolve_code_bundle_layer + # Resolve any CodeBundleLayer layers before building images + parent_env = cast(Environment, obj.parent_env()) + from flyte._image import Image, resolve_code_bundle_layer - if isinstance(parent_env.image, Image): - parent_env.image = resolve_code_bundle_layer( - parent_env.image, self._copy_files, pathlib.Path(cfg.root_dir) - ) + if isinstance(parent_env.image, Image): + parent_env.image = resolve_code_bundle_layer( + parent_env.image, self._copy_files, pathlib.Path(cfg.root_dir) + ) - if not self._dry_run: - image_cache = await build_images.aio(parent_env) - else: - image_cache = None + if not self._dry_run: + image_cache = await build_images.aio(parent_env) + else: + image_cache = None - if self._interactive_mode: - code_bundle = await build_pkl_bundle( - obj, - upload_to_controlplane=not self._dry_run, - copy_bundle_to=self._copy_bundle_to, - ) - elif self._copy_files == "custom": - if not self._bundle_relative_paths or not self._bundle_from_dir: - raise ValueError("copy_style='custom' requires _bundle_relative_paths and _bundle_from_dir") - code_bundle = await build_code_bundle_from_relative_paths( - self._bundle_relative_paths, - from_dir=self._bundle_from_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - ) - elif self._copy_files != "none": - code_bundle = await build_code_bundle( - from_dir=cfg.root_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - copy_style=self._copy_files, - ) - else: - code_bundle = None - if not self._disable_run_cache: - _RUN_CACHE[_CacheKey(obj_id=id(obj), dry_run=self._dry_run)] = _CacheValue( - code_bundle=code_bundle, image_cache=image_cache + skip_cache = self._disable_run_cache + if self._interactive_mode: + code_bundle = await build_pkl_bundle( + obj, + upload_to_controlplane=not self._dry_run, + copy_bundle_to=self._copy_bundle_to, + ) + elif self._copy_files == "custom": + if not self._bundle_relative_paths or not self._bundle_from_dir: + raise ValueError("copy_style='custom' requires _bundle_relative_paths and _bundle_from_dir") + code_bundle = await build_code_bundle_from_relative_paths( + self._bundle_relative_paths, + from_dir=self._bundle_from_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + skip_cache=skip_cache, ) + elif self._copy_files != "none": + code_bundle = await build_code_bundle( + from_dir=cfg.root_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + copy_style=self._copy_files, + skip_cache=skip_cache, + ) + else: + code_bundle = None version = self._version or ( code_bundle.computed_version if code_bundle and code_bundle.computed_version else None diff --git a/tests/flyte/imagebuild/test_image_build_engine.py b/tests/flyte/imagebuild/test_image_build_engine.py index 6c5fc76f3..b1c5490d7 100644 --- a/tests/flyte/imagebuild/test_image_build_engine.py +++ b/tests/flyte/imagebuild/test_image_build_engine.py @@ -8,26 +8,67 @@ DockerAPIImageChecker, ImageBuildEngine, LocalDockerCommandImageChecker, + PersistentCacheImageChecker, ) @mock.patch("flyte._internal.imagebuild.image_builder.DockerAPIImageChecker.image_exists") @mock.patch("flyte._internal.imagebuild.image_builder.LocalDockerCommandImageChecker.image_exists") +@mock.patch("flyte._internal.imagebuild.image_builder.PersistentCacheImageChecker.image_exists") @pytest.mark.asyncio -async def test_cached(mock_checker_cli, mock_checker_api): - # Simulate that the image exists locally - mock_checker_cli.return_value = True +async def test_cached(mock_checker_cache, mock_checker_cli, mock_checker_api): + # Simulate that the image exists via persistent cache + mock_checker_cache.return_value = True img = Image.from_debian_base() await ImageBuildEngine.image_exists(img) await ImageBuildEngine.image_exists(img) - # The local checker should be called once, and its result cached - mock_checker_cli.assert_called_once() - # The API checker should not be called at all + # The persistent cache checker should be called once, and its result cached by alru_cache + mock_checker_cache.assert_called_once() + # All other checkers should not be called + mock_checker_cli.assert_not_called() mock_checker_api.assert_not_called() +def test_persistent_cache_write_and_read(tmp_path, monkeypatch): + """PersistentCacheImageChecker reads back what _write_image_cache wrote.""" + import flyte._internal.imagebuild.image_builder as ib + from flyte._persistence._db import LocalDB + + monkeypatch.setattr(LocalDB, "_get_db_path", staticmethod(lambda: str(tmp_path / "cache.db"))) + monkeypatch.setattr(LocalDB, "_initialized", False) + monkeypatch.setattr(LocalDB, "_conn_sync", None) + monkeypatch.setattr(LocalDB, "_conn", None) + LocalDB.initialize_sync() + + try: + # Initially nothing cached — PersistentCacheImageChecker raises LookupError on miss + import asyncio + + with pytest.raises(LookupError): + asyncio.get_event_loop().run_until_complete( + PersistentCacheImageChecker.image_exists("myrepo", "v1.0", ("linux/amd64",)) + ) + + # Write to cache + ib._write_image_cache("myrepo", "v1.0", ("linux/amd64",), "myrepo:v1.0") + + # Now it should be found + result = asyncio.get_event_loop().run_until_complete( + PersistentCacheImageChecker.image_exists("myrepo", "v1.0", ("linux/amd64",)) + ) + assert result == "myrepo:v1.0" + + # Different arch should NOT be found + with pytest.raises(LookupError): + asyncio.get_event_loop().run_until_complete( + PersistentCacheImageChecker.image_exists("myrepo", "v1.0", ("linux/arm64",)) + ) + finally: + LocalDB.close_sync() + + @mock.patch("flyte._internal.imagebuild.image_builder.ImageBuildEngine._get_builder") @mock.patch("flyte._internal.imagebuild.image_builder.ImageBuildEngine.image_exists", new_callable=mock.AsyncMock) @pytest.mark.asyncio