diff --git a/exercise_utils/test.py b/exercise_utils/test.py index bed433c9..04df315f 100644 --- a/exercise_utils/test.py +++ b/exercise_utils/test.py @@ -3,7 +3,19 @@ from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Callable, ContextManager, Dict, Iterator, List, Optional, Self, Tuple +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Iterator, + List, + Literal, + Optional, + Self, + Tuple, + overload, +) from unittest import mock import pytz @@ -45,14 +57,19 @@ def __init__( grade_func: Callable[[GitAutograderExercise], GitAutograderOutput], clone_from: Optional[str] = None, mock_answers: Optional[Dict[str, str]] = None, + include_remote_repo: bool = False, ) -> None: self.exercise_name = exercise_name self.grade_func = grade_func self.clone_from = clone_from self.mock_answers = mock_answers + self.include_remote_repo = include_remote_repo self.__rs: Optional[RepoSmith] = None + self.__rs_remote: Optional[RepoSmith] = None self.__rs_context: Optional[ContextManager[RepoSmith]] = None + self.__rs_remote_context: Optional[ContextManager[RepoSmith]] = None self.__temp_dir: Optional[tempfile.TemporaryDirectory] = None + self.__remote_temp_dir: Optional[tempfile.TemporaryDirectory] = None self.__patches: List[mock._patch] = [] @property @@ -60,6 +77,10 @@ def rs(self) -> RepoSmith: assert self.__rs is not None return self.__rs + @property + def rs_remote(self) -> Optional[RepoSmith]: + return self.__rs_remote + def run(self) -> GitAutograderOutput: output: Optional[GitAutograderOutput] = None started_at = datetime.now(tz=pytz.UTC) @@ -95,7 +116,7 @@ def run(self) -> GitAutograderOutput: assert output is not None return output - def __enter__(self) -> Tuple[Self, RepoSmith]: + def __enter__(self) -> Tuple[Self, RepoSmith, RepoSmith | None]: # We will mock all accesses to the config to avoid reading the file itself # Only the exercise name and repo_name matters, everything else isn't used repo_name = "repo" @@ -165,7 +186,18 @@ def __enter__(self) -> Tuple[Self, RepoSmith]: self.__rs = self.__rs_context.__enter__() self.__rs.add_helper(GitMasteryHelper) - return self, self.rs + if self.include_remote_repo: + self.__remote_temp_dir = tempfile.TemporaryDirectory() + remote_temp_path = Path(self.__remote_temp_dir.name) + remote_repo_path = remote_temp_path / repo_name + os.makedirs(remote_repo_path, exist_ok=True) + self.__rs_remote_context = create_repo_smith( + False, existing_path=remote_repo_path.absolute().as_posix() + ) + self.__rs_remote = self.__rs_remote_context.__enter__() + self.__rs_remote.add_helper(GitMasteryHelper) + + return self, self.rs, self.rs_remote def __exit__( self, @@ -182,6 +214,12 @@ def __exit__( if self.__rs_context is not None: self.__rs_context.__exit__(exc_type, exc_val, None) + if self.__rs_remote_context is not None: + self.__rs_remote_context.__exit__(exc_type, exc_val, None) + + if self.__remote_temp_dir is not None: + self.__remote_temp_dir.cleanup() + class GitAutograderTestLoader: def __init__( @@ -192,20 +230,44 @@ def __init__( self.exercise_name = exercise_name self.grade_func = grade_func + @overload + def start( + self, + clone_from: Optional[str] = None, + mock_answers: Optional[Dict[str, str]] = None, + include_remote_repo: Literal[False] = False, + ) -> ContextManager[Tuple[GitAutograderTest, RepoSmith]]: ... + + @overload + def start( + self, + clone_from: Optional[str] = None, + mock_answers: Optional[Dict[str, str]] = None, + *, + include_remote_repo: Literal[True], + ) -> ContextManager[Tuple[GitAutograderTest, RepoSmith, RepoSmith]]: ... + @contextmanager def start( self, clone_from: Optional[str] = None, mock_answers: Optional[Dict[str, str]] = None, - ) -> Iterator[Tuple[GitAutograderTest, RepoSmith]]: + include_remote_repo: bool = False, + ) -> Iterator[Any]: test = GitAutograderTest( self.exercise_name, self.grade_func, clone_from, mock_answers, + include_remote_repo, ) - with test as (ctx, rs): - yield (ctx, rs) + if include_remote_repo: + with test as (ctx, rs, rs_remote): + yield ctx, rs, rs_remote + else: + # extract only rs if include_remote_repo is False + with test as (ctx, rs, rs_remote): + yield ctx, rs def assert_output(