diff --git a/src/omegaup_hook_tools/git_tools.py b/src/omegaup_hook_tools/git_tools.py index 3229376..e9cb750 100644 --- a/src/omegaup_hook_tools/git_tools.py +++ b/src/omegaup_hook_tools/git_tools.py @@ -101,6 +101,17 @@ def _validate_args(args: argparse.Namespace, files: Sequence[Text]) -> bool: f'with `commits` or `files`.{COLORS.NORMAL}'), file=sys.stderr) return False + if args.staged: + if args.all_files: + print((f'{COLORS.FAIL}--staged is incompatible ' + f'with --all-files.{COLORS.NORMAL}'), + file=sys.stderr) + return False + if args.commits or files: + print((f'{COLORS.FAIL}--staged is incompatible ' + f'with `commits` or `files`.{COLORS.NORMAL}'), + file=sys.stderr) + return False if len(args.commits) not in (0, 1, 2): # args.commits can never be empty since its default value is ['HEAD'], # but the user can specify zero commits. @@ -122,6 +133,32 @@ def _get_all_files() -> Iterator[bytes]: yield path +def _get_staged_files() -> Iterator[bytes]: + '''Returns the list of files that are staged in the index.''' + cmd = ['/usr/bin/git', 'diff-index', '-z', '--diff-filter=d', + '--cached', 'HEAD'] + tokens = subprocess.run(cmd, + check=True, + stdout=subprocess.PIPE, + cwd=root_dir()).stdout.split(b'\x00') + idx = 0 + while idx < len(tokens) - 1: + match = GIT_DIFF_TREE_PATTERN.match(tokens[idx]) + assert match, tokens[idx] + filemode, status = match.groups() + if filemode == GIT_DIRECTORY_ENTRY_MODE: + idx += 2 + continue + src = tokens[idx + 1] + if status in (b'C', b'R'): + dest = tokens[idx + 2] + idx += 3 + yield dest + else: + idx += 2 + yield src + + def _get_changed_files(commits: List[Text]) -> Iterator[bytes]: ''' Returns the list of files that were modified in the specified range.''' @@ -168,6 +205,8 @@ def _files_to_consider(args: argparse.Namespace) -> List[Text]: # Get all files in the latter commit. if args.all_files: result = _get_all_files() + elif args.staged: + result = _get_staged_files() else: result = _get_changed_files(args.commits) @@ -208,6 +247,13 @@ def prompt(question: Text, default: bool = True) -> bool: def file_contents(args: argparse.Namespace, root: Text, filename: Text) -> bytes: '''Returns contents of |filename| at the revision specified by |args|.''' + if getattr(args, 'staged', False): + # When --staged is used, read the version of the file from the index. + return subprocess.run( + ['/usr/bin/git', 'show', f':{filename}'], + check=True, + stdout=subprocess.PIPE, + cwd=root).stdout if len(args.commits) in (0, 1): # Zero or one commits (where the former is a shorthand for 'HEAD') # always diff against the current contents of the file in the @@ -259,6 +305,11 @@ def parse_arguments( validate_parser.add_argument( '--all-files', action='store_true', help='Considers all files. Incompatible with `commits` and `files`') + validate_parser.add_argument( + '--staged', action='store_true', + help=('Only considers files staged in the index ' + '(git diff --cached). Incompatible with ' + '--all-files, `commits`, and `files`')) validate_parser.add_argument( 'commits', metavar='[commit [commit ...]] [--] [file [file ...]]', @@ -276,6 +327,11 @@ def parse_arguments( '--all-files', action='store_true', help=('Considers all files. ' 'Incompatible with `commits` and `files`')) + fix_parser.add_argument( + '--staged', action='store_true', + help=('Only considers files staged in the index ' + '(git diff --cached). Incompatible with ' + '--all-files, `commits`, and `files`')) fix_parser.add_argument( 'commits', metavar='[commit [commit ...]] [--] [file [file ...]]',