diff --git a/problemtools/run/__init__.py b/problemtools/run/__init__.py index 9915e13d..6f5791dc 100644 --- a/problemtools/run/__init__.py +++ b/problemtools/run/__init__.py @@ -11,7 +11,7 @@ from .program import Program from .source import SourceCode from .viva import Viva -from .tools import get_tool as get_tool +from .tools import get_tool as get_tool, get_tool_path as get_tool_path from . import rutil diff --git a/problemtools/verifyproblem.py b/problemtools/verifyproblem.py index 524ad28b..6a6bf1d1 100644 --- a/problemtools/verifyproblem.py +++ b/problemtools/verifyproblem.py @@ -37,7 +37,7 @@ from .formatversion import FormatVersion, get_format_version from abc import ABC -from typing import Any, Callable, ClassVar, Literal, Pattern, Match, ParamSpec, Type, TypeVar +from typing import Any, Callable, ClassVar, Literal, Pattern, Match, ParamSpec, TypeVar from pydantic import ValidationError log = logging.getLogger(__name__) @@ -196,31 +196,14 @@ class ProblemPart(ProblemAspect): """ PART_NAME: ClassVar[str] - """Should return all classes that need to be initialized before this one. It is sufficient to be - a subclass of the classes listed. There should be exactly one subclass of each dependency in the - format that the problem-part is included in. - - Note that this will only ensure that the specified classes are initialized before this one, but - they might be checked in a different order. - """ - - @staticmethod - def setup_dependencies() -> set[type]: - return set() - def __init__(self, problem: Problem) -> None: if self.PART_NAME is None: raise NotImplementedError('Every problem-part must override PART_NAME') super().__init__(f'{problem.shortname}.{self.PART_NAME}', problem) + self.setup() - """Override to setup data about this problem-part. The order in which problem-parts are setup - will be decided based on the dependencies that exist. - - Return value is the data made available by initializing this part. - """ - - def setup(self) -> dict: - return {} + def setup(self) -> None: + pass def start_background_work(self, context: Context) -> None: pass @@ -232,16 +215,16 @@ def check(self, context: Context) -> bool: class TestCase(ProblemAspect): Result = tuple[SubmissionResult, SubmissionResult, SubmissionResult] - def __init__(self, problem: Problem, aspect_name: str, base: str, testcasegroup: TestCaseGroup) -> None: - super().__init__(f'{problem.shortname}.{aspect_name}.{testcasegroup.name}.{os.path.basename(base)}', problem) + def __init__(self, problem: Problem, base: str, testcasegroup: TestCaseGroup) -> None: + super().__init__(f'{problem.shortname}.test.{testcasegroup.name}.{os.path.basename(base)}', problem) self._base = base self.infile = f'{base}.in' self.ansfile = f'{base}.ans' self._problem = problem self.testcasegroup = testcasegroup self.reuse_result_from: TestCase | None = None - self.counter = len(problem.getProblemPart(ProblemTestCases).testcase_by_infile) - problem.getProblemPart(ProblemTestCases).testcase_by_infile[self.infile] = self + self.counter = len(problem.testcase_by_infile) + problem.testcase_by_infile[self.infile] = self def check_newlines(self, filename: str) -> None: with open(filename, 'rb') as f: @@ -279,9 +262,9 @@ def check(self, context: Context) -> bool: self.check_newlines(self.ansfile) self.check_size_limits(self.infile) self.check_size_limits(self.ansfile) - self._problem.getProblemPart(InputValidators).validate(self) + self._problem.input_validators.validate(self) anssize = os.path.getsize(self.ansfile) / 1024.0 / 1024.0 - outputlim = self._problem.getMetadata().limits.output + outputlim = self._problem.metadata.limits.output if anssize > outputlim: self.error( f'Answer file ({anssize:.1f} Mb) is larger than output limit ({outputlim} Mb), you need to increase output limit' @@ -290,8 +273,8 @@ def check(self, context: Context) -> bool: self.warning( f'Answer file ({anssize:.1f} Mb) is within 50% of output limit ({outputlim} Mb), you might want to increase output limit' ) - if not self._problem.getMetadata().is_interactive(): - val_res = self._problem.getProblemPart(OutputValidators).validate(self, self.ansfile) + if not self._problem.metadata.is_interactive(): + val_res = self._problem.output_validators.validate(self, self.ansfile) if val_res.verdict != 'AC': if self.is_in_sample_group(): self.error(f'judge answer file got {val_res}') @@ -310,8 +293,8 @@ def set_symlinks(self) -> None: if not os.path.islink(self.infile): return target = os.path.realpath(self.infile) - if target in self._problem.getProblemPart(ProblemTestCases).testcase_by_infile: - self.reuse_result_from = self._problem.getProblemPart(ProblemTestCases).testcase_by_infile[target] + if target in self._problem.testcase_by_infile: + self.reuse_result_from = self._problem.testcase_by_infile[target] def _check_symlinks(self) -> bool: if not os.path.islink(self.infile): @@ -350,10 +333,8 @@ def run_submission(self, sub, runner: Runner, context: Context) -> Result: def run_submission_real(self, sub, context: Context, timelim: int, timelim_low: int, timelim_high: int) -> Result: # This may be called off-main thread. - if self._problem.getMetadata().is_interactive(): - res_high = self._problem.getProblemPart(OutputValidators).validate_interactive( - self, sub, timelim_high, self._problem.getProblemPart(Submissions) - ) + if self._problem.metadata.is_interactive(): + res_high = self._problem.output_validators.validate_interactive(self, sub, timelim_high, self._problem.submissions) else: outfile = os.path.join(self._problem.tmpdir, f'output-{self.counter}') errfile = os.path.join(self._problem.tmpdir, f'error-{self.counter}') @@ -362,7 +343,7 @@ def run_submission_real(self, sub, context: Context, timelim: int, timelim_low: outfile=outfile, errfile=errfile, timelim=timelim_high + 1, - memlim=self._problem.getMetadata().limits.memory, + memlim=self._problem.metadata.limits.memory, work_dir=sub.path, ) if is_TLE(status) or runtime > timelim_high: @@ -376,7 +357,7 @@ def run_submission_real(self, sub, context: Context, timelim: int, timelim_low: info = None res_high = SubmissionResult('RTE', additional_info=info) else: - res_high = self._problem.getProblemPart(OutputValidators).validate(self, outfile) + res_high = self._problem.output_validators.validate(self, outfile) res_high.runtime = runtime if res_high.runtime <= timelim_low: @@ -425,14 +406,14 @@ class TestCaseGroup(ProblemAspect): _DEFAULT_CONFIG = config.load_config('testdata.yaml') _SCORING_ONLY_KEYS = ['accept_score', 'reject_score', 'range'] - def __init__(self, problem: Problem, aspect_name: str, datadir: str | None = None, parent: TestCaseGroup | None = None): + def __init__(self, problem: Problem, datadir: str | None = None, parent: TestCaseGroup | None = None): self._parent = parent self._problem = problem datadir = datadir or os.path.join(problem.probdir, 'data') self._datadir = datadir self.name = os.path.relpath(os.path.abspath(self._datadir), os.path.abspath(self._problem.probdir)).replace('/', '.') - super().__init__(f'{problem.shortname}.{aspect_name}.{self.name}', problem) + super().__init__(f'{problem.shortname}.test.{self.name}', problem) self._seen_oob_scores = False self.debug('Loading test data group %s', datadir) @@ -455,7 +436,7 @@ def __init__(self, problem: Problem, aspect_name: str, datadir: str | None = Non # TODO: Decide if these should stay # Some deprecated properties are inherited from problem config during a transition period - legacy_grading = problem.getMetadata().legacy_grading + legacy_grading = problem.metadata.legacy_grading for key in ['accept_score', 'reject_score', 'range']: if getattr(legacy_grading, key) is not None: self.config[key] = getattr(legacy_grading, key) @@ -466,7 +447,7 @@ def __init__(self, problem: Problem, aspect_name: str, datadir: str | None = Non if problem_on_reject == 'grade': self.config['on_reject'] = 'continue' - if self._problem.getMetadata().is_pass_fail(): + if self._problem.metadata.is_pass_fail(): for key in TestCaseGroup._SCORING_ONLY_KEYS: if key not in self.config: self.config[key] = None @@ -480,11 +461,11 @@ def __init__(self, problem: Problem, aspect_name: str, datadir: str | None = Non for filename in sorted(os.listdir(datadir)): filename = os.path.join(datadir, filename) if os.path.isdir(filename): - self._items.append(TestCaseGroup(problem, aspect_name, filename, self)) + self._items.append(TestCaseGroup(problem, filename, self)) else: base, ext = os.path.splitext(filename) if ext == '.ans' and os.path.isfile(f'{base}.in'): - self._items.append(TestCase(problem, aspect_name, base, self)) + self._items.append(TestCase(problem, base, self)) if not parent: self.set_symlinks() @@ -536,10 +517,10 @@ def check(self, context: Context) -> bool: if self.config['grading'] not in ['default', 'custom']: self.error('Invalid grading policy in testdata.yaml') - if self.config['grading'] == 'custom' and len(self._problem.getProblemPart(Graders)._graders) == 0: - self._problem.getProblemPart(Graders).fatal(f'{self} has custom grading but no custom graders provided') + if self.config['grading'] == 'custom' and len(self._problem.graders._graders) == 0: + self._problem.graders.fatal(f'{self} has custom grading but no custom graders provided') if self.config['grading'] == 'default' and Graders._default_grader is None: - self._problem.getProblemPart(Graders).fatal(f'{self} has default grading but I could not find default grader') + self._problem.graders.fatal(f'{self} has default grading but I could not find default grader') if self.config['grading'] == 'default' and 'ignore_sample' in self.config['grader_flags'].split(): if self._parent is not None: @@ -553,7 +534,7 @@ def check(self, context: Context) -> bool: if field not in TestCaseGroup._DEFAULT_CONFIG.keys(): self.warning(f"Unknown key '{field}' in '{os.path.join(self._datadir, 'testdata.yaml')}'") - if not self._problem.getMetadata().is_scoring(): + if not self._problem.metadata.is_scoring(): for key in TestCaseGroup._SCORING_ONLY_KEYS: if self.config.get(key) is not None: self.error(f"Key '{key}' is only applicable for scoring problems, this is a pass-fail problem") @@ -561,7 +542,7 @@ def check(self, context: Context) -> bool: if self.config['on_reject'] not in ['break', 'continue']: self.error(f"Invalid value '{self.config['on_reject']}' for on_reject policy") - if self._problem.getMetadata().is_scoring(): + if self._problem.metadata.is_scoring(): # Check grading try: score_range = self.config['range'] @@ -720,11 +701,11 @@ def aggregate_results(self, sub, sub_results: list[SubmissionResult], shadow_res res.additional_info = judge_error.additional_info res.testcase = judge_error.testcase else: - res.verdict, score = self._problem.getProblemPart(Graders).grade(sub_results, self, shadow_result) + res.verdict, score = self._problem.graders.grade(sub_results, self, shadow_result) if sub_results: res.testcase = sub_results[-1].testcase res.additional_info = sub_results[-1].additional_info - if self._problem.getMetadata().is_scoring(): + if self._problem.metadata.is_scoring(): res.score = score min_score, max_score = self.get_score_range() if score is not None and not (min_score <= score <= max_score) and not self._seen_oob_scores: @@ -745,15 +726,12 @@ def all_datasets(self) -> list: class ProblemStatement(ProblemPart): + statements: dict[str, list[Path]] # Maps language code -> statement(s) PART_NAME = 'statement' def setup(self): self.debug(' Loading problem statement') - try: - self.statements = statement_util.find_statements(Path(self.problem.probdir), self.problem.format) - except OSError as e: - self.error(f'Failed locating problem statements: {e}') - self.statements = {} + self.statements = statement_util.find_statements(Path(self.problem.probdir), self.problem.format) def check(self, context: Context) -> bool: if self._check_res is not None: @@ -776,11 +754,11 @@ def check(self, context: Context) -> bool: if len(files) > 1: self.error(f'Found multiple statements in the same language {lang}: {", ".join((file.name for file in files))}') - if lang not in self.problem.getMetadata().name: + if lang not in self.problem.metadata.name: self.error(f'No problem name given in language {lang}') - elif not self.problem.getMetadata().name[lang]: + elif not self.problem.metadata.name[lang]: self.error(f'Problem name in language {lang} is empty') - elif not self.problem.getMetadata().name[lang].strip(): + elif not self.problem.metadata.name[lang].strip(): self.error(f'Problem name in language {lang} contains only whitespace') for file in files: @@ -822,18 +800,14 @@ class ProblemConfig(ProblemPart): def setup(self): self.debug(' Loading problem config') - try: self._metadata, self._origdata = metadata.load_metadata(Path(self.problem.probdir)) - self.problem.setMetadata(self._metadata) + self.problem._set_metadata(self._metadata) except ValidationError as e: - # This should likely be a fatal error, but I'm not sure there's a clean way to fail from setup error_str = '\n'.join([f' {"->".join((str(loc) for loc in err["loc"]))}: {err["msg"]}' for err in e.errors()]) self.fatal(f'Failed parsing problem.yaml. Found {len(e.errors())} errors:\n{error_str}') except Exception as e: - # This should likely be a fatal error, but I'm not sure there's a clean way to fail from setup self.fatal(f'Failed loading problem configuration: {e}') - return {} def __str__(self) -> str: return 'problem configuration' @@ -862,7 +836,7 @@ def check(self, context: Context) -> bool: self.error('Showing test data groups is only supported for scoring problems, this is a pass-fail problem') if ( not self._metadata.is_pass_fail() - and self.problem.get(ProblemTestCases)['root_group'].has_custom_groups() + and self.problem.testdata.has_custom_groups() and 'show_test_data_groups' not in self._origdata.get('grading', {}) and self.problem.format is FormatVersion.LEGACY ): @@ -903,45 +877,22 @@ def check(self, context: Context) -> bool: return self._check_res -class ProblemTestCases(ProblemPart): - PART_NAME = 'testdata' - - @staticmethod - def setup_dependencies(): - return {ProblemConfig} # We need this as the TestCaseGroup constructor reads config - - def setup(self): - self.testcase_by_infile = {} - return { - 'root_group': TestCaseGroup(self.problem, self.PART_NAME), - } - - def check(self, context: Context) -> bool: - return self.problem.get(ProblemTestCases)['root_group'].check(context) - - class Attachments(ProblemPart): """Represents the attachments of a problem. Attributes: attachments: The absolute paths to the attachment files for this problem. - """ + attachments: list[Path] + PART_NAME = 'attachments' def setup(self): - attachments_path = os.path.join(self.problem.probdir, 'attachments') - self.attachments: list[str] = [] - if os.path.isdir(attachments_path): - self.attachments = [ - os.path.join(attachments_path, attachment_name) for attachment_name in os.listdir(attachments_path) - ] - + attachments_dir = Path(self.problem.probdir) / 'attachments' + self.attachments = [p for p in attachments_dir.iterdir()] if attachments_dir.is_dir() else [] self.debug(f'Adding attachments {str(self.attachments)}') - return {} - def check(self, context: Context) -> bool: if self._check_res is not None: return self._check_res @@ -1048,7 +999,7 @@ def collect_flags(group: TestCaseGroup, flags: set[str]) -> None: for subgroup in group.get_subgroups(): collect_flags(subgroup, flags) - collect_flags(self.problem.get(ProblemTestCases)['root_group'], all_flags) + collect_flags(self.problem.testdata, all_flags) fd, file_name = tempfile.mkstemp() os.close(fd) @@ -1066,7 +1017,7 @@ def collect_flags(group: TestCaseGroup, flags: set[str]) -> None: self.warning(f'No validator rejects {desc} with flags "{" ".join(flags)}"') def modified_input_validates(applicable, modifier): - for testcase in self.problem.get(ProblemTestCases)['root_group'].get_all_testcases(): + for testcase in self.problem.testdata.get_all_testcases(): with open(testcase.infile) as infile: infile_data = infile.read() if not applicable(infile_data): @@ -1140,7 +1091,7 @@ def check(self, context: Context) -> bool: return self._check_res self._check_res = True - if self.problem.getMetadata().is_pass_fail() and len(self._graders) > 0: + if self.problem.metadata.is_pass_fail() and len(self._graders) > 0: self.error('There are grader programs but the problem is pass-fail') for grader in self._graders: @@ -1248,12 +1199,12 @@ def check(self, context: Context) -> bool: if isinstance(v, run.SourceCode) and v.language.lang_id not in recommended_output_validator_languages: self.warning('output validator language %s is not recommended' % v.language.name) - if self.problem.getMetadata().legacy_validation == 'default' and self._validators: + if self.problem.metadata.legacy_validation == 'default' and self._validators: self.error('There are validator programs but problem.yaml has validation = "default"') - elif self.problem.getMetadata().legacy_validation.startswith('custom') and not self._validators: + elif self.problem.metadata.legacy_validation.startswith('custom') and not self._validators: self.fatal('problem.yaml specifies custom validator but no validator programs found') - if self.problem.getMetadata().legacy_validation == 'default' and self._default_validator is None: + if self.problem.metadata.legacy_validation == 'default' and self._default_validator is None: self.fatal('Unable to locate default validator') for val in self._validators[:]: @@ -1266,7 +1217,7 @@ def check(self, context: Context) -> bool: # Only sanity check output validators if they all actually compiled if self._check_res: - flags = self.problem.getMetadata().legacy_validator_flags + flags = self.problem.metadata.legacy_validator_flags fd, file_name = tempfile.mkstemp() os.close(fd) @@ -1275,7 +1226,7 @@ def check(self, context: Context) -> bool: f.write(case) f.close() rejected = False - for testcase in self.problem.get(ProblemTestCases)['root_group'].get_all_testcases(): + for testcase in self.problem.testdata.get_all_testcases(): result = self.validate(testcase, file_name) if result.verdict != 'AC': rejected = True @@ -1307,7 +1258,7 @@ def _get_feedback(feedback_dir: str) -> str | None: return None def _parse_validator_results(self, val, status: int, feedbackdir, testcase: TestCase) -> SubmissionResult: - custom_score = self.problem.getMetadata().legacy_custom_score + custom_score = self.problem.metadata.legacy_custom_score score = None # TODO: would be good to have some way of displaying the feedback for debugging uses score_file = os.path.join(feedbackdir, 'score.txt') @@ -1347,9 +1298,7 @@ def _parse_validator_results(self, val, status: int, feedbackdir, testcase: Test def _actual_validators(self) -> list: vals = self._validators - if self.problem.getMetadata().legacy_validation == 'default' or ( - self.problem.format is FormatVersion.V_2023_07 and not vals - ): + if self.problem.metadata.legacy_validation == 'default' or (self.problem.format is FormatVersion.V_2023_07 and not vals): vals = [self._default_validator] return [val for val in vals if val is not None] @@ -1364,9 +1313,9 @@ def validate_interactive(self, testcase: TestCase, submission, timelim: int, err # file descriptor, wall time lim initargs = ['1', str(2 * timelim)] validator_args = [testcase.infile, testcase.ansfile, ''] - submission_args = submission.get_runcmd(memlim=self.problem.getMetadata().limits.memory) + submission_args = submission.get_runcmd(memlim=self.problem.metadata.limits.memory) - val_memlim = self.problem.getMetadata().limits.validation_memory + val_memlim = self.problem.metadata.limits.validation_memory for val in self._actual_validators(): if val.compile()[0]: feedbackdir = tempfile.mkdtemp(prefix='feedback', dir=self.problem.tmpdir) @@ -1422,11 +1371,10 @@ def validate_interactive(self, testcase: TestCase, submission, timelim: int, err def validate(self, testcase: TestCase, submission_output: str) -> SubmissionResult: res = SubmissionResult('JE') - val_timelim = self.problem.getMetadata().limits.validation_time - val_memlim = self.problem.getMetadata().limits.validation_memory + val_timelim = self.problem.metadata.limits.validation_time + val_memlim = self.problem.metadata.limits.validation_memory flags = ( - self.problem.getMetadata().legacy_validator_flags.split() - + testcase.testcasegroup.config['output_validator_flags'].split() + self.problem.metadata.legacy_validator_flags.split() + testcase.testcasegroup.config['output_validator_flags'].split() ) for val in self._actual_validators(): if val.compile()[0]: @@ -1562,7 +1510,7 @@ def _recompute_jobs(self) -> None: with self._lock: seen = set(self._started_jobs) self._remaining_jobs = [] - for testcase in self._gather_testcases(self._problem.get(ProblemTestCases)['root_group']): + for testcase in self._gather_testcases(self._problem.testdata): if testcase not in seen: seen.add(testcase) self._remaining_jobs.append(testcase) @@ -1615,9 +1563,7 @@ def check_submission( timelim_low = timelim with Runner(self.problem, sub, context, timelim, timelim_low, timelim_high) as runner: - result, result_low, result_high = self.problem.get(ProblemTestCases)['root_group'].run_submission( - sub, runner, context - ) + result, result_low, result_high = self.problem.testdata.run_submission(sub, runner, context) if result.verdict == 'AC' and expected_verdict == 'AC' and not partial and result.sample_failures: res = result.sample_failures[0] @@ -1648,21 +1594,21 @@ def check_submission( return result def full_score_finite(self) -> bool: - min_score, max_score = self.problem.get(ProblemTestCases)['root_group'].get_score_range() - if self.problem.getMetadata().legacy_grading.objective == 'min': + min_score, max_score = self.problem.testdata.get_score_range() + if self.problem.metadata.legacy_grading.objective == 'min': return min_score != float('-inf') else: return max_score != float('inf') def fully_accepted(self, result: SubmissionResult) -> bool: - min_score, max_score = self.problem.get(ProblemTestCases)['root_group'].get_score_range() - best_score = min_score if self.problem.getMetadata().legacy_grading.objective == 'min' else max_score - return result.verdict == 'AC' and (not self.problem.getMetadata().is_scoring() or result.score == best_score) + min_score, max_score = self.problem.testdata.get_score_range() + best_score = min_score if self.problem.metadata.legacy_grading.objective == 'min' else max_score + return result.verdict == 'AC' and (not self.problem.metadata.is_scoring() or result.score == best_score) def start_background_work(self, context: Context) -> None: # Send off an early background compile job for each submission and # validator, to avoid a bottleneck step at the start of each test run. - self.problem.getProblemPart(OutputValidators).start_background_work(context) + self.problem.output_validators.start_background_work(context) for acr in self._submissions: for sub in self._submissions[acr]: context.submit_background_work(lambda s: s.compile(), sub) @@ -1672,7 +1618,7 @@ def check(self, context: Context) -> bool: return self._check_res self._check_res = True - limits = self.problem.getMetadata().limits + limits = self.problem.metadata.limits time_multiplier = limits.time_multipliers.ac_to_time_limit safety_margin = limits.time_multipliers.time_limit_to_tle @@ -1717,7 +1663,7 @@ def check(self, context: Context) -> bool: max_runtime = max(runtimes) exact_timelim = max_runtime * time_multiplier max_runtime_str = f'{max_runtime:.3f}' - timelim = max(1, int(0.5 + exact_timelim)) + timelim = max(1, int(0.5 + exact_timelim)) # TODO: properly support 2023-07 time limit computation timelim_margin_lo = max(1, min(int(0.5 + exact_timelim / safety_margin), timelim - 1)) timelim_margin = max(timelim + 1, int(0.5 + exact_timelim * safety_margin)) else: @@ -1727,87 +1673,51 @@ def check(self, context: Context) -> bool: f' Solutions give timelim of {timelim} seconds, but will use provided fixed limit of {context.fixed_timelim} seconds instead' ) timelim = context.fixed_timelim - timelim_margin = round(timelim * safety_margin) # TODO: properly support 2023-07 time limit computation + timelim_margin = round(timelim * safety_margin) self.msg( f' Slowest AC runtime: {max_runtime_str}, setting timelim to {timelim} secs, safety margin to {timelim_margin} secs' ) + self.problem._set_timelim(timelim) return self._check_res -PROBLEM_FORMATS: dict[FormatVersion, dict[str, list[Type[ProblemPart]]]] = { - FormatVersion.LEGACY: { - 'config': [ProblemConfig], - 'statement': [ProblemStatement, Attachments], - 'validators': [InputValidators, OutputValidators], - 'graders': [Graders], - 'data': [ProblemTestCases], - 'submissions': [ - OutputValidators, - Submissions, - ], # OutputValidators duplicated to fatal() early if we can't find a validator. We should find a cleaner solution - }, - FormatVersion.V_2023_07: { # TODO: Add all the parts - 'config': [ProblemConfig], - 'statement': [ProblemStatement, Attachments], - 'validators': [InputValidators, OutputValidators], - 'graders': [Graders], - 'data': [ProblemTestCases], - 'submissions': [ - OutputValidators, - Submissions, - ], # OutputValidators duplicated to fatal() early if we can't find a validator. We should find a cleaner solution - }, -} - -# parts tested in alphabetical order -PROBLEM_PARTS = [*sorted({part for format in PROBLEM_FORMATS.values() for part in format})] - -_ProblemPartT = TypeVar('_ProblemPartT', bound=ProblemPart) +PROBLEM_PARTS = ['config', 'data', 'graders', 'statement', 'submissions', 'validators'] class Problem(ProblemAspect): """Represents a checkable problem""" - """ - Needs a problem-format in the form of a parts-dictionary, where all classes that verify the - problem are listed. These should all be a subclass of ProblemPart. The dictionary is in the form - of category -> part-types. You could for example have 'validators' -> [InputValidators, OutputValidators]. - """ - - def __init__( - self, probdir: str, args: argparse.Namespace, parts: dict[str, list[type]] = PROBLEM_FORMATS[FormatVersion.LEGACY] - ): - self.part_mapping: dict[str, list[Type[ProblemPart]]] = parts - self.aspects: set[type] = {v for s in parts.values() for v in s} + def __init__(self, probdir: str, args: argparse.Namespace): self.probdir = os.path.realpath(probdir) self.shortname: str = os.path.basename(self.probdir) super().__init__(self.shortname, self) self.language_config = languages.load_language_config() - self.format = get_format_version(Path(self.probdir)) - self._data: dict[str, dict] = {} + self.testcase_by_infile: dict[str, TestCase] = {} + self.loaded = False self._metadata: metadata.Metadata | None = None - self.debug(f'Problem-format: {parts}') self._args = args - self.loaded = False + self._timelim: float | None = None - def get(self, part) -> dict: - if isinstance(part, type) and issubclass(part, ProblemPart): - part = part.PART_NAME - assert part in self._data - return self._data[part] + # Unfortunately must be before metadata, otherwise mypy gets confused about the type metadata.Metadata (feels like a bug) + def _set_metadata(self, metadata: metadata.Metadata) -> None: # Should only be called by ProblemConfig + assert self._metadata is None, 'Attempted to set metadata twice' + self._metadata = metadata - def getMetadata(self) -> metadata.Metadata: - assert self._metadata is not None, 'Attempted to access Config before it was set' + @property + def metadata(self) -> metadata.Metadata: + assert self._metadata is not None, 'Attempted to access config before it was set. load() or check() first.' return self._metadata - def setMetadata(self, metadata: metadata.Metadata) -> None: - assert self._metadata is None, 'Attempted to set Config twice' - self._metadata = metadata + @property + def timelim(self) -> float: + assert self._timelim is not None, 'Attempted to access timelim before it was set. check() first.' + return self._timelim - def getProblemPart(self, part: Type[_ProblemPartT]) -> _ProblemPartT: - return self._classes[part.PART_NAME] # type: ignore + def _set_timelim(self, timelim: float) -> None: # Should only be called by Submissions + assert self._timelim is None, 'Attempted to set timelim twice' + self._timelim = timelim def load(self) -> None: """Parses the problem package statically, loading up information with very little verification. @@ -1828,33 +1738,18 @@ def load(self) -> None: if not os.path.isdir(self.probdir): self.fatal(f"Problem directory '{self.probdir}' not found") - # Initialize the classes, making sure to resolve dependencies first - initialized = set() - self._classes: dict[str, ProblemPart] = {} - - def init(_class): - if _class.PART_NAME in initialized: - return - - # A bit ugly but want to allow for subclasses - for dependency in _class.setup_dependencies(): - cnt = 0 - for cl in self.aspects: - if issubclass(cl, dependency): - init(cl) - cnt += 1 - if cnt != 1: - raise NotImplementedError( - f'Part "{_class.PART_NAME}" depends on part "{dependency.PART_NAME}" which showed up {cnt} times in problem-format (should have showed up exactly once)' - ) - self.debug(f'Initializing {_class.PART_NAME} ({_class})') - assert _class.PART_NAME not in initialized - self._classes[_class.PART_NAME] = _class(self) - self._data[_class.PART_NAME] = self._classes[_class.PART_NAME].setup() - initialized.add(_class.PART_NAME) - - for c in self.aspects: - init(c) + try: + self.format = get_format_version(Path(self.probdir)) + except Exception as e: + self.fatal(f'Failed loading problem version: {e}') + self.config = ProblemConfig(self) # Populates self.metadata as a side effect. Needs to run first. + self.statement = ProblemStatement(self) + self.attachments = Attachments(self) + self.input_validators = InputValidators(self) + self.output_validators = OutputValidators(self) + self.graders = Graders(self) + self.testdata = TestCaseGroup(self, os.path.join(self.probdir, 'data')) + self.submissions = Submissions(self) def __enter__(self) -> Problem: self.tmpdir = tempfile.mkdtemp(prefix=f'verify-{self.shortname}-') @@ -1888,6 +1783,16 @@ def check(self) -> tuple[int, int]: context = Context(self._args, executor) try: + part_mapping: dict[str, list] = { + 'config': [self.config], + 'statement': [self.statement, self.attachments], + 'validators': [self.input_validators, self.output_validators], + 'graders': [self.graders], + 'data': [self.testdata], + 'submissions': [self.submissions], + } + assert sorted(part_mapping.keys()) == sorted(PROBLEM_PARTS), 'part_mapping and PROBLEM_PARTS must be kept in sync' + if not re.match('^[a-z0-9]+$', self.shortname): self.error(f"Invalid shortname '{self.shortname}' (must be [a-z0-9]+)") if self.format is FormatVersion.V_2023_07: @@ -1898,25 +1803,24 @@ def check(self) -> tuple[int, int]: run.limit.check_limit_capabilities(self) - # Skip any parts that do not belong to the format - parts = [part for part in self._args.parts if part in self.part_mapping] - + parts = [ + part for part in part_mapping if part in self._args.parts + ] # Parts from _args in the order they appear in part_mapping if executor: for part in parts: - for item in self.part_mapping[part]: - self._classes[item.PART_NAME].start_background_work(context) + for item in part_mapping[part]: + item.start_background_work(context) for part in parts: self.msg(f'Checking {part}') - for item in self.part_mapping[part]: - self._classes[item.PART_NAME].check(context) + for item in part_mapping[part]: + item.check(context) except VerifyError: pass finally: # Wait for background work to finish before performing an rmtree on # the directory tree it uses. context.wait_for_background_work() - return self.errors, self.warnings def _check_symlinks(self): @@ -1989,13 +1893,6 @@ def argparser_basic_arguments(parser: argparse.ArgumentParser) -> None: default=15, help='maximum number of lines of additional info (e.g. compiler output or validator feedback) to display about an error (set to 0 to disable additional info)', ) - parser.add_argument( - '-v', - '--problem_format', - default='automatic', - choices=list(PROBLEM_FORMATS.keys()) + ['automatic'], - help='which problem format should the package be interpreted as, or "automatic" if it should be figured out from problem.yaml', - ) def argparser() -> argparse.ArgumentParser: @@ -2058,18 +1955,8 @@ def main() -> None: total_errors = 0 try: for problemdir in args.problemdir: - try: - if args.problem_format == 'automatic': - formatversion = get_format_version(Path(problemdir)) - else: - formatversion = FormatVersion(args.problem_format) - except Exception as e: - total_errors += 1 - print(f'ERROR: problem version could not be decided for {os.path.basename(os.path.realpath(problemdir))}: {e}') - continue - - print(f'Loading problem {os.path.basename(os.path.realpath(problemdir))} with format version {formatversion}') - with Problem(problemdir, args, PROBLEM_FORMATS[formatversion]) as prob: + print(f'Loading problem {os.path.basename(os.path.realpath(problemdir))}') + with Problem(problemdir, args) as prob: errors, warnings = prob.check() def p(x: int) -> str: diff --git a/tests/test_verify_hello.py b/tests/test_verify_hello.py index fa906c78..9bf5cd43 100644 --- a/tests/test_verify_hello.py +++ b/tests/test_verify_hello.py @@ -14,5 +14,5 @@ def test_load_hello(): p.load() assert p.shortname == 'hello' # pytest and fork don't go along very well, so just run aspects that work without run - assert p.getProblemPart(verify.ProblemConfig).check(context) - assert p.getProblemPart(verify.Attachments).check(context) + assert p.config.check(context) + assert p.attachments.check(context)