Skip to content
74 changes: 68 additions & 6 deletions exercise_utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Callable, ContextManager, Dict, Iterator, List, Optional, Self, Tuple
from typing import (
Any,
Callable,
ContextManager,
Dict,
Iterator,
List,
Literal,
Optional,
Self,
Tuple,
overload,
)
from unittest import mock

import pytz
Expand Down Expand Up @@ -45,21 +57,30 @@ def __init__(
grade_func: Callable[[GitAutograderExercise], GitAutograderOutput],
clone_from: Optional[str] = None,
mock_answers: Optional[Dict[str, str]] = None,
include_remote_repo: bool = False,
) -> None:
self.exercise_name = exercise_name
self.grade_func = grade_func
self.clone_from = clone_from
self.mock_answers = mock_answers
self.include_remote_repo = include_remote_repo
self.__rs: Optional[RepoSmith] = None
self.__rs_remote: Optional[RepoSmith] = None
self.__rs_context: Optional[ContextManager[RepoSmith]] = None
self.__rs_remote_context: Optional[ContextManager[RepoSmith]] = None
self.__temp_dir: Optional[tempfile.TemporaryDirectory] = None
self.__remote_temp_dir: Optional[tempfile.TemporaryDirectory] = None
self.__patches: List[mock._patch] = []

@property
def rs(self) -> RepoSmith:
assert self.__rs is not None
return self.__rs

@property
def rs_remote(self) -> Optional[RepoSmith]:
return self.__rs_remote

def run(self) -> GitAutograderOutput:
output: Optional[GitAutograderOutput] = None
started_at = datetime.now(tz=pytz.UTC)
Expand Down Expand Up @@ -95,7 +116,7 @@ def run(self) -> GitAutograderOutput:
assert output is not None
return output

def __enter__(self) -> Tuple[Self, RepoSmith]:
def __enter__(self) -> Tuple[Self, RepoSmith, RepoSmith | None]:
# We will mock all accesses to the config to avoid reading the file itself
# Only the exercise name and repo_name matters, everything else isn't used
repo_name = "repo"
Expand Down Expand Up @@ -165,7 +186,18 @@ def __enter__(self) -> Tuple[Self, RepoSmith]:
self.__rs = self.__rs_context.__enter__()
self.__rs.add_helper(GitMasteryHelper)

return self, self.rs
if self.include_remote_repo:
self.__remote_temp_dir = tempfile.TemporaryDirectory()
remote_temp_path = Path(self.__remote_temp_dir.name)
remote_repo_path = remote_temp_path / repo_name
os.makedirs(remote_repo_path, exist_ok=True)
self.__rs_remote_context = create_repo_smith(
False, existing_path=remote_repo_path.absolute().as_posix()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small remark, for consistency, existing_path... should be on a new line.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code be auto-formatted by ruff

)
self.__rs_remote = self.__rs_remote_context.__enter__()
self.__rs_remote.add_helper(GitMasteryHelper)

return self, self.rs, self.rs_remote
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __enter__ method now always returns three values (test, rs, rs_remote) regardless of whether include_remote_repo is True or False. When include_remote_repo is False, rs_remote will be None. This design decision means all callers must handle three return values, even when they don't need the remote repo. While this simplifies the implementation, it breaks backward compatibility with existing code that unpacks only two values (e.g., with test as (ctx, rs):). Consider whether this breaking change is acceptable, or if a different design would be better.

Suggested change
return self, self.rs, self.rs_remote
if self.include_remote_repo:
return self, self.rs, self.rs_remote
return self, self.rs

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes type arguments complicated and results in failure in mypy test cases, so we only extract out the relevant variables at start() based on include_remote_repo


def __exit__(
self,
Expand All @@ -182,6 +214,12 @@ def __exit__(
if self.__rs_context is not None:
self.__rs_context.__exit__(exc_type, exc_val, None)

if self.__rs_remote_context is not None:
self.__rs_remote_context.__exit__(exc_type, exc_val, None)

if self.__remote_temp_dir is not None:
self.__remote_temp_dir.cleanup()


class GitAutograderTestLoader:
def __init__(
Expand All @@ -192,20 +230,44 @@ def __init__(
self.exercise_name = exercise_name
self.grade_func = grade_func

@overload
def start(
self,
clone_from: Optional[str] = None,
mock_answers: Optional[Dict[str, str]] = None,
include_remote_repo: Literal[False] = False,
) -> ContextManager[Tuple[GitAutograderTest, RepoSmith]]: ...

@overload
def start(
self,
clone_from: Optional[str] = None,
mock_answers: Optional[Dict[str, str]] = None,
*,
include_remote_repo: Literal[True],
) -> ContextManager[Tuple[GitAutograderTest, RepoSmith, RepoSmith]]: ...

@contextmanager
def start(
self,
clone_from: Optional[str] = None,
mock_answers: Optional[Dict[str, str]] = None,
) -> Iterator[Tuple[GitAutograderTest, RepoSmith]]:
include_remote_repo: bool = False,
) -> Iterator[Any]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, I would avoid using Any as type, but it seems to make the most sense here.

I tried returning Tuple[GitAutograderTest, RepoSmith, RepoSmith] | Tuple[GitAutograderTest, RepoSmith] but it seems that Python's type system can't narrow a union return type based on instance attributes at call time.

Having the @overload here seems to be the best solution for type safety.

Copy link
Copy Markdown
Contributor

@SAN-MUYUN SAN-MUYUN Jan 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, without @overload, users lose the precision of return type when include_remote_repo is set to True/False.

test = GitAutograderTest(
self.exercise_name,
self.grade_func,
clone_from,
mock_answers,
include_remote_repo,
)
with test as (ctx, rs):
yield (ctx, rs)
if include_remote_repo:
with test as (ctx, rs, rs_remote):
yield ctx, rs, rs_remote
else:
# extract only rs if include_remote_repo is False
with test as (ctx, rs, rs_remote):
yield ctx, rs


def assert_output(
Expand Down