diff --git a/main.py b/main.py
index f523ae5..07de873 100755
--- a/main.py
+++ b/main.py
@@ -30,6 +30,7 @@
import sys
try:
import readline
+ readline # ignore unused import warning
except ImportError:
pass
@@ -37,6 +38,10 @@
# scanners
from recuperabit.fs.ntfs import NTFSScanner
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from recuperabit.fs.core_types import Partition
+
__author__ = "Andrea Lazzarotto"
__copyright__ = "(c) 2014-2021, Andrea Lazzarotto"
__license__ = "GPLv3"
@@ -97,7 +102,7 @@ def check_valid_part(num, parts, shorthands, rebuild=True):
return None
-def interpret(cmd, arguments, parts, shorthands, outdir):
+def interpret(cmd, arguments, parts: dict[int, 'Partition'], shorthands, outdir):
"""Perform command required by user."""
if cmd == 'help':
print('Available commands:')
@@ -362,7 +367,7 @@ def main():
pickle.dump(interesting, savefile)
# Ask for partitions
- parts = {}
+ parts: dict[int, Partition] = {}
for scanner in scanners:
parts.update(scanner.get_partitions())
diff --git a/recuperabit/fs/constants.py b/recuperabit/fs/constants.py
index 9370a77..69c928d 100644
--- a/recuperabit/fs/constants.py
+++ b/recuperabit/fs/constants.py
@@ -19,5 +19,5 @@
# along with RecuperaBit. If not, see .
-sector_size = 512
-max_sectors = 256 # Maximum block size for recovery
+sector_size: int = 512
+max_sectors: int = 256 # Maximum block size for recovery
diff --git a/recuperabit/fs/core_types.py b/recuperabit/fs/core_types.py
index 87dda78..eb94d2a 100644
--- a/recuperabit/fs/core_types.py
+++ b/recuperabit/fs/core_types.py
@@ -25,6 +25,8 @@
import logging
import os.path
+from typing import Optional, Dict, Set, List, Tuple, Union, Any, Iterator
+from datetime import datetime
from .constants import sector_size
@@ -32,49 +34,49 @@
class File(object):
- """Filesystem-independent representation of a file."""
- def __init__(self, index, name, size, is_directory=False,
- is_deleted=False, is_ghost=False):
- self.index = index
- self.name = name
- self.size = size
- self.is_directory = is_directory
- self.is_deleted = is_deleted
- self.is_ghost = is_ghost
- self.parent = None
- self.mac = {
+ """Filesystem-independent representation of a file. Aka Node."""
+ def __init__(self, index: Union[int, str], name: str, size: Optional[int], is_directory: bool = False,
+ is_deleted: bool = False, is_ghost: bool = False) -> None:
+ self.index: Union[int, str] = index
+ self.name: str = name
+ self.size: Optional[int] = size
+ self.is_directory: bool = is_directory
+ self.is_deleted: bool = is_deleted
+ self.is_ghost: bool = is_ghost
+ self.parent: Optional[Union[int, str]] = None
+ self.mac: Dict[str, Optional[datetime]] = {
'modification': None,
'access': None,
'creation': None
}
- self.children = set()
- self.children_names = set() # Avoid name clashes breaking restore
- self.offset = None # Offset from beginning of disk
+ self.children: Set['File'] = set()
+ self.children_names: Set[str] = set() # Avoid name clashes breaking restore
+ self.offset: Optional[int] = None # Offset from beginning of disk
- def set_parent(self, parent):
+ def set_parent(self, parent: Optional[Union[int, str]]) -> None:
"""Set a pointer to the parent directory."""
self.parent = parent
- def set_mac(self, modification, access, creation):
+ def set_mac(self, modification: Optional[datetime], access: Optional[datetime], creation: Optional[datetime]) -> None:
"""Set the modification, access and creation times."""
self.mac['modification'] = modification
self.mac['access'] = access
self.mac['creation'] = creation
- def get_mac(self):
+ def get_mac(self) -> List[Optional[datetime]]:
"""Get the modification, access and creation times."""
keys = ('modification', 'access', 'creation')
return [self.mac[k] for k in keys]
- def set_offset(self, offset):
+ def set_offset(self, offset: Optional[int]) -> None:
"""Set the offset of the file record with respect to the disk image."""
self.offset = offset
- def get_offset(self):
+ def get_offset(self) -> Optional[int]:
"""Get the offset of the file record with respect to the disk image."""
return self.offset
- def add_child(self, node):
+ def add_child(self, node: 'File') -> None:
"""Add a new child to this directory."""
original_name = node.name
i = 0
@@ -90,7 +92,7 @@ def add_child(self, node):
self.children.add(node)
self.children_names.add(node.name)
- def full_path(self, part):
+ def full_path(self, part: 'Partition') -> str:
"""Return the full path of this file."""
if self.parent is not None:
parent = part[self.parent]
@@ -98,7 +100,7 @@ def full_path(self, part):
else:
return self.name
- def get_content(self, partition):
+ def get_content(self, partition: 'Partition') -> Optional[Union[bytes, Iterator[bytes]]]:
# pylint: disable=W0613
"""Extract the content of the file.
@@ -109,14 +111,14 @@ def get_content(self, partition):
raise NotImplementedError
# pylint: disable=R0201
- def ignore(self):
+ def ignore(self) -> bool:
"""The following method is used by the restore procedure to check
files that should not be recovered. For example, in NTFS file
$BadClus:$Bad shall not be recovered because it creates an output
with the same size as the partition (usually many GBs)."""
return False
- def __repr__(self):
+ def __repr__(self) -> str:
return (
u'File(#%s, ^^%s^^, %s, offset = %s sectors)' %
(self.index, self.parent, self.name, self.offset)
@@ -128,42 +130,42 @@ class Partition(object):
Parameter root_id represents the identifier assigned to the root directory
of a partition. This can be file system dependent."""
- def __init__(self, fs_type, root_id, scanner):
- self.fs_type = fs_type
- self.root_id = root_id
- self.size = None
- self.offset = None
- self.root = None
- self.lost = File(-1, 'LostFiles', 0, is_directory=True, is_ghost=True)
- self.files = {}
- self.recoverable = False
- self.scanner = scanner
-
- def add_file(self, node):
+ def __init__(self, fs_type: str, root_id: Union[int, str], scanner: 'DiskScanner') -> None:
+ self.fs_type: str = fs_type
+ self.root_id: Union[int, str] = root_id
+ self.size: Optional[int] = None
+ self.offset: Optional[int] = None
+ self.root: Optional[File] = None
+ self.lost: File = File(-1, 'LostFiles', 0, is_directory=True, is_ghost=True)
+ self.files: Dict[Union[int, str], File] = {}
+ self.recoverable: bool = False
+ self.scanner: 'DiskScanner' = scanner
+
+ def add_file(self, node: File) -> None:
"""Insert a new file in the partition."""
index = node.index
self.files[index] = node
- def set_root(self, node):
+ def set_root(self, node: File) -> None:
"""Set the root directory."""
if not node.is_directory:
raise TypeError('Not a directory')
self.root = node
self.root.set_parent(None)
- def set_size(self, size):
+ def set_size(self, size: int) -> None:
"""Set the (estimated) size of the partition."""
self.size = size
- def set_offset(self, offset):
+ def set_offset(self, offset: int) -> None:
"""Set the offset from the beginning of the disk."""
self.offset = offset
- def set_recoverable(self, recoverable):
+ def set_recoverable(self, recoverable: bool) -> None:
"""State if the partition contents are also recoverable."""
self.recoverable = recoverable
- def rebuild(self):
+ def rebuild(self) -> None:
"""Rebuild the partition structure.
This method processes the contents of files and it rebuilds the
@@ -201,11 +203,11 @@ def rebuild(self):
return
# pylint: disable=R0201
- def additional_repr(self):
+ def additional_repr(self) -> List[Tuple[str, Any]]:
"""Return additional values to show in the string representation."""
return []
- def __repr__(self):
+ def __repr__(self) -> str:
size = (
readable_bytes(self.size * sector_size)
if self.size is not None else '??? b'
@@ -227,14 +229,14 @@ def __repr__(self):
', '.join(a+': '+str(b) for a, b in data)
)
- def __getitem__(self, index):
+ def __getitem__(self, index: Union[int, str]) -> File:
if index in self.files:
return self.files[index]
if index == self.lost.index:
return self.lost
raise KeyError
- def get(self, index, default=None):
+ def get(self, index: Union[int, str], default: Optional[File] = None) -> Optional[File]:
"""Get a file or the special LostFiles directory."""
try:
return self.__getitem__(index)
@@ -244,17 +246,22 @@ def get(self, index, default=None):
class DiskScanner(object):
"""Abstract stub for the implementation of disk scanners."""
- def __init__(self, pointer):
- self.image = pointer
+ def __init__(self, pointer: Any) -> None:
+ self.image: Any = pointer
- def get_image(self):
+ def get_image(self) -> Any:
"""Return the image reference."""
return self.image
- def feed(self, index, sector):
+ @staticmethod
+ def get_image(scanner: 'DiskScanner') -> Any:
+ """Static method to get image from scanner instance."""
+ return scanner.image
+
+ def feed(self, index: int, sector: bytes) -> Optional[str]:
"""Feed a new sector."""
raise NotImplementedError
- def get_partitions(self):
+ def get_partitions(self) -> Dict[int, Partition]:
"""Get a list of the found partitions."""
raise NotImplementedError
diff --git a/recuperabit/fs/ntfs.py b/recuperabit/fs/ntfs.py
index f0e820c..a168404 100644
--- a/recuperabit/fs/ntfs.py
+++ b/recuperabit/fs/ntfs.py
@@ -24,6 +24,7 @@
import logging
from collections import Counter
+from typing import Any, Dict, List, Optional, Tuple, Union, Iterator, Set
from .constants import max_sectors, sector_size
from .core_types import DiskScanner, File, Partition
@@ -36,7 +37,7 @@
from ..utils import merge, sectors, unpack
# Some attributes may appear multiple times
-multiple_attributes = set([
+multiple_attributes: Set[str] = set([
'$FILE_NAME',
'$DATA',
'$INDEX_ROOT',
@@ -45,11 +46,11 @@
])
# Size of records in sectors
-FILE_size = 2
-INDX_size = 8
+FILE_size: int = 2
+INDX_size: int = 8
-def best_name(entries):
+def best_name(entries: List[Tuple[int, str]]) -> Optional[str]:
"""Return the best file name available.
This function accepts a list of tuples formed by a namespace and a string.
@@ -66,8 +67,10 @@ def best_name(entries):
return name if len(name) else None
-def parse_mft_attr(attr):
+def parse_mft_attr(attr: Union[bytes, bytearray]) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parse the contents of a MFT attribute."""
+ assert isinstance(attr, (bytes, bytearray)), f"attr must be bytes or bytearray, got {type(attr)}"
+
header = unpack(attr, attr_header_fmt)
attr_type = header['type']
@@ -94,7 +97,7 @@ def parse_mft_attr(attr):
return header, name
-def _apply_fixup_values(header, entry):
+def _apply_fixup_values(header: Dict[str, Any], entry: bytearray) -> None:
"""Apply the fixup values to FILE and INDX records."""
offset = header['off_fixup']
for i in range(1, header['n_entries']):
@@ -102,7 +105,7 @@ def _apply_fixup_values(header, entry):
entry[pos-2:pos] = entry[offset + 2*i:offset + 2*(i+1)]
-def _attributes_reader(entry, offset):
+def _attributes_reader(entry: Union[bytes, bytearray], offset: int) -> Dict[str, Any]:
"""Read every attribute."""
attributes = {}
while offset < len(entry) - 16:
@@ -133,8 +136,10 @@ def _attributes_reader(entry, offset):
return attributes
-def parse_file_record(entry):
+def parse_file_record(entry: bytearray | bytes) -> Dict[str, Any]:
"""Parse the contents of a FILE record (MFT entry)."""
+ assert isinstance(entry, (bytearray, bytes)), f"entry must be bytearray or bytes, got {type(entry)}"
+
header = unpack(entry, entry_fmt)
if (header['size_alloc'] is None or
header['size_alloc'] > len(entry) or
@@ -154,8 +159,10 @@ def parse_file_record(entry):
return header
-def parse_indx_record(entry):
+def parse_indx_record(entry: bytearray | bytes) -> Dict[str, Any]:
"""Parse the contents of a INDX record (directory index)."""
+ assert isinstance(entry, (bytearray, bytes)), f"entry must be bytearray or bytes, got {type(entry)}"
+
header = unpack(entry, indx_fmt)
_apply_fixup_values(header, entry)
@@ -200,7 +207,7 @@ def parse_indx_record(entry):
return header
-def _integrate_attribute_list(parsed, part, image):
+def _integrate_attribute_list(parsed: Dict[str, Any], part: 'NTFSPartition', image: Any) -> None:
"""Integrate missing attributes in the parsed MTF entry."""
base_record = parsed['record_n']
attrs = parsed['attributes']
@@ -264,7 +271,7 @@ def _integrate_attribute_list(parsed, part, image):
class NTFSFile(File):
"""NTFS File."""
- def __init__(self, parsed, offset, is_ghost=False, ads=''):
+ def __init__(self, parsed: Dict[str, Any], offset: Optional[int], is_ghost: bool = False, ads: str = '') -> None:
index = parsed['record_n']
ads_suffix = ':' + ads if ads != '' else ads
if ads != '':
@@ -322,7 +329,7 @@ def __init__(self, parsed, offset, is_ghost=False, ads=''):
self.ads = ads
@staticmethod
- def _padded_bytes(image, offset, size):
+ def _padded_bytes(image: Any, offset: int, size: int) -> bytes:
dump = sectors(image, offset, size, 1)
if len(dump) < size:
logging.warning(
@@ -331,7 +338,7 @@ def _padded_bytes(image, offset, size):
dump += bytearray(b'\x00' * (size - len(dump)))
return dump
- def content_iterator(self, partition, image, datas):
+ def content_iterator(self, partition: 'NTFSPartition', image: Any, datas: List[Dict[str, Any]]) -> Iterator[bytes]:
"""Return an iterator for the contents of this file."""
vcn = 0
spc = partition.sec_per_clus
@@ -378,7 +385,7 @@ def content_iterator(self, partition, image, datas):
yield bytes(partial)
vcn = attr['end_VCN'] + 1
- def get_content(self, partition):
+ def get_content(self, partition: 'NTFSPartition') -> Optional[Union[bytes, Iterator[bytes]]]:
"""Extract the content of the file.
This method works by extracting the $DATA attribute."""
@@ -439,7 +446,7 @@ def get_content(self, partition):
)
return self.content_iterator(partition, image, non_resident)
- def ignore(self):
+ def ignore(self) -> bool:
"""Determine which files should be ignored."""
return (
(self.index == '8:$Bad') or
@@ -449,13 +456,13 @@ def ignore(self):
class NTFSPartition(Partition):
"""Partition with additional fields for NTFS recovery."""
- def __init__(self, scanner, position=None):
+ def __init__(self, scanner: 'NTFSScanner', position: Optional[int] = None) -> None:
Partition.__init__(self, 'NTFS', 5, scanner)
- self.sec_per_clus = None
- self.mft_pos = position
- self.mftmirr_pos = None
+ self.sec_per_clus: Optional[int] = None
+ self.mft_pos: Optional[int] = position
+ self.mftmirr_pos: Optional[int] = None
- def additional_repr(self):
+ def additional_repr(self) -> List[Tuple[str, Any]]:
"""Return additional values to show in the string representation."""
return [
('Sec/Clus', self.sec_per_clus),
@@ -466,17 +473,17 @@ def additional_repr(self):
class NTFSScanner(DiskScanner):
"""NTFS Disk Scanner."""
- def __init__(self, pointer):
+ def __init__(self, pointer: Any) -> None:
DiskScanner.__init__(self, pointer)
- self.found_file = set()
- self.parsed_file_review = {}
- self.found_indx = set()
- self.parsed_indx = {}
- self.indx_list = None
- self.found_boot = []
- self.found_spc = []
-
- def feed(self, index, sector):
+ self.found_file: Set[int] = set()
+ self.parsed_file_review: Dict[int, Dict[str, Any]] = {}
+ self.found_indx: Set[int] = set()
+ self.parsed_indx: Dict[int, Dict[str, Any]] = {}
+ self.indx_list: Optional[SparseList[int]] = None
+ self.found_boot: List[int] = []
+ self.found_spc: List[int] = []
+
+ def feed(self, index: int, sector: bytes) -> Optional[str]:
"""Feed a new sector."""
# check boot sector
if sector.endswith(b'\x55\xAA') and b'NTFS' in sector[:8]:
@@ -494,7 +501,7 @@ def feed(self, index, sector):
return 'NTFS index record'
@staticmethod
- def add_indx_entries(entries, part):
+ def add_indx_entries(entries: List[Dict[str, Any]], part: NTFSPartition) -> None:
"""Insert new ghost files which were not already found."""
for rec in entries:
if (rec['record_n'] not in part.files and
@@ -512,7 +519,7 @@ def add_indx_entries(entries, part):
rec['flags'] = 0x1
part.add_file(NTFSFile(rec, None, is_ghost=True))
- def add_from_indx_root(self, parsed, part):
+ def add_from_indx_root(self, parsed: Dict[str, Any], part: NTFSPartition) -> None:
"""Add ghost entries to part from INDEX_ROOT attributes in parsed."""
for attribute in parsed['attributes']['$INDEX_ROOT']:
if (attribute.get('content') is None or
@@ -520,7 +527,7 @@ def add_from_indx_root(self, parsed, part):
continue
self.add_indx_entries(attribute['content']['records'], part)
- def most_likely_sec_per_clus(self):
+ def most_likely_sec_per_clus(self) -> List[int]:
"""Determine the most likely value of sec_per_clus of each partition,
to speed up the search."""
counter = Counter()
@@ -528,7 +535,7 @@ def most_likely_sec_per_clus(self):
counter.update(2**i for i in range(8))
return [i for i, _ in counter.most_common()]
- def find_boundary(self, part, mft_address, multipliers):
+ def find_boundary(self, part: NTFSPartition, mft_address: int, multipliers: List[int]) -> Tuple[Optional[int], Optional[int]]:
"""Determine the starting sector of a partition with INDX records."""
nodes = (
self.parsed_file_review[node.offset]
@@ -593,7 +600,7 @@ def find_boundary(self, part, mft_address, multipliers):
else:
return (None, None)
- def add_from_indx_allocation(self, parsed, part):
+ def add_from_indx_allocation(self, parsed: Dict[str, Any], part: NTFSPartition) -> None:
"""Add ghost entries to part from INDEX_ALLOCATION attributes in parsed.
This procedure requires that the beginning of the partition has already
@@ -625,7 +632,7 @@ def add_from_indx_allocation(self, parsed, part):
entries = parse_indx_record(dump)['entries']
self.add_indx_entries(entries, part)
- def add_from_attribute_list(self, parsed, part, offset):
+ def add_from_attribute_list(self, parsed: Dict[str, Any], part: NTFSPartition, offset: int) -> None:
"""Add additional entries to part from attributes in ATTRIBUTE_LIST.
Files with many attributes may have additional attributes not in the
@@ -643,7 +650,7 @@ def add_from_attribute_list(self, parsed, part, offset):
if ads_name and len(ads_name):
part.add_file(NTFSFile(parsed, offset, ads=ads_name))
- def add_from_mft_mirror(self, part):
+ def add_from_mft_mirror(self, part: NTFSPartition) -> None:
"""Fix the first file records using the MFT mirror."""
img = DiskScanner.get_image(self)
mirrpos = part.mftmirr_pos
@@ -664,7 +671,7 @@ def add_from_mft_mirror(self, part):
'%s from backup', node.index, node.name, part.offset
)
- def finalize_reconstruction(self, part):
+ def finalize_reconstruction(self, part: NTFSPartition) -> None:
"""Finish information gathering from a file.
This procedure requires that the beginning of the
@@ -693,9 +700,9 @@ def finalize_reconstruction(self, part):
parsed = self.parsed_file_review[node.offset]
self.add_from_indx_allocation(parsed, part)
- def get_partitions(self):
+ def get_partitions(self) -> Dict[int, NTFSPartition]:
"""Get a list of the found partitions."""
- partitioned_files = {}
+ partitioned_files: Dict[int, NTFSPartition] = {}
img = DiskScanner.get_image(self)
logging.info('Parsing MFT entries')
diff --git a/recuperabit/logic.py b/recuperabit/logic.py
index e97052b..c4f10c2 100644
--- a/recuperabit/logic.py
+++ b/recuperabit/logic.py
@@ -27,30 +27,34 @@
import sys
import time
import types
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Iterator, Set, Tuple, TypeVar, Generic
-from .utils import tiny_repr
+T = TypeVar('T')
+if TYPE_CHECKING:
+ from .fs.core_types import File, Partition
-class SparseList(object):
+
+class SparseList(Generic[T]):
"""List which only stores values at some places."""
- def __init__(self, data=None, default=None):
- self.keys = [] # This is always kept in order
- self.elements = {}
- self.default = default
+ def __init__(self, data: Optional[Dict[int, T]] = None, default: Optional[T] = None) -> None:
+ self.keys: List[int] = [] # This is always kept in order
+ self.elements: Dict[int, T] = {}
+ self.default: Optional[T] = default
if data is not None:
self.keys = sorted(data)
self.elements.update(data)
- def __len__(self):
+ def __len__(self) -> int:
try:
return self.keys[-1] + 1
except IndexError:
return 0
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Optional[T]:
return self.elements.get(index, self.default)
- def __setitem__(self, index, item):
+ def __setitem__(self, index: int, item: T) -> None:
if item == self.default:
if index in self.elements:
del self.elements[index]
@@ -60,18 +64,18 @@ def __setitem__(self, index, item):
bisect.insort(self.keys, index)
self.elements[index] = item
- def __contains__(self, element):
+ def __contains__(self, element: T) -> bool:
return element in self.elements.values()
- def __iter__(self):
+ def __iter__(self) -> Iterator[int]:
return self.keys.__iter__()
- def __repr__(self):
+ def __repr__(self) -> str:
elems = []
prevk = 0
if len(self.elements) > 0:
k = self.keys[0]
- elems.append(str(k) + ' -> ' + tiny_repr(self.elements[k]))
+ elems.append(str(k) + ' -> ' + repr(self.elements[k]))
prevk = self.keys[0]
for i in range(1, len(self.elements)):
nextk = self.keys[i]
@@ -79,31 +83,31 @@ def __repr__(self):
while prevk < nextk - 1:
elems.append('__')
prevk += 1
- elems.append(tiny_repr(self.elements[nextk]))
+ elems.append(repr(self.elements[nextk]))
else:
elems.append('\n... ' + str(nextk) + ' -> ' +
- tiny_repr(self.elements[nextk]))
+ repr(self.elements[nextk]))
prevk = nextk
return '[' + ', '.join(elems) + ']'
- def iterkeys(self):
+ def iterkeys(self) -> Iterator[int]:
"""An iterator over the keys of actual elements."""
return self.__iter__()
- def iterkeys_rev(self):
+ def iterkeys_rev(self) -> Iterator[int]:
"""An iterator over the keys of actual elements (reversed)."""
i = len(self.keys)
while i > 0:
i -= 1
yield self.keys[i]
- def itervalues(self):
+ def itervalues(self) -> Iterator[T]:
"""An iterator over the elements."""
for k in self.keys:
yield self.elements[k]
- def wipe_interval(self, bottom, top):
+ def wipe_interval(self, bottom: int, top: int) -> None:
"""Remove elements between bottom and top."""
new_keys = set()
if bottom > top:
@@ -121,12 +125,12 @@ def wipe_interval(self, bottom, top):
self.keys = sorted(new_keys)
-def preprocess_pattern(pattern):
+def preprocess_pattern(pattern: SparseList[T]) -> Dict[T, List[int]]:
"""Preprocess a SparseList for approximate string matching.
This function performs preprocessing for the Baeza-Yates--Perleberg
fast and practical approximate string matching algorithm."""
- result = {}
+ result: Dict[T, List[int]] = {}
length = pattern.__len__()
for k in pattern:
name = pattern[k]
@@ -137,7 +141,7 @@ def preprocess_pattern(pattern):
return result
-def approximate_matching(records, pattern, stop, k=1):
+def approximate_matching(records: SparseList[T], pattern: SparseList[T], stop: int, k: int = 1) -> Optional[List[Union[Set[int], int, float]]]:
"""Find the best match for a given pattern.
The Baeza-Yates--Perleberg algorithm requires a preprocessed pattern. This
@@ -152,8 +156,8 @@ def approximate_matching(records, pattern, stop, k=1):
return None
lookup = preprocess_pattern(pattern)
- count = SparseList(default=0)
- match_offsets = set()
+ count: SparseList[int] = SparseList(default=0)
+ match_offsets: Set[int] = set()
i = 0
j = 0 # previous value of i
@@ -192,7 +196,7 @@ def approximate_matching(records, pattern, stop, k=1):
return None
-def makedirs(path):
+def makedirs(path: str) -> bool:
"""Make directories if they do not exist."""
try:
os.makedirs(path)
@@ -205,7 +209,7 @@ def makedirs(path):
return True
-def recursive_restore(node, part, outputdir, make_dirs=True):
+def recursive_restore(node: 'File', part: 'Partition', outputdir: str, make_dirs: bool = True) -> None:
"""Restore a directory structure starting from a file node."""
parent_path = str(
part[node.parent].full_path(part) if node.parent is not None
diff --git a/recuperabit/utils.py b/recuperabit/utils.py
index 3ee1424..45baaa5 100644
--- a/recuperabit/utils.py
+++ b/recuperabit/utils.py
@@ -19,25 +19,31 @@
# along with RecuperaBit. If not, see .
+from datetime import datetime
import logging
import pprint
import string
import sys
import time
+from typing import TYPE_CHECKING, Any, Optional, List, Dict, Tuple, Union, Callable
import unicodedata
+import io
from .fs.constants import sector_size
-printer = pprint.PrettyPrinter(indent=4)
+printer: pprint.PrettyPrinter = pprint.PrettyPrinter(indent=4)
all_chars = (chr(i) for i in range(sys.maxunicode))
-unicode_printable = set(
+unicode_printable: set[str] = set(
c for c in all_chars
if not unicodedata.category(c)[0].startswith('C')
)
-ascii_printable = set(string.printable[:-5])
+ascii_printable: set[str] = set(string.printable[:-5])
+if TYPE_CHECKING:
+ from .fs.core_types import File, Partition
-def sectors(image, offset, size, bsize=sector_size, fill=True):
+
+def sectors(image: io.BufferedReader, offset: int, size: int, bsize: int = sector_size, fill: bool = True) -> Optional[bytearray]:
"""Read from a file descriptor."""
read = True
try:
@@ -60,7 +66,7 @@ def sectors(image, offset, size, bsize=sector_size, fill=True):
return None
return bytearray(dump)
-def unixtime(dtime):
+def unixtime(dtime: Optional[datetime]) -> int:
"""Convert datetime to UNIX epoch."""
if dtime is None:
return 0
@@ -72,9 +78,9 @@ def unixtime(dtime):
# format:
# [(label, (formatter, lower, higher)), ...]
-def unpack(data, fmt):
+def unpack(data: bytes | bytearray, fmt: List[Tuple[str, Tuple[Union[str, Callable[[bytes], Any]], Union[int, Callable[[Dict[str, Any]], Optional[int]]], Union[int, Callable[[Dict[str, Any]], Optional[int]]]]]]) -> Dict[str, Any]:
"""Extract formatted information from a string of bytes."""
- result = {}
+ result: Dict[str, Any] = {}
for label, description in fmt:
formatter, lower, higher = description
# If lower is a function, then apply it
@@ -112,9 +118,9 @@ def unpack(data, fmt):
return result
-def feed_all(image, scanners, indexes):
+def feed_all(image: io.BufferedReader, scanners: List[Any], indexes: List[int]) -> List[int]:
# Scan the disk image and feed the scanners
- interesting = []
+ interesting: List[int] = []
for index in indexes:
sector = sectors(image, index, 1, fill=False)
if not sector:
@@ -128,29 +134,19 @@ def feed_all(image, scanners, indexes):
return interesting
-def printable(text, default='.', alphabet=None):
+def printable(text: str, default: str = '.', alphabet: Optional[set[str]] = None) -> str:
"""Replace unprintable characters in a text with a default one."""
if alphabet is None:
alphabet = unicode_printable
return ''.join((i if i in alphabet else default) for i in text)
-def pretty(dictionary):
- """Format dictionary with the pretty printer."""
- return printer.pformat(dictionary)
-def show(dictionary):
- """Print dictionary with the pretty printer."""
- printer.pprint(dictionary)
-def tiny_repr(element):
- """deprecated: Return a representation of unicode strings without the u."""
- rep = repr(element)
- return rep[1:] if type(element) == unicode else rep
-def readable_bytes(amount):
+def readable_bytes(amount: Optional[int]) -> str:
"""Return a human readable string representing a size in bytes."""
if amount is None:
return '??? B'
@@ -164,7 +160,7 @@ def readable_bytes(amount):
return '%.2f %sB' % (scaled, powers[biggest])
-def _file_tree_repr(node):
+def _file_tree_repr(node: 'File') -> str:
"""Give a nice representation for the tree."""
desc = (
' [GHOST]' if node.is_ghost else
@@ -188,9 +184,9 @@ def _file_tree_repr(node):
)
-def tree_folder(directory, padding=0):
+def tree_folder(directory: 'File', padding: int = 0) -> str:
"""Return a tree-like textual representation of a directory."""
- lines = []
+ lines: List[str] = []
pad = ' ' * padding
lines.append(
pad + _file_tree_repr(directory)
@@ -207,7 +203,7 @@ def tree_folder(directory, padding=0):
return '\n'.join(lines)
-def _bodyfile_repr(node, path):
+def _bodyfile_repr(node: 'File', path: str) -> str:
"""Return a body file line for node."""
end = '/' if node.is_directory or len(node.children) else ''
return '|'.join(str(el) for el in [
@@ -223,13 +219,13 @@ def _bodyfile_repr(node, path):
])
-def bodyfile_folder(directory, path=''):
+def bodyfile_folder(directory: 'File', path: str = '') -> List[str]:
"""Create a body file compatible with TSK 3.x.
Format:
'#MD5|name|inode|mode_as_string|UID|GID|size|atime|mtime|ctime|crtime'
See also: http://wiki.sleuthkit.org/index.php?title=Body_file"""
- lines = [_bodyfile_repr(directory, path)]
+ lines: List[str] = [_bodyfile_repr(directory, path)]
path += directory.name + '/'
for entry in directory.children:
if len(entry.children) or entry.is_directory:
@@ -239,7 +235,7 @@ def bodyfile_folder(directory, path=''):
return lines
-def _ltx_clean(label):
+def _ltx_clean(label: Any) -> str:
"""Small filter to prepare strings to be included in LaTeX code."""
clean = str(label).replace('$', r'\$').replace('_', r'\_')
if clean[0] == '-':
@@ -247,7 +243,7 @@ def _ltx_clean(label):
return clean
-def _tikz_repr(node):
+def _tikz_repr(node: 'File') -> str:
"""Represent the node for a Tikz diagram."""
return r'node %s{%s\enskip{}%s}' % (
'[ghost]' if node.is_ghost else '[deleted]' if node.is_deleted else '',
@@ -255,11 +251,11 @@ def _tikz_repr(node):
)
-def tikz_child(directory, padding=0):
+def tikz_child(directory: 'File', padding: int = 0) -> Tuple[str, int]:
"""Write a child row for Tikz representation."""
pad = ' ' * padding
- lines = [r'%schild {%s' % (pad, _tikz_repr(directory))]
- count = len(directory.children)
+ lines: List[str] = [r'%schild {%s' % (pad, _tikz_repr(directory))]
+ count: int = len(directory.children)
for entry in directory.children:
content, number = tikz_child(entry, padding+4)
lines.append(content)
@@ -270,7 +266,7 @@ def tikz_child(directory, padding=0):
return '\n'.join(lines).replace('\n}', '}'), count
-def tikz_part(part):
+def tikz_part(part: 'Partition') -> str:
"""Create LaTeX code to represent the directory structure as a nice Tikz
diagram.
@@ -296,7 +292,7 @@ def tikz_part(part):
)
-def csv_part(part):
+def csv_part(part: 'Partition') -> list[str]:
"""Provide a CSV representation for a partition."""
contents = [
','.join(('Id', 'Parent', 'Name', 'Full Path', 'Modification Time',
@@ -324,9 +320,9 @@ def csv_part(part):
return contents
-def _sub_locate(text, directory, part):
+def _sub_locate(text: str, directory: 'File', part: 'Partition') -> List[Tuple['File', str]]:
"""Helper for locate."""
- lines = []
+ lines: List[Tuple['File', str]] = []
for entry in sorted(directory.children, key=lambda node: node.name):
path = entry.full_path(part)
if text in path.lower():
@@ -336,16 +332,16 @@ def _sub_locate(text, directory, part):
return lines
-def locate(part, text):
+def locate(part: 'Partition', text: str) -> List[Tuple['File', str]]:
"""Return paths of files matching the text."""
- lines = []
+ lines: List[Tuple['File', str]] = []
text = text.lower()
lines += _sub_locate(text, part.lost, part)
lines += _sub_locate(text, part.root, part)
return lines
-def merge(part, piece):
+def merge(part: 'Partition', piece: 'Partition') -> None:
"""Merge piece into part (both are partitions)."""
for index in piece.files:
if (
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..76b86d6
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+"""Test suite for RecuperaBit NTFS recovery tool."""
diff --git a/tests/data/reference_files/deep/nested/directory/deep_file.txt b/tests/data/reference_files/deep/nested/directory/deep_file.txt
new file mode 100644
index 0000000..22ce1eb
--- /dev/null
+++ b/tests/data/reference_files/deep/nested/directory/deep_file.txt
@@ -0,0 +1 @@
+File in deep nested directory
diff --git a/tests/data/reference_files/empty_file.empty b/tests/data/reference_files/empty_file.empty
new file mode 100644
index 0000000..e69de29
diff --git a/tests/data/reference_files/file_with_ads.txt b/tests/data/reference_files/file_with_ads.txt
new file mode 100644
index 0000000..3a0c07f
--- /dev/null
+++ b/tests/data/reference_files/file_with_ads.txt
@@ -0,0 +1 @@
+File with ADS main content
diff --git a/tests/data/reference_files/large_file.dat b/tests/data/reference_files/large_file.dat
new file mode 100644
index 0000000..8f0c060
--- /dev/null
+++ b/tests/data/reference_files/large_file.dat
@@ -0,0 +1 @@

\ No newline at end of file
diff --git a/tests/data/reference_files/medium_binary.bin b/tests/data/reference_files/medium_binary.bin
new file mode 100644
index 0000000..a6a8d35
Binary files /dev/null and b/tests/data/reference_files/medium_binary.bin differ
diff --git a/tests/data/reference_files/small_text.txt b/tests/data/reference_files/small_text.txt
new file mode 100644
index 0000000..85d20ae
--- /dev/null
+++ b/tests/data/reference_files/small_text.txt
@@ -0,0 +1 @@
+Hello, World! This is a small text file.
diff --git a/tests/data/reference_files/subdirectory/subdir_file1.txt b/tests/data/reference_files/subdirectory/subdir_file1.txt
new file mode 100644
index 0000000..9b7eac5
--- /dev/null
+++ b/tests/data/reference_files/subdirectory/subdir_file1.txt
@@ -0,0 +1 @@
+File in subdirectory
diff --git a/tests/data/reference_files/subdirectory/subdir_file2.bin b/tests/data/reference_files/subdirectory/subdir_file2.bin
new file mode 100644
index 0000000..bd26b49
Binary files /dev/null and b/tests/data/reference_files/subdirectory/subdir_file2.bin differ
diff --git "a/tests/data/reference_files/unicode_name_\321\204\320\260\320\271\320\273.txt" "b/tests/data/reference_files/unicode_name_\321\204\320\260\320\271\320\273.txt"
new file mode 100644
index 0000000..b15958f
--- /dev/null
+++ "b/tests/data/reference_files/unicode_name_\321\204\320\260\320\271\320\273.txt"
@@ -0,0 +1 @@
+Файл с unicode именем
diff --git a/tests/data/reference_ntfs.img.gz b/tests/data/reference_ntfs.img.gz
new file mode 100644
index 0000000..7edc868
Binary files /dev/null and b/tests/data/reference_ntfs.img.gz differ
diff --git a/tests/data/reference_ntfs.json b/tests/data/reference_ntfs.json
new file mode 100644
index 0000000..7f765ec
--- /dev/null
+++ b/tests/data/reference_ntfs.json
@@ -0,0 +1,20 @@
+{
+ "created_by": "build_reference_ntfs.py",
+ "file_count": 9,
+ "image_hash": "864fcb2d2e0fc0ec2620e01929246fe263c68692c5793e90e24e180d844bebd7",
+ "image_size_bytes": 104857600,
+ "notes": "Reference NTFS filesystem for RecuperaBit E2E tests",
+ "reference_files_hashes": {
+ "deep/nested/directory/deep_file.txt": "c44594d1bf12bfe9b8f85ff08d44210cb0d7976ba292434b2c50e3cc8331e53c",
+ "empty_file.empty": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
+ "file_with_ads.txt": "674a71cd25529aa999f590bed146fc496f85f2659aa7f1bdbadbef2ca36ea8a4",
+ "large_file.dat": "3031d04fa4ca07a7794b60b8eca58afc7cc7faea73b6d47352d526e906346d02",
+ "medium_binary.bin": "334bbc577b4b8075539ee42aca3a8a71b02861507294223087cdcdc9040e0f16",
+ "small_text.txt": "a1a5e8236ce8738b882d5cbd97af29414e75407db56d1c4668d261a6d9a17091",
+ "subdirectory/subdir_file1.txt": "33008bf127acca6e47249b311652515c5803249d25db5951da9f296297c6feae",
+ "subdirectory/subdir_file2.bin": "1ee7722ccd6ef215afa901f5a8bbeaf770952f7ccf2586ebf74e841e24b46d8f",
+ "unicode_name_\u0444\u0430\u0439\u043b.txt": "cebaf880b99881b574003c8220ac3013e2830b376d2d6b3f001db6e51f449d85"
+ },
+ "size_mb": 100,
+ "version": "1.0"
+}
\ No newline at end of file
diff --git a/tests/reference_image.py b/tests/reference_image.py
new file mode 100644
index 0000000..01d87a4
--- /dev/null
+++ b/tests/reference_image.py
@@ -0,0 +1,168 @@
+"""Reference NTFS image utilities for E2E tests."""
+
+import hashlib
+import json
+import logging
+import shutil
+from pathlib import Path
+from typing import Dict, Optional, Tuple
+
+import gzip
+
+
+class ReferenceNTFSImage:
+ """Handler for reference NTFS filesystem images used in E2E tests."""
+
+ def __init__(self, image_path: str = "tests/data/reference_ntfs.img.gz"):
+ self.image_path = Path(image_path)
+ # Remove both .img and .gz to get metadata path
+ if self.image_path.suffix == '.gz':
+ self.metadata_path = self.image_path.with_suffix('').with_suffix('.json')
+ else:
+ self.metadata_path = self.image_path.with_suffix('.json')
+ self.logger = logging.getLogger(__name__)
+
+ def exists(self) -> bool:
+ """Check if the reference image exists."""
+ print(self.image_path, self.metadata_path)
+ return self.image_path.exists() and self.metadata_path.exists()
+
+ def is_compressed(self) -> bool:
+ """Check if the reference image is compressed (e.g., .img.gz)."""
+ return self.image_path.suffix == '.gz'
+
+ def _compute_file_hash(self, filepath: Path, compressed: bool = False) -> str:
+ """Compute SHA256 hash of a file."""
+ sha256_hash = hashlib.sha256()
+ opener = gzip.open if compressed else open
+ with opener(filepath, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ sha256_hash.update(chunk)
+ return sha256_hash.hexdigest()
+
+ def _compute_directory_hash(self, directory: Path) -> Dict[str, str]:
+ """Compute hashes for all files in a directory recursively."""
+ file_hashes = {}
+
+ for filepath in directory.rglob('*'):
+ if filepath.is_file():
+ relative_path = filepath.relative_to(directory)
+ file_hashes[str(relative_path)] = self._compute_file_hash(filepath)
+
+ return file_hashes
+
+ def validate(self) -> Tuple[bool, Optional[str]]:
+ """Validate that the reference image is up-to-date and uncorrupted.
+
+ Returns:
+ (is_valid, error_message)
+ """
+ if not self.exists():
+ return False, f"Reference image not found: {self.image_path}"
+
+ try:
+ # Load metadata
+ with open(self.metadata_path, 'r') as f:
+ metadata = json.load(f)
+
+ # Validate image hash
+ current_image_hash = self._compute_file_hash(self.image_path, self.is_compressed())
+ expected_image_hash = metadata.get('image_hash')
+
+ if current_image_hash != expected_image_hash:
+ return False, f"Image hash mismatch: expected {expected_image_hash}, got {current_image_hash}"
+
+ # Validate source files hash (to detect if reference files changed)
+ reference_files_dir = Path("tests/data/reference_files")
+ if reference_files_dir.exists():
+ current_files_hash = self._compute_directory_hash(reference_files_dir)
+ expected_files_hash = metadata.get('reference_files_hashes', {})
+
+ if current_files_hash != expected_files_hash:
+ return False, "Reference files have changed, image needs to be rebuilt"
+
+ return True, None
+
+ except Exception as e:
+ return False, f"Validation error: {e}"
+
+ def get_expected_files(self) -> Dict[str, str]:
+ """Get the expected file hashes from the reference image metadata.
+
+ Returns:
+ Dictionary mapping relative file paths to their SHA256 hashes
+ """
+ if not self.metadata_path.exists():
+ return {}
+
+ try:
+ with open(self.metadata_path, 'r') as f:
+ metadata = json.load(f)
+ return metadata.get('reference_files_hashes', {})
+ except Exception as e:
+ self.logger.error(f"Failed to load metadata: {e}")
+ return {}
+
+ def get_reference_files_dir(self) -> Path:
+ """Get the directory containing the original reference files."""
+ return self.metadata_path.parent / "reference_files"
+
+ def copy_to_temp(self, temp_path: Path) -> None:
+ """Copy the reference image to a temporary location for testing.
+
+ Args:
+ temp_path: Path where to copy the image
+ """
+ if not self.exists():
+ raise FileNotFoundError(f"Reference image not found: {self.image_path}")
+
+ if self.is_compressed():
+ with gzip.open(self.image_path, 'rb') as f_in, open(temp_path, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ else:
+ shutil.copy2(self.image_path, temp_path)
+ self.logger.debug(f"Copied reference image to {temp_path}")
+
+ def get_info(self) -> Dict:
+ """Get information about the reference image.
+
+ Returns:
+ Dictionary with image metadata
+ """
+ if not self.metadata_path.exists():
+ return {}
+
+ try:
+ with open(self.metadata_path, 'r') as f:
+ return json.load(f)
+ except Exception as e:
+ self.logger.error(f"Failed to load metadata: {e}")
+ return {}
+
+
+def ensure_reference_image() -> ReferenceNTFSImage:
+ """Ensure reference NTFS image exists and is valid.
+
+ Returns:
+ ReferenceNTFSImage instance
+
+ Raises:
+ FileNotFoundError: If image doesn't exist
+ ValueError: If image is corrupted or outdated
+ """
+ ref_image = ReferenceNTFSImage()
+
+ if not ref_image.exists():
+ raise FileNotFoundError(
+ "Reference NTFS image not found. Please run: "
+ "sudo python tools/build_reference_ntfs.py"
+ )
+
+ is_valid, error = ref_image.validate()
+ if not is_valid:
+ raise ValueError(
+ f"Reference NTFS image validation failed: {error}. "
+ "Please rebuild with: sudo python tools/build_reference_ntfs.py"
+ )
+
+ return ref_image
diff --git a/tests/run_tests.py b/tests/run_tests.py
new file mode 100644
index 0000000..84a523b
--- /dev/null
+++ b/tests/run_tests.py
@@ -0,0 +1,125 @@
+"""Test runner and configuration for RecuperaBit test suite."""
+
+import unittest
+import sys
+import os
+import logging
+from pathlib import Path
+
+# Add the project root to the Python path
+project_root = Path(__file__).parent.parent
+sys.path.insert(0, str(project_root))
+
+# Import test modules
+from tests.test_ntfs_unit import *
+from tests.test_ntfs_e2e import *
+from tests.test_integration import *
+
+
+def create_test_suite():
+ """Create and return the complete test suite."""
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+
+ # Add unit tests
+ suite.addTests(loader.loadTestsFromModule(sys.modules['tests.test_ntfs_unit']))
+
+ # Add integration tests
+ suite.addTests(loader.loadTestsFromModule(sys.modules['tests.test_integration']))
+
+ # Add E2E tests (these may be skipped if tools are not available)
+ suite.addTests(loader.loadTestsFromModule(sys.modules['tests.test_ntfs_e2e']))
+
+ return suite
+
+
+def run_unit_tests_only():
+ """Run only unit tests (fast, no external dependencies)."""
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+
+ suite.addTests(loader.loadTestsFromModule(sys.modules['tests.test_ntfs_unit']))
+ suite.addTests(loader.loadTestsFromModule(sys.modules['tests.test_integration']))
+
+ runner = unittest.TextTestRunner(verbosity=2)
+ result = runner.run(suite)
+ return result.wasSuccessful()
+
+
+def run_e2e_tests_only():
+ """Run only end-to-end tests (slower, requires system tools)."""
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+
+ suite.addTests(loader.loadTestsFromModule(sys.modules['tests.test_ntfs_e2e']))
+
+ runner = unittest.TextTestRunner(verbosity=2)
+ result = runner.run(suite)
+ return result.wasSuccessful()
+
+
+def run_all_tests():
+ """Run the complete test suite."""
+ suite = create_test_suite()
+ runner = unittest.TextTestRunner(verbosity=2)
+ result = runner.run(suite)
+ return result.wasSuccessful()
+
+
+def main():
+ """Main test runner with command line options."""
+ import argparse
+
+ parser = argparse.ArgumentParser(description='RecuperaBit Test Runner')
+ parser.add_argument('--unit', action='store_true',
+ help='Run only unit tests (fast)')
+ parser.add_argument('--e2e', action='store_true',
+ help='Run only end-to-end tests (requires system tools)')
+ parser.add_argument('--integration', action='store_true',
+ help='Run only integration tests')
+ parser.add_argument('--verbose', '-v', action='store_true',
+ help='Verbose logging output')
+ parser.add_argument('--debug', action='store_true',
+ help='Debug level logging')
+
+ args = parser.parse_args()
+
+ # Set up logging
+ log_level = logging.WARNING
+ if args.verbose:
+ log_level = logging.INFO
+ if args.debug:
+ log_level = logging.DEBUG
+
+ logging.basicConfig(
+ level=log_level,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ )
+
+ # Run selected tests
+ success = True
+
+ if args.unit:
+ print("Running unit tests only...")
+ success = run_unit_tests_only()
+ elif args.e2e:
+ print("Running end-to-end tests only...")
+ success = run_e2e_tests_only()
+ elif args.integration:
+ print("Running integration tests only...")
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+ suite.addTests(loader.loadTestsFromModule(sys.modules['tests.test_integration']))
+ runner = unittest.TextTestRunner(verbosity=2)
+ result = runner.run(suite)
+ success = result.wasSuccessful()
+ else:
+ print("Running complete test suite...")
+ success = run_all_tests()
+
+ # Exit with appropriate code
+ sys.exit(0 if success else 1)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/test_integration.py b/tests/test_integration.py
new file mode 100644
index 0000000..3e3e806
--- /dev/null
+++ b/tests/test_integration.py
@@ -0,0 +1,347 @@
+"""Integration tests for RecuperaBit logic and utilities."""
+
+import unittest
+import tempfile
+import os
+from unittest.mock import Mock, patch
+from io import BytesIO
+
+from recuperabit.logic import SparseList, approximate_matching
+from recuperabit.utils import merge, sectors, unpack
+
+
+class TestSparseListIntegration(unittest.TestCase):
+ """Integration tests for SparseList with NTFS components."""
+
+ def test_sparse_list_with_mft_references(self):
+ """Test SparseList with MFT-like reference patterns."""
+ # Simulate MFT record references
+ mft_refs = {
+ 0: 0, # Root directory points to itself
+ 16: 0, # System file points to root
+ 32: 16, # File in system directory
+ 48: 0, # Another root-level file
+ 64: 48, # File in subdirectory
+ 80: 48, # Another file in same subdirectory
+ }
+
+ sparse_list = SparseList(mft_refs)
+
+ # Test basic operations
+ self.assertEqual(sparse_list[0], 0)
+ self.assertEqual(sparse_list[16], 0)
+ self.assertEqual(sparse_list[64], 48)
+ self.assertIsNone(sparse_list[24]) # Gap
+
+ # Test length
+ self.assertEqual(len(sparse_list), 81)
+
+ # Test iteration gives keys, not all indices
+ keys = list(sparse_list)
+ expected_keys = [0, 16, 32, 48, 64, 80]
+ self.assertEqual(keys, expected_keys)
+
+ def test_sparse_list_large_gaps(self):
+ """Test SparseList with large gaps (common in fragmented filesystems)."""
+ fragmented_refs = {
+ 100: 0,
+ 5000: 100,
+ 10000: 5000,
+ 50000: 10000,
+ }
+
+ sparse_list = SparseList(fragmented_refs)
+
+ # Should handle large indices efficiently
+ self.assertEqual(sparse_list[100], 0)
+ self.assertEqual(sparse_list[5000], 100)
+ self.assertEqual(sparse_list[50000], 10000)
+
+ # Large gaps should return None
+ self.assertIsNone(sparse_list[1000])
+ self.assertIsNone(sparse_list[25000])
+
+
+class TestApproximateMatching(unittest.TestCase):
+ """Test approximate matching functionality."""
+
+ def test_approximate_matching_perfect_match(self):
+ """Test approximate matching with perfect match."""
+ # Create text (haystack) and pattern (needle)
+ text_data = {i: i // 4 for i in range(0, 100, 4)} # Every 4th position
+ pattern_data = {i: i // 4 for i in range(0, 20, 4)} # First 5 elements
+
+ text_list = SparseList(text_data)
+ pattern_list = SparseList(pattern_data)
+
+ # Should find match at position 0
+ result = approximate_matching(text_list, pattern_list, 0, k=3)
+
+ self.assertIsNotNone(result)
+ positions, count, percentage = result
+ self.assertIn(0, positions)
+ self.assertGreater(percentage, 0.8) # High match percentage
+
+ def test_approximate_matching_shifted_pattern(self):
+ """Test approximate matching with shifted pattern."""
+ # Create text and pattern with some overlap
+ text_data = {i: i % 5 for i in range(0, 100, 4)} # Pattern repeating every 5
+ pattern_data = {i: i % 5 for i in range(0, 20, 4)} # Same pattern but shorter
+
+ text_list = SparseList(text_data)
+ pattern_list = SparseList(pattern_data)
+
+ # Should find matches at multiple positions
+ result = approximate_matching(text_list, pattern_list, 50, k=1)
+
+ if result is not None:
+ positions, count, percentage = result
+ # positions is a set, not a list, and contains actual match positions
+ self.assertIsInstance(positions, set)
+ self.assertGreater(len(positions), 0)
+ else:
+ # If no exact match found, that's also acceptable for this pattern
+ self.assertIsNone(result)
+
+ def test_approximate_matching_no_match(self):
+ """Test approximate matching with no good match."""
+ # Create text and completely different pattern
+ text_data = {i: 1 for i in range(0, 100, 4)} # All 1s
+ pattern_data = {i: 2 for i in range(0, 20, 4)} # All 2s
+
+ text_list = SparseList(text_data)
+ pattern_list = SparseList(pattern_data)
+
+ # Should not find good match
+ result = approximate_matching(text_list, pattern_list, 0, k=3)
+
+ if result is not None:
+ positions, count, percentage = result
+ self.assertLess(percentage, 0.1) # Very low match percentage
+
+
+class TestUtilityFunctions(unittest.TestCase):
+ """Test utility functions."""
+
+ def test_merge_function(self):
+ """Test the merge function."""
+ from recuperabit.fs.core_types import Partition, File
+
+ # Create mock scanner
+ class MockScanner:
+ pass
+ scanner = MockScanner()
+
+ # Create partition objects with files
+ part1 = Partition('TEST', 0, scanner)
+ part2 = Partition('TEST', 0, scanner)
+
+ # Add files to partitions
+ file1 = File(1, 'file1.txt', 100)
+ file2 = File(2, 'file2.txt', 200)
+ file3 = File(3, 'file3.txt', 300)
+ file4 = File(4, 'file4.txt', 400)
+
+ part1.add_file(file1)
+ part1.add_file(file2)
+ part2.add_file(file3)
+ part2.add_file(file4)
+
+ # Test merge
+ merge(part1, part2)
+
+ # part1 should now contain files from both
+ self.assertIn(1, part1.files)
+ self.assertIn(2, part1.files)
+ self.assertIn(3, part1.files)
+ self.assertIn(4, part1.files)
+ self.assertEqual(len(part1.files), 4)
+
+ def test_merge_with_conflicts(self):
+ """Test merge function with conflicting keys."""
+ from recuperabit.fs.core_types import Partition, File
+
+ # Create mock scanner
+ class MockScanner:
+ pass
+ scanner = MockScanner()
+
+ # Create partition objects
+ part1 = Partition('TEST', 0, scanner)
+ part2 = Partition('TEST', 0, scanner)
+
+ # Add conflicting files (same index)
+ file1_ghost = File(1, 'file1_ghost.txt', 100, is_ghost=True)
+ file1_real = File(1, 'file1_real.txt', 100, is_ghost=False)
+ file2 = File(2, 'file2.txt', 200)
+ file3 = File(3, 'file3.txt', 300)
+
+ part1.add_file(file1_ghost)
+ part1.add_file(file2)
+ part2.add_file(file1_real)
+ part2.add_file(file3)
+
+ merge(part1, part2)
+
+ # part1 should replace ghost with real file
+ self.assertIn(1, part1.files)
+ self.assertIn(2, part1.files)
+ self.assertIn(3, part1.files)
+ # The ghost file should be replaced by the real file
+ self.assertFalse(part1.files[1].is_ghost)
+
+ def test_sectors_function(self):
+ """Test the sectors function."""
+ # Create test data
+ test_data = b'A' * 512 + b'B' * 512 + b'C' * 512 # 3 sectors
+ test_file = BytesIO(test_data)
+
+ # Test reading single sector
+ result = sectors(test_file, 0, 1)
+ self.assertEqual(result, b'A' * 512)
+
+ # Test reading multiple sectors
+ result = sectors(test_file, 1, 2)
+ self.assertEqual(result, b'B' * 512 + b'C' * 512)
+
+ # Test reading with byte granularity
+ result = sectors(test_file, 256, 512, 1) # 512 bytes starting at byte 256
+ expected = b'A' * 256 + b'B' * 256
+ self.assertEqual(result, expected)
+
+ def test_sectors_out_of_bounds(self):
+ """Test sectors function with out-of-bounds access."""
+ test_data = b'A' * 512 # Only 1 sector
+ test_file = BytesIO(test_data)
+
+ # Try to read beyond file
+ result = sectors(test_file, 1, 1)
+ self.assertEqual(result, b'') # Should return empty bytes
+
+ def test_unpack_function(self):
+ """Test the unpack function with simple format."""
+ # Create test data
+ test_data = b'\x01\x02\x03\x04\x05\x06\x07\x08'
+
+ # Create format specification: [(label, (formatter, lower, higher)), ...]
+ test_format = [
+ ('first_byte', ('i', 0, 0)), # Single byte at position 0
+ ('two_bytes', ('2i', 1, 2)), # Two bytes from position 1-2
+ ('last_four', ('4i', 4, 7)) # Four bytes from position 4-7
+ ]
+
+ result = unpack(test_data, test_format)
+
+ # Check that we get expected structure
+ self.assertIn('first_byte', result)
+ self.assertIn('two_bytes', result)
+ self.assertIn('last_four', result)
+
+ def test_unpack_insufficient_data(self):
+ """Test unpack function with insufficient data."""
+ # Create short test data
+ test_data = b'\x01\x02'
+
+ # Format that requires more data than available
+ test_format = [
+ ('valid_data', ('i', 0, 1)), # Valid range
+ ('out_of_bounds', ('i', 5, 8)) # Tries to read beyond data
+ ]
+
+ # Should handle gracefully, setting None for missing data
+ result = unpack(test_data, test_format)
+
+ # Should have valid data for first field
+ self.assertIn('valid_data', result)
+ # Should handle out of bounds gracefully
+ self.assertIn('out_of_bounds', result)
+
+ def test_unpack_insufficient_data(self):
+ """Test unpack function with insufficient data."""
+ # Create short test data
+ test_data = b'\x01\x02'
+
+ # Format that requires more data than available
+ test_format = [
+ ('valid_data', ('i', 0, 1)), # Valid range
+ ('out_of_bounds', ('i', 5, 8)) # Tries to read beyond data
+ ]
+
+ # Should handle gracefully, setting None for missing data
+ result = unpack(test_data, test_format)
+
+ # Should have valid data for first field
+ self.assertIn('valid_data', result)
+ # Should handle out of bounds gracefully
+ self.assertIn('out_of_bounds', result)
+
+
+class TestNTFSIntegration(unittest.TestCase):
+ """Integration tests combining multiple NTFS components."""
+
+ def test_mft_indx_relationship(self):
+ """Test the relationship between MFT and INDX records."""
+ # Simulate finding related MFT and INDX records
+ mft_positions = {100, 200, 300, 400} # MFT record positions
+ indx_positions = {1000, 2000, 3000} # INDX record positions
+
+ # Simulate INDX records pointing to MFT records
+ indx_references = {
+ 1000: {'parent': 100, 'children': {200, 300}},
+ 2000: {'parent': 200, 'children': {400}},
+ 3000: {'parent': 300, 'children': set()},
+ }
+
+ # Create SparseList for INDX relationships
+ indx_list = SparseList({pos: info['parent'] for pos, info in indx_references.items()})
+
+ # Verify relationships
+ self.assertEqual(indx_list[1000], 100) # INDX at 1000 points to MFT 100
+ self.assertEqual(indx_list[2000], 200) # INDX at 2000 points to MFT 200
+
+ # Test that we can find directory structure
+ root_mft = 100
+ subdirs = [pos for pos, info in indx_references.items() if info['parent'] == root_mft]
+ self.assertEqual(len(subdirs), 1)
+ self.assertEqual(subdirs[0], 1000)
+
+ # Test children relationships
+ children_of_200 = indx_references[2000]['children']
+ self.assertEqual(children_of_200, {400})
+
+ def test_partition_boundary_detection(self):
+ """Test partition boundary detection logic."""
+ # Simulate MFT pattern for boundary detection
+ base_pattern = {10: 100, 20: 100, 30: 200, 40: 200} # Cluster -> MFT record
+
+ # Test different sectors per cluster values
+ for sec_per_clus in [1, 2, 4, 8]:
+ # Convert cluster pattern to sector pattern
+ sector_pattern = {
+ cluster * sec_per_clus: mft_record
+ for cluster, mft_record in base_pattern.items()
+ }
+
+ pattern_list = SparseList(sector_pattern)
+
+ # Simulate text list (found INDX records)
+ text_data = {}
+ for sector in range(0, 400):
+ if sector in sector_pattern:
+ text_data[sector + 1000] = sector_pattern[sector] # Offset by 1000
+
+ text_list = SparseList(text_data)
+
+ # Test approximate matching for boundary detection
+ mft_address = 1000 # Assumed MFT start
+ result = approximate_matching(text_list, pattern_list, mft_address + min(sector_pattern.keys()), k=2)
+
+ if result is not None:
+ positions, count, percentage = result
+ # Should find at least one potential boundary
+ self.assertGreater(len(positions), 0)
+ self.assertGreater(percentage, 0.1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_ntfs_e2e.py b/tests/test_ntfs_e2e.py
new file mode 100644
index 0000000..ad0bdc6
--- /dev/null
+++ b/tests/test_ntfs_e2e.py
@@ -0,0 +1,251 @@
+"""End-to-end tests for RecuperaBit NTFS recovery.
+
+This module uses pre-built reference NTFS images to test complete recovery workflows.
+"""
+
+import unittest
+import tempfile
+import os
+import shutil
+import hashlib
+from pathlib import Path
+from typing import Dict
+import logging
+
+from recuperabit.fs.ntfs import NTFSPartition, NTFSScanner
+from tests.reference_image import ensure_reference_image
+import main # Import main module to access interpret function
+
+
+class TestNTFSE2E(unittest.TestCase):
+ """End-to-end tests for NTFS recovery."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Set up class-level fixtures."""
+ # Ensure reference image exists and is valid
+ try:
+ cls.ref_image = ensure_reference_image()
+ logging.info(f"Using reference NTFS image: {cls.ref_image.image_path}")
+ except (FileNotFoundError, ValueError) as e:
+ raise unittest.SkipTest(f"Reference image not available: {e}")
+
+ # Get expected file hashes from reference image
+ cls.expected_files = cls.ref_image.get_expected_files()
+
+ if not cls.expected_files:
+ raise unittest.SkipTest("No expected files in reference image metadata")
+
+ logging.info(f"Reference image contains {len(cls.expected_files)} test files")
+
+ # Set up temp directory for working files
+ cls.test_dir = tempfile.mkdtemp(prefix='recuperabit_e2e_')
+ cls.recovery_dir = os.path.join(cls.test_dir, 'recovered')
+ os.makedirs(cls.recovery_dir, exist_ok=True)
+
+ logging.basicConfig(level=logging.DEBUG)
+
+ @classmethod
+ def tearDownClass(cls):
+ """Clean up class-level fixtures."""
+ # Clean up test directory
+ if hasattr(cls, 'test_dir') and os.path.exists(cls.test_dir):
+ shutil.rmtree(cls.test_dir)
+
+ def setUp(self):
+ """Set up test fixtures."""
+ # Create a temporary copy of the reference image for this test
+ self.image_path = os.path.join(self.test_dir, f'test_ntfs_{id(self)}.img')
+ self.ref_image.copy_to_temp(Path(self.image_path))
+
+ def tearDown(self):
+ """Clean up test fixtures."""
+ # Clean up the temporary image copy
+ if hasattr(self, 'image_path') and os.path.exists(self.image_path):
+ os.remove(self.image_path)
+
+ def _scan_image_with_scanner(self, scanner_class: type[NTFSScanner]) -> Dict[int, NTFSPartition]:
+ """Scan the image with the given scanner class."""
+ # Keep file handle open and return it along with partitions
+ img_file = open(self.image_path, 'rb')
+ scanner = scanner_class(img_file)
+
+ # Feed sectors to scanner
+ sector_size = 512
+ sector_index = 0
+
+ while True:
+ img_file.seek(sector_index * sector_size)
+ sector = img_file.read(sector_size)
+
+ if len(sector) < sector_size:
+ break
+
+ result = scanner.feed(sector_index, sector)
+ if result:
+ logging.debug(f"Found {result} at sector {sector_index}")
+
+ sector_index += 1
+
+ # Get partitions
+ partitions = scanner.get_partitions()
+ # Store the file handle so it doesn't get closed
+ self._img_file = img_file
+ return partitions
+
+ def _close_image_file(self):
+ """Close the image file handle."""
+ if hasattr(self, '_img_file') and self._img_file:
+ self._img_file.close()
+ self._img_file = None
+
+ def _recover_files_from_partition(self, partition: NTFSPartition, partition_id: int) -> Dict[str, bytes]:
+ """Recover files from a partition using high-level interpret function with proper hierarchy."""
+ # Create temporary recovery directory
+ recovery_dir = os.path.join(self.test_dir, f'recovered_partition_{partition_id}')
+ os.makedirs(recovery_dir, exist_ok=True)
+
+ # Create shorthands structure like main.py
+ parts = {0: partition} # Simple mapping for our single partition
+ shorthands = [(0, 0)] # (index, partition_key) pairs
+
+ try:
+ # Use the high-level interpret function to restore the root directory
+ # This will properly handle filesystem hierarchy, directories, etc.
+ main.interpret('restore', ['0', '5'], parts, shorthands, recovery_dir) # '5' is typically the root directory
+
+ # Collect all recovered files and their content
+ recovered_files = {}
+
+ # Walk through the recovered directory structure
+ partition_dir = os.path.join(recovery_dir, 'Partition0', 'Root')
+ if os.path.exists(partition_dir):
+ for root, dirs, files in os.walk(partition_dir):
+ for file in files:
+ file_path = os.path.join(root, file)
+ # Get relative path from partition directory
+ relative_path = os.path.relpath(file_path, partition_dir)
+
+ try:
+ with open(file_path, 'rb') as f:
+ content = f.read()
+ recovered_files[relative_path] = content
+ logging.info(f"Recovered file: {relative_path} ({len(content)} bytes)")
+ except Exception as e:
+ logging.error(f"Error reading recovered file {relative_path}: {e}")
+
+ return recovered_files
+
+ except Exception as e:
+ logging.error(f"Error during recovery: {e}")
+ return {}
+ finally:
+ # Clean up recovery directory
+ if os.path.exists(recovery_dir):
+ shutil.rmtree(recovery_dir, ignore_errors=True)
+
+ def _compare_files(self, original_hashes: Dict[str, str],
+ recovered_files: Dict[str, bytes]) -> Dict[str, bool]:
+ """Compare original and recovered files, handling path normalization."""
+ results = {}
+
+ # Normalize recovered file paths by removing Root/ prefix
+
+ print(f"DEBUG: Expected files: {list(original_hashes.keys())}")
+ print(f"DEBUG: Normalized recovered files: {list(recovered_files.keys())}")
+
+ expected_recovered_files = [filename for filename in list(recovered_files.keys()) if filename in original_hashes.keys()]
+ print(f"DEBUG: Matching recovered files: {expected_recovered_files} ({len(expected_recovered_files)} / {len(original_hashes)})")
+
+ # Check how many files were recovered successfully
+ for filename, original_hash in original_hashes.items():
+ if filename in recovered_files:
+ recovered_file = recovered_files[filename]
+ recovered_hash = hashlib.sha256(recovered_file).hexdigest()
+ results[filename] = (original_hash == recovered_hash)
+ if results[filename]:
+ logging.info(f"✓ {filename}: Recovery successful ({len(recovered_file)} bytes)")
+ else:
+ logging.error(f"✗ {filename}: Hash mismatch! Expected: {original_hash}, Got: {recovered_hash}")
+ # Print first 64 bytes of recovered content vs the original content for debugging
+ logging.error(f" Recovered content (first 64 bytes): {recovered_file[:64]}")
+ with open(self.ref_image.get_reference_files_dir() / filename, 'rb') as original_file:
+ original_content = original_file.read(64)
+ logging.error(f" Original content (first 64 bytes): {original_content[:64]}")
+ else:
+ results[filename] = False
+ logging.error(f"✗ {filename}: File not recovered")
+
+ return results
+
+ def test_basic_ntfs_recovery(self):
+ """Test basic NTFS file recovery using reference image."""
+ print(f"DEBUG: Using reference NTFS image at {self.image_path}")
+
+ try:
+ # Test recovery with standard scanner
+ partitions = self._scan_image_with_scanner(NTFSScanner)
+ self.assertGreater(len(partitions), 0, "No NTFS partitions found")
+
+ # Recover files from the LARGEST partition (most likely to contain user data)
+ if not partitions:
+ self.fail("No NTFS partitions found")
+
+ # Find the largest partition by number of files (user data indicator)
+ largest_partition_id = None
+ largest_partition = None
+ max_files = 0
+
+ print(f"DEBUG: Found {len(partitions)} partitions:")
+ for partition_id, partition in partitions.items():
+ file_count = len(partition.files) if hasattr(partition, 'files') else 0
+ print(f" Partition {partition_id}: {file_count} files, offset {partition.offset}")
+
+ if file_count > max_files:
+ max_files = file_count
+ largest_partition_id = partition_id
+ largest_partition = partition
+
+ if largest_partition is None:
+ self.fail("No partition with files found")
+
+ print(f"DEBUG: Processing largest partition {largest_partition_id} with {max_files} files at offset {largest_partition.offset}")
+
+ # Recover files from the largest partition only
+ all_recovered_files = self._recover_files_from_partition(largest_partition, largest_partition_id)
+
+ for filename, content in all_recovered_files.items():
+ print(f"DEBUG: Recovered file '{filename}' with content size {len(content)} bytes")
+
+ # Compare results using expected files from reference image
+ comparison = self._compare_files(self.expected_files, all_recovered_files)
+
+ # Check that at least some files were recovered correctly
+ successful_recoveries = sum(1 for success in comparison.values() if success)
+ total_files = len(self.expected_files)
+
+ self.assertGreater(successful_recoveries, 0, "No files recovered successfully")
+
+ # We expect most files to be recovered (allowing for some edge cases)
+ recovery_rate = successful_recoveries / total_files
+ self.assertAlmostEqual(recovery_rate, 1.0,
+ f"Low recovery rate: {recovery_rate:.2%} ({successful_recoveries}/{total_files})")
+
+ # Log success for visibility
+ print(f"SUCCESS: Hierarchical recovery rate {recovery_rate:.2%} ({successful_recoveries}/{total_files})")
+ print(f"✅ All {total_files} files found with correct filesystem hierarchy!")
+ print(f"✅ High-level recovery APIs working correctly!")
+ print(f"✅ Largest partition selection working!")
+ finally:
+ # Always close the image file handle
+ self._close_image_file()
+
+
+if __name__ == '__main__':
+ # Set up logging
+ logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(levelname)s - %(message)s'
+ )
+
+ unittest.main(verbosity=2)
diff --git a/tests/test_ntfs_unit.py b/tests/test_ntfs_unit.py
new file mode 100644
index 0000000..a35d24e
--- /dev/null
+++ b/tests/test_ntfs_unit.py
@@ -0,0 +1,297 @@
+"""Unit tests for NTFS parsing functions and core types."""
+
+import unittest
+from unittest.mock import Mock
+
+# Import the modules under test
+from recuperabit.fs.ntfs import (
+ NTFSFile, NTFSPartition,
+ NTFSScanner, best_name, _apply_fixup_values
+)
+from recuperabit.logic import SparseList
+
+
+class TestNTFSParsing(unittest.TestCase):
+ """Test NTFS parsing functions."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ # Create a mock MFT entry for testing
+ self.mock_mft_entry = bytearray(1024) # 1KB MFT entry
+ # FILE signature
+ self.mock_mft_entry[0:4] = b'FILE'
+ # Fixup offset at position 4-6 (little endian)
+ self.mock_mft_entry[4:6] = (48).to_bytes(2, 'little')
+ # Number of fixup entries at position 6-8
+ self.mock_mft_entry[6:8] = (2).to_bytes(2, 'little')
+ # First attribute offset at position 20-22
+ self.mock_mft_entry[20:22] = (56).to_bytes(2, 'little')
+ # MFT record size allocated at position 28-32
+ self.mock_mft_entry[28:32] = (1024).to_bytes(4, 'little')
+ # Record number at position 44-48
+ self.mock_mft_entry[44:48] = (42).to_bytes(4, 'little')
+
+ # Mock INDX entry
+ self.mock_indx_entry = bytearray(4096) # 4KB INDX entry
+ # INDX signature
+ self.mock_indx_entry[0:4] = b'INDX'
+ # Fixup offset at position 4-6
+ self.mock_indx_entry[4:6] = (40).to_bytes(2, 'little')
+ # Number of fixup entries at position 6-8
+ self.mock_indx_entry[6:8] = (8).to_bytes(2, 'little')
+
+ def test_apply_fixup_values(self):
+ """Test the fixup values application."""
+ # Create a test entry with 3 sectors (1536 bytes) to test both fixups
+ entry = bytearray(1536)
+ header = {
+ 'off_fixup': 48,
+ 'n_entries': 3 # 1 original + 2 fixup entries
+ }
+
+ # Set up fixup array at offset 48
+ entry[48:50] = b'\xAA\xBB' # Original value (not used in replacement)
+ entry[50:52] = b'\xCC\xDD' # First replacement (for sector 1)
+ entry[52:54] = b'\xEE\xFF' # Second replacement (for sector 2)
+
+ # Set sectors to have the original values that need fixing
+ # sector_size = 512, so positions are 512*i - 2
+ entry[510:512] = b'\x00\x00' # End of first sector (512*1 - 2)
+ entry[1022:1024] = b'\x00\x00' # End of second sector (512*2 - 2)
+
+ _apply_fixup_values(header, entry)
+
+ # Check that fixup was applied correctly
+ self.assertEqual(entry[510:512], b'\xCC\xDD')
+ self.assertEqual(entry[1022:1024], b'\xEE\xFF')
+
+ def test_best_name(self):
+ """Test the best_name function."""
+ # Test with NTFS namespace (preferred)
+ entries = [(1, 'short.txt'), (3, 'long_filename.txt')]
+ self.assertEqual(best_name(entries), 'long_filename.txt')
+
+ # Test without NTFS namespace
+ entries = [(1, 'short.txt'), (2, 'dos_name.txt')]
+ self.assertEqual(best_name(entries), 'short.txt')
+
+ # Test with empty list
+ self.assertIsNone(best_name([]))
+
+ # Test with empty name
+ entries = [(3, '')]
+ self.assertIsNone(best_name(entries))
+
+
+class TestNTFSFile(unittest.TestCase):
+ """Test NTFSFile class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.mock_parsed = {
+ 'record_n': 42,
+ 'flags': 0x01, # Not deleted
+ 'attributes': {
+ '$FILE_NAME': [{
+ 'content': {
+ 'namespace': 3,
+ 'name': 'test_file.txt',
+ 'name_length': 13,
+ 'parent_entry': 5
+ }
+ }],
+ '$DATA': [{
+ 'name': '',
+ 'real_size': 1024,
+ 'non_resident': False,
+ 'content_size': 1024
+ }],
+ '$STANDARD_INFORMATION': {
+ 'content': {
+ 'modification_time': 132000000000000000,
+ 'access_time': 132000000000000000,
+ 'creation_time': 132000000000000000
+ }
+ }
+ }
+ }
+
+ def test_ntfs_file_creation(self):
+ """Test NTFSFile creation with valid data."""
+ file_obj = NTFSFile(self.mock_parsed, 12345)
+
+ self.assertEqual(file_obj.index, 42)
+ self.assertEqual(file_obj.name, 'test_file.txt')
+ self.assertEqual(file_obj.size, 1024)
+ self.assertFalse(file_obj.is_directory)
+ self.assertFalse(file_obj.is_deleted)
+ self.assertFalse(file_obj.is_ghost)
+ self.assertEqual(file_obj.parent, 5)
+ self.assertEqual(file_obj.ads, '')
+
+ def test_ntfs_file_with_ads(self):
+ """Test NTFSFile creation with alternate data stream."""
+ file_obj = NTFSFile(self.mock_parsed, 12345, ads='stream1')
+
+ self.assertEqual(file_obj.index, '42:stream1')
+ self.assertEqual(file_obj.name, 'test_file.txt:stream1')
+ self.assertEqual(file_obj.ads, 'stream1')
+
+ def test_ntfs_file_directory(self):
+ """Test NTFSFile creation for directory."""
+ self.mock_parsed['flags'] = 0x03 # Directory flag
+ file_obj = NTFSFile(self.mock_parsed, 12345)
+
+ self.assertTrue(file_obj.is_directory)
+
+ def test_ntfs_file_deleted(self):
+ """Test NTFSFile creation for deleted file."""
+ self.mock_parsed['flags'] = 0x00 # Deleted flag
+ file_obj = NTFSFile(self.mock_parsed, 12345)
+
+ self.assertTrue(file_obj.is_deleted)
+
+ def test_ntfs_file_ghost(self):
+ """Test NTFSFile creation for ghost file."""
+ file_obj = NTFSFile(self.mock_parsed, 12345, is_ghost=True)
+
+ self.assertTrue(file_obj.is_ghost)
+
+ def test_ntfs_file_ignore(self):
+ """Test NTFSFile ignore logic."""
+ # Test $Bad file
+ self.mock_parsed['record_n'] = 8
+ file_obj = NTFSFile(self.mock_parsed, 12345, ads='$Bad')
+ file_obj.index = '8:$Bad'
+ self.assertTrue(file_obj.ignore())
+
+ # Test $UsnJrnl file
+ self.mock_parsed['record_n'] = 100
+ file_obj = NTFSFile(self.mock_parsed, 12345, ads='$J')
+ file_obj.parent = 11
+ self.assertTrue(file_obj.ignore())
+
+
+class TestNTFSPartition(unittest.TestCase):
+ """Test NTFSPartition class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.scanner = Mock(spec=NTFSScanner)
+
+ def test_ntfs_partition_creation(self):
+ """Test NTFSPartition creation."""
+ partition = NTFSPartition(self.scanner, 12345)
+
+ self.assertEqual(partition.fs_type, 'NTFS')
+ self.assertEqual(partition.root_id, 5)
+ self.assertEqual(partition.mft_pos, 12345)
+ self.assertIsNone(partition.sec_per_clus)
+ self.assertIsNone(partition.mftmirr_pos)
+
+ def test_ntfs_partition_additional_repr(self):
+ """Test NTFSPartition additional representation."""
+ partition = NTFSPartition(self.scanner, 12345)
+ partition.sec_per_clus = 8
+ partition.mftmirr_pos = 67890
+
+ additional = partition.additional_repr()
+ expected = [
+ ('Sec/Clus', 8),
+ ('MFT offset', 12345),
+ ('MFT mirror offset', 67890)
+ ]
+ self.assertEqual(additional, expected)
+
+
+class TestNTFSScanner(unittest.TestCase):
+ """Test NTFSScanner class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.scanner = NTFSScanner(Mock())
+
+ def test_feed_boot_sector(self):
+ """Test feeding a boot sector."""
+ boot_sector = b'NTFS' + b'\x00' * 506 + b'\x55\xAA'
+ result = self.scanner.feed(0, boot_sector)
+
+ self.assertEqual(result, 'NTFS boot sector')
+ self.assertIn(0, self.scanner.found_boot)
+
+ def test_feed_file_record(self):
+ """Test feeding a FILE record."""
+ file_record = b'FILE' + b'\x00' * 508
+ result = self.scanner.feed(100, file_record)
+
+ self.assertEqual(result, 'NTFS file record')
+ self.assertIn(100, self.scanner.found_file)
+
+ def test_feed_baad_record(self):
+ """Test feeding a BAAD record."""
+ baad_record = b'BAAD' + b'\x00' * 508
+ result = self.scanner.feed(200, baad_record)
+
+ self.assertEqual(result, 'NTFS file record')
+ self.assertIn(200, self.scanner.found_file)
+
+ def test_feed_indx_record(self):
+ """Test feeding an INDX record."""
+ indx_record = b'INDX' + b'\x00' * 508
+ result = self.scanner.feed(300, indx_record)
+
+ self.assertEqual(result, 'NTFS index record')
+ self.assertIn(300, self.scanner.found_indx)
+
+ def test_feed_unknown_sector(self):
+ """Test feeding an unknown sector."""
+ unknown_sector = b'UNKN' + b'\x00' * 508
+ result = self.scanner.feed(400, unknown_sector)
+
+ self.assertIsNone(result)
+ self.assertNotIn(400, self.scanner.found_boot)
+ self.assertNotIn(400, self.scanner.found_file)
+ self.assertNotIn(400, self.scanner.found_indx)
+
+ def test_most_likely_sec_per_clus(self):
+ """Test most_likely_sec_per_clus function."""
+ self.scanner.found_spc = [8, 8, 8, 4, 4, 16]
+ result = self.scanner.most_likely_sec_per_clus()
+
+ # Should return 8 first (most common), then others
+ self.assertEqual(result[0], 8)
+ self.assertIn(4, result)
+ self.assertIn(16, result)
+
+class TestSparseList(unittest.TestCase):
+ """Test SparseList functionality."""
+
+ def test_sparse_list_creation(self):
+ """Test SparseList creation and basic operations."""
+ data = {10: 'ten', 20: 'twenty', 30: 'thirty'}
+ sparse_list = SparseList(data)
+
+ self.assertEqual(len(sparse_list), 31) # 0 to 30
+ self.assertEqual(sparse_list[10], 'ten')
+ self.assertEqual(sparse_list[20], 'twenty')
+ self.assertEqual(sparse_list[30], 'thirty')
+ self.assertIsNone(sparse_list[15]) # Gap
+
+ def test_sparse_list_iteration(self):
+ """Test SparseList iteration."""
+ data = {1: 'one', 3: 'three', 5: 'five'}
+ sparse_list = SparseList(data)
+
+ # SparseList should iterate over keys, not all values
+ keys = list(sparse_list)
+ expected_keys = [1, 3, 5]
+ self.assertEqual(keys, expected_keys)
+
+ # Test itervalues method for getting values
+ values = list(sparse_list.itervalues())
+ expected_values = ['one', 'three', 'five']
+ self.assertEqual(values, expected_values)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tools/build_reference_ntfs.py b/tools/build_reference_ntfs.py
new file mode 100755
index 0000000..0537893
--- /dev/null
+++ b/tools/build_reference_ntfs.py
@@ -0,0 +1,289 @@
+#!/usr/bin/env python3
+"""Build reference NTFS filesystem image for E2E tests.
+
+This script creates a reference NTFS filesystem image by:
+1. Creating a loop-mounted NTFS filesystem
+2. Copying reference test files to it
+3. Unmounting and saving the image
+4. Computing checksums for both the image and source files
+5. Storing metadata for validation
+
+Usage:
+ python tools/build_reference_ntfs.py [--size SIZE_MB] [--output OUTPUT_PATH]
+"""
+
+import argparse
+import hashlib
+import json
+import logging
+import os
+import shutil
+import subprocess
+import tempfile
+from pathlib import Path
+from typing import Dict, List
+
+import gzip
+
+class NTFSImageBuilder:
+ """Builder for reference NTFS filesystem images."""
+
+ def __init__(self, size_mb: int = 100, output_path: str = None, compress: bool = True):
+ self.size_mb = size_mb
+ self.output_path = output_path or "tests/data/reference_ntfs.img"
+ self.metadata_path = self.output_path.replace('.img', '.json')
+ self.reference_files_dir = Path("tests/data/reference_files")
+ self.compress = compress
+
+ # Set up logging
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
+ self.logger = logging.getLogger(__name__)
+
+ def _check_requirements(self) -> None:
+ """Check if required tools are available."""
+ required_tools = ['mkfs.ntfs', 'losetup', 'mount', 'umount', 'sync']
+ missing_tools = []
+
+ for tool in required_tools:
+ if shutil.which(tool) is None:
+ missing_tools.append(tool)
+
+ if missing_tools:
+ raise RuntimeError(f"Missing required tools: {', '.join(missing_tools)}")
+
+ # Check if running as root (needed for loop devices)
+ if os.geteuid() != 0:
+ raise RuntimeError("This script must be run as root to create loop devices")
+
+ def _compute_file_hash(self, filepath: Path) -> str:
+ """Compute SHA256 hash of a file."""
+ sha256_hash = hashlib.sha256()
+ with open(filepath, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ sha256_hash.update(chunk)
+ return sha256_hash.hexdigest()
+
+ def _compute_directory_hash(self, directory: Path) -> Dict[str, str]:
+ """Compute hashes for all files in a directory recursively."""
+ file_hashes = {}
+
+ for filepath in directory.rglob('*'):
+ if filepath.is_file():
+ relative_path = filepath.relative_to(directory)
+ file_hashes[str(relative_path)] = self._compute_file_hash(filepath)
+ self.logger.info(f"Hashed {relative_path}: {file_hashes[str(relative_path)][:16]}...")
+
+ return file_hashes
+
+ def _create_empty_image(self, image_path: Path) -> None:
+ """Create an empty disk image file."""
+ self.logger.info(f"Creating {self.size_mb}MB empty image at {image_path}")
+
+ with open(image_path, 'wb') as f:
+ f.seek(self.size_mb * 1024 * 1024 - 1)
+ f.write(b'\0')
+
+ def _format_ntfs(self, image_path: Path) -> None:
+ """Format the image as NTFS."""
+ self.logger.info("Formatting image as NTFS...")
+
+ cmd = ['mkfs.ntfs', '-F', '-f', str(image_path)]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ raise RuntimeError(f"Failed to format NTFS: {result.stderr}")
+
+ def _setup_loop_device(self, image_path: Path) -> str:
+ """Set up loop device for the image."""
+ self.logger.info("Setting up loop device...")
+
+ cmd = ['losetup', '--find', '--show', str(image_path)]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ raise RuntimeError(f"Failed to set up loop device: {result.stderr}")
+
+ loop_device = result.stdout.strip()
+ self.logger.info(f"Created loop device: {loop_device}")
+ return loop_device
+
+ def _cleanup_loop_device(self, loop_device: str) -> None:
+ """Clean up loop device."""
+ self.logger.info(f"Cleaning up loop device: {loop_device}")
+
+ cmd = ['losetup', '-d', loop_device]
+ subprocess.run(cmd, capture_output=True, text=True)
+
+ def _mount_filesystem(self, loop_device: str, mount_point: Path) -> None:
+ """Mount the NTFS filesystem."""
+ self.logger.info(f"Mounting {loop_device} at {mount_point}")
+
+ cmd = ['mount', '-t', 'ntfs-3g', '-o', 'sync', loop_device, str(mount_point)]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ raise RuntimeError(f"Failed to mount filesystem: {result.stderr}")
+
+ def _unmount_filesystem(self, mount_point: Path) -> None:
+ """Unmount the filesystem."""
+ self.logger.info(f"Unmounting {mount_point}")
+
+ cmd = ['umount', str(mount_point)]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ self.logger.warning(f"Failed to unmount cleanly: {result.stderr}")
+
+ def _copy_files(self, mount_point: Path) -> None:
+ """Copy reference files to the mounted filesystem."""
+ self.logger.info("Copying reference files to mounted filesystem...")
+
+ if not self.reference_files_dir.exists():
+ raise RuntimeError(f"Reference files directory not found: {self.reference_files_dir}")
+
+ # Copy all files and directories
+ for item in self.reference_files_dir.iterdir():
+ dest = mount_point / item.name
+
+ if item.is_file():
+ shutil.copy2(item, dest)
+ self.logger.info(f"Copied file: {item.name}")
+ elif item.is_dir():
+ shutil.copytree(item, dest)
+ self.logger.info(f"Copied directory: {item.name}")
+
+ # Create alternate data stream (if supported)
+ try:
+ ads_file = mount_point / "file_with_ads.txt"
+ if ads_file.exists():
+ # Try to create ADS using attr command if available
+ if shutil.which('attr'):
+ cmd = ['attr', '-s', 'stream1', '-V', 'ADS content', str(ads_file)]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ if result.returncode == 0:
+ self.logger.info("Created alternate data stream")
+ else:
+ self.logger.warning("Failed to create ADS, not supported")
+ else:
+ self.logger.warning("attr tool not available, skipping ADS creation")
+ except Exception as e:
+ self.logger.warning(f"Could not create alternate data stream: {e}")
+
+ def _save_metadata(self, image_path: Path, file_hashes: Dict[str, str]) -> None:
+ """Save metadata about the image and source files."""
+ self.logger.info("Computing image hash and saving metadata...")
+
+ # Compute image hash
+ image_hash = self._compute_file_hash(image_path)
+
+ # Get image size
+ image_size = image_path.stat().st_size
+
+ metadata = {
+ "version": "1.0",
+ "created_by": "build_reference_ntfs.py",
+ "size_mb": self.size_mb,
+ "image_size_bytes": image_size,
+ "image_hash": image_hash,
+ "reference_files_hashes": file_hashes,
+ "file_count": len(file_hashes),
+ "notes": "Reference NTFS filesystem for RecuperaBit E2E tests"
+ }
+
+ with open(self.metadata_path, 'w') as f:
+ json.dump(metadata, f, indent=2, sort_keys=True)
+
+ self.logger.info(f"Saved metadata to {self.metadata_path}")
+ self.logger.info(f"Image hash: {image_hash}")
+
+ def build(self) -> None:
+ """Build the reference NTFS image."""
+ self.logger.info("Starting NTFS reference image build...")
+
+ # Check requirements
+ self._check_requirements()
+
+ # Prepare paths
+ image_path = Path(self.output_path)
+ image_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Compute hashes of source files
+ self.logger.info("Computing hashes of reference files...")
+ file_hashes = self._compute_directory_hash(self.reference_files_dir)
+
+ loop_device = None
+ temp_mount = None
+
+ try:
+ # Create and format image
+ self._create_empty_image(image_path)
+ self._format_ntfs(image_path)
+
+ # Set up loop device
+ loop_device = self._setup_loop_device(image_path)
+
+ # Create temporary mount point
+ temp_mount = Path(tempfile.mkdtemp(prefix="ntfs_build_"))
+
+ # Mount, copy files, unmount
+ self._mount_filesystem(loop_device, temp_mount)
+ self._copy_files(temp_mount)
+
+ # Sync to ensure all data is written
+ subprocess.run(['sync', str(temp_mount)], check=True)
+
+ self._unmount_filesystem(temp_mount)
+
+ # Save metadata
+ self._save_metadata(image_path, file_hashes)
+
+ self.logger.info(f"Successfully created reference NTFS image: {image_path}")
+ self.logger.info(f"Image size: {image_path.stat().st_size / (1024*1024):.1f} MB")
+
+ # Compress image
+ if self.compress:
+ self.logger.info("Compressing image with gzip...")
+ with open(image_path, 'rb') as f_in, gzip.open(f"{image_path}.gz", 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ image_path.unlink() # Remove uncompressed image
+
+ finally:
+ # Clean up
+ if loop_device:
+ self._cleanup_loop_device(loop_device)
+
+ if temp_mount and temp_mount.exists():
+ shutil.rmtree(temp_mount)
+
+
+def main():
+ """Main entry point."""
+ parser = argparse.ArgumentParser(description="Build reference NTFS image for E2E tests")
+ parser.add_argument('--size', type=int, default=100,
+ help='Image size in MB (default: 100)')
+ parser.add_argument('--output', type=str,
+ default='tests/data/reference_ntfs.img',
+ help='Output image path (default: tests/data/reference_ntfs.img)')
+
+ args = parser.parse_args()
+
+ builder = NTFSImageBuilder(size_mb=args.size, output_path=args.output)
+
+ try:
+ builder.build()
+ print(f"\n✓ Success! Reference NTFS image created at: {args.output}")
+ print(f"✓ Metadata saved at: {args.output.replace('.img', '.json')}")
+ print("\nNext steps:")
+ print("1. Add the .img file to Git LFS: git lfs track '*.img'")
+ print("2. Commit both the image and metadata files")
+ print("3. The E2E tests will now use this reference image")
+
+ except Exception as e:
+ print(f"\n✗ Error: {e}")
+ return 1
+
+ return 0
+
+
+if __name__ == '__main__':
+ exit(main())