diff --git a/dagshub/auth/token_auth.py b/dagshub/auth/token_auth.py index 31ec32ac..7ba3a70a 100644 --- a/dagshub/auth/token_auth.py +++ b/dagshub/auth/token_auth.py @@ -37,7 +37,7 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: def can_renegotiate(self): # Env var tokens cannot renegotiate, every other token type can - return not type(self._token) is EnvVarDagshubToken + return type(self._token) is not EnvVarDagshubToken def renegotiate_token(self): if not self._token_storage.is_valid_token(self._token, self._host): diff --git a/dagshub/data_engine/annotation/importer.py b/dagshub/data_engine/annotation/importer.py index c19212de..bf6009d6 100644 --- a/dagshub/data_engine/annotation/importer.py +++ b/dagshub/data_engine/annotation/importer.py @@ -1,22 +1,34 @@ from difflib import SequenceMatcher from pathlib import Path, PurePosixPath, PurePath from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Literal, Optional, Union, Sequence, Mapping, Callable, List - -from dagshub_annotation_converter.converters.cvat import load_cvat_from_zip +from typing import TYPE_CHECKING, Dict, Iterable, Literal, Optional, Union, Sequence, Mapping, Callable, List, Tuple + +from dagshub_annotation_converter.converters.coco import load_coco_from_file +from dagshub_annotation_converter.converters.cvat import ( + CVATAnnotations, + load_cvat_from_fs, + load_cvat_from_zip, + load_cvat_from_xml_file, +) +from dagshub_annotation_converter.converters.mot import load_mot_from_dir, load_mot_from_fs, load_mot_from_zip +from dagshub_annotation_converter.formats.mot.context import MOTContext from dagshub_annotation_converter.converters.yolo import load_yolo_from_fs +from dagshub_annotation_converter.converters.label_studio_video import video_ir_to_ls_video_task from dagshub_annotation_converter.formats.label_studio.task import LabelStudioTask from dagshub_annotation_converter.formats.yolo import YoloContext +from dagshub_annotation_converter.ir.base import IRTaskAnnotation from dagshub_annotation_converter.ir.image.annotations.base import IRAnnotationBase +from dagshub_annotation_converter.ir.video import IRVideoAnnotationTrack, IRVideoBBoxFrameAnnotation, IRVideoSequence from dagshub.common.api import UserAPI from dagshub.common.api.repo import PathNotFoundError from dagshub.common.helpers import log_message +from dagshub.data_engine.annotation.video import build_video_sequence_from_annotations if TYPE_CHECKING: from dagshub.data_engine.model.datasource import Datasource -AnnotationType = Literal["yolo", "cvat"] +AnnotationType = Literal["yolo", "cvat", "coco", "mot", "cvat_video"] AnnotationLocation = Literal["repo", "disk"] @@ -57,7 +69,11 @@ def __init__( 'Add `yolo_type="bbox"|"segmentation"|pose"` to the arguments.' ) - def import_annotations(self) -> Mapping[str, Sequence[IRAnnotationBase]]: + @property + def is_video_format(self) -> bool: + return self.annotations_type in ("mot", "cvat_video") + + def import_annotations(self) -> Mapping[str, Sequence[IRTaskAnnotation]]: # Double check that the annotation file exists if self.load_from == "disk": if not self.annotations_file.exists(): @@ -78,20 +94,171 @@ def import_annotations(self) -> Mapping[str, Sequence[IRAnnotationBase]]: # Convert annotations log_message("Loading annotations...") - annotation_dict: Mapping[str, Sequence[IRAnnotationBase]] + annotation_dict: Mapping[str, Sequence[IRTaskAnnotation]] if self.annotations_type == "yolo": annotation_dict, _ = load_yolo_from_fs( annotation_type=self.additional_args["yolo_type"], meta_file=annotations_file ) elif self.annotations_type == "cvat": - annotation_dict = load_cvat_from_zip(annotations_file) + if annotations_file.is_dir(): + annotation_dict = self._key_cvat_fs_annotations_by_filename(load_cvat_from_fs(annotations_file)) + else: + result = load_cvat_from_zip(annotations_file) + if self._determine_cvat_annotation(result) == "video": + annotation_dict = self._key_video_annotations_by_filename(result) + else: + annotation_dict = result + elif self.annotations_type == "coco": + annotation_dict, _ = load_coco_from_file(annotations_file) + elif self.annotations_type == "mot": + mot_kwargs = {} + if "image_width" in self.additional_args: + mot_kwargs["image_width"] = self.additional_args["image_width"] + if "image_height" in self.additional_args: + mot_kwargs["image_height"] = self.additional_args["image_height"] + if "video_name" in self.additional_args: + mot_kwargs["video_file"] = self.additional_args["video_name"] + if annotations_file.is_dir(): + # Detect whether this is an fs layout (videos/ + labels/) or a single MOT dir + video_dir_name = self.additional_args.get("video_dir_name", "videos") + label_dir_name = self.additional_args.get("label_dir_name", "labels") + if (annotations_file / label_dir_name).is_dir(): + mot_results = load_mot_from_fs( + annotations_file, + image_width=mot_kwargs.get("image_width"), + image_height=mot_kwargs.get("image_height"), + video_dir_name=video_dir_name, + label_dir_name=label_dir_name, + ) + annotation_dict = self._key_mot_fs_annotations_by_filename(mot_results) + else: + video_anns, _ = load_mot_from_dir(annotations_file, **mot_kwargs) + annotation_dict = self._key_video_annotations_by_filename(video_anns) + elif annotations_file.suffix == ".zip": + video_anns, _ = load_mot_from_zip(annotations_file, **mot_kwargs) + annotation_dict = self._key_video_annotations_by_filename(video_anns) + else: + video_anns, _ = load_mot_from_dir(annotations_file, **mot_kwargs) + annotation_dict = self._key_video_annotations_by_filename(video_anns) + elif self.annotations_type == "cvat_video": + cvat_kwargs = {} + if "image_width" in self.additional_args: + cvat_kwargs["image_width"] = self.additional_args["image_width"] + if "image_height" in self.additional_args: + cvat_kwargs["image_height"] = self.additional_args["image_height"] + if annotations_file.is_dir(): + raw = load_cvat_from_fs(annotations_file, **cvat_kwargs) + annotation_dict = self._key_cvat_fs_annotations_by_filename(raw) + elif annotations_file.suffix == ".zip": + result = load_cvat_from_zip(annotations_file, **cvat_kwargs) + if self._determine_cvat_annotation(result) == "video": + annotation_dict = self._key_video_annotations_by_filename(result) + else: + annotation_dict = result + else: + result = load_cvat_from_xml_file(annotations_file, **cvat_kwargs) + if self._determine_cvat_annotation(result) == "video": + annotation_dict = self._key_video_annotations_by_filename(result) + else: + annotation_dict = result + else: + raise ValueError(f"Unsupported annotation type: {self.annotations_type}") return annotation_dict + @staticmethod + def _determine_cvat_annotation(result: CVATAnnotations) -> Literal["video", "image"]: + """Determine whether a CVAT loader result contains video or image annotations.""" + if isinstance(result, IRVideoSequence): + return "video" + return "image" + + def _key_video_annotations_by_filename( + self, + video_data: CVATAnnotations, + ) -> Dict[str, Sequence[IRTaskAnnotation]]: + """Flatten video annotations into a single entry keyed by the source video path.""" + video_name = self.additional_args.get("video_name") + if isinstance(video_data, IRVideoSequence): + sequence_name = self._resolve_video_annotation_key(video_data.filename, fallback=video_name) + return {sequence_name: video_data.to_annotations()} + + if video_name is None: + video_name = self._first_video_annotation_filename(video_data.values()) + if video_name is None: + video_name = self.annotations_file.stem + + all_anns: List[IRTaskAnnotation] = [] + for frame_anns in video_data.values(): + all_anns.extend(frame_anns) + return {video_name: all_anns} + + def _key_cvat_fs_annotations_by_filename( + self, fs_annotations: Mapping[str, CVATAnnotations] + ) -> Dict[str, Sequence[IRTaskAnnotation]]: + flattened: Dict[str, List[IRTaskAnnotation]] = {} + for rel_path, result in fs_annotations.items(): + if isinstance(result, IRVideoSequence): + video_key = self._resolve_video_annotation_key(result.filename, fallback=str(rel_path)) + flattened.setdefault(video_key, []) + flattened[video_key].extend(result.to_annotations()) + elif isinstance(result, dict): # CVATImageAnnotations: Dict[str, Sequence[IRImageAnnotationBase]] + for filename, anns in result.items(): + flattened.setdefault(filename, []) + flattened[filename].extend(anns) + return flattened + + def _key_mot_fs_annotations_by_filename( + self, + fs_annotations: Mapping[Path, Tuple[IRVideoSequence, MOTContext]], + ) -> Dict[str, Sequence[IRTaskAnnotation]]: + flattened: Dict[str, List[IRTaskAnnotation]] = {} + for rel_path, (sequence, _) in fs_annotations.items(): + rel_path_str = self._stringify_video_path(rel_path) + sequence_name = self.annotations_file.stem if rel_path_str in (None, "", ".") else rel_path_str + sequence_name = self._resolve_video_annotation_key(sequence.filename, fallback=sequence_name) + flattened.setdefault(sequence_name, []) + flattened[sequence_name].extend(sequence.to_annotations()) + return flattened + + @staticmethod + def _stringify_video_path(path: Optional[Union[str, Path, PurePath]]) -> Optional[str]: + if path is None: + return None + if isinstance(path, (Path, PurePath)): + return path.as_posix() + return str(path).replace("\\", "/") + + def _resolve_video_annotation_key( + self, + filename: Optional[Union[str, Path, PurePath]], + fallback: Optional[str] = None, + ) -> str: + resolved = self._stringify_video_path(filename) + if resolved not in (None, "", "."): + return resolved + + resolved_fallback = self._stringify_video_path(fallback) + if resolved_fallback not in (None, "", "."): + return resolved_fallback + return self.annotations_file.stem + + @classmethod + def _first_video_annotation_filename( + cls, + frame_groups: Iterable[Sequence[IRAnnotationBase]], + ) -> Optional[str]: + for frame_anns in frame_groups: + for ann in frame_anns: + ann_filename = cls._stringify_video_path(ann.filename) + if ann_filename not in (None, "", "."): + return ann_filename + return None + def download_annotations(self, dest_dir: Path): log_message("Downloading annotations from repository") repoApi = self.ds.source.repoApi - if self.annotations_type == "cvat": + if self.annotations_type in ("cvat", "cvat_video"): # Download just the annotation file repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) elif self.annotations_type == "yolo": @@ -104,6 +271,11 @@ def download_annotations(self, dest_dir: Path): # Download the annotation data assert context.path is not None repoApi.download(self.annotations_file.parent / context.path, dest_dir, keep_source_prefix=True) + elif self.annotations_type == "mot": + repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) + elif self.annotations_type == "coco": + # Download just the annotation file + repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) @staticmethod def determine_load_location(ds: "Datasource", annotations_path: Union[str, Path]) -> AnnotationLocation: @@ -125,9 +297,9 @@ def determine_load_location(ds: "Datasource", annotations_path: Union[str, Path] def remap_annotations( self, - annotations: Mapping[str, Sequence[IRAnnotationBase]], + annotations: Mapping[str, Sequence[IRTaskAnnotation]], remap_func: Optional[Callable[[str], Optional[str]]] = None, - ) -> Mapping[str, Sequence[IRAnnotationBase]]: + ) -> Mapping[str, Sequence[IRTaskAnnotation]]: """ Remaps the filenames in the annotations to the datasource's data points. @@ -136,6 +308,9 @@ def remap_annotations( remap_func: Function that maps from an annotation path to a datapoint path. \ If None, we try to guess it by getting a datapoint and remapping that path """ + if not annotations: + return {} + if remap_func is None: first_ann = list(annotations.keys())[0] first_ann_filename = Path(first_ann).name @@ -153,8 +328,17 @@ def remap_annotations( ) continue for ann in anns: - assert ann.filename is not None - ann.filename = remap_func(ann.filename) + if isinstance(ann, IRVideoAnnotationTrack): + for track_ann in ann.annotations: + track_ann.filename = new_filename + continue + + if ann.filename is not None: + ann.filename = remap_func(ann.filename) + else: + if not self.is_video_format: + raise ValueError(f"Non-video annotation has no filename: {ann}") + ann.filename = new_filename remapped[new_filename] = anns return remapped @@ -284,10 +468,12 @@ def get_best_fit_datapoint_path(ann_path: str, datapoint_paths: List[str]) -> st raise ValueError(f"No good match found for annotation path {ann_path} in the datasource.") return best_match - def convert_to_ls_tasks(self, annotations: Mapping[str, Sequence[IRAnnotationBase]]) -> Mapping[str, bytes]: + def convert_to_ls_tasks(self, annotations: Mapping[str, Sequence[IRTaskAnnotation]]) -> Mapping[str, bytes]: """ Converts the annotations to Label Studio tasks. """ + if self.is_video_format: + return self._convert_to_ls_video_tasks(annotations) current_user_id = UserAPI.get_current_user(self.ds.source.repoApi.host).user_id tasks = {} for filename, anns in annotations.items(): @@ -296,3 +482,38 @@ def convert_to_ls_tasks(self, annotations: Mapping[str, Sequence[IRAnnotationBas t.add_ir_annotations(anns) tasks[filename] = t.model_dump_json().encode("utf-8") return tasks + + def _convert_to_ls_video_tasks( + self, annotations: Mapping[str, Sequence[IRTaskAnnotation]] + ) -> Mapping[str, bytes]: + """ + Converts video annotations to Label Studio video tasks. + """ + tasks = {} + for filename, anns in annotations.items(): + sequence = self._build_video_sequence(anns, filename) + if sequence is None: + continue + video_path = self.ds.source.raw_path(filename) + ls_task = video_ir_to_ls_video_task(sequence, video_path=video_path) + if ls_task is not None: + tasks[filename] = ls_task.model_dump_json().encode("utf-8") + return tasks + + @staticmethod + def _build_video_sequence( + annotations: Sequence[IRTaskAnnotation], + filename: str, + ) -> Optional[IRVideoSequence]: + tracks = [ann for ann in annotations if isinstance(ann, IRVideoAnnotationTrack)] + frame_annotations = [ann for ann in annotations if isinstance(ann, IRVideoBBoxFrameAnnotation)] + if frame_annotations: + tracks.extend(build_video_sequence_from_annotations(frame_annotations).tracks) + if not tracks: + return None + + sequence = IRVideoSequence.from_annotations(tracks=tracks, filename=filename) + sequence.resolved_video_width() + sequence.resolved_video_height() + sequence.resolved_sequence_length() + return sequence diff --git a/dagshub/data_engine/annotation/metadata.py b/dagshub/data_engine/annotation/metadata.py index 06f7bc28..a3e9638c 100644 --- a/dagshub/data_engine/annotation/metadata.py +++ b/dagshub/data_engine/annotation/metadata.py @@ -3,6 +3,7 @@ from dagshub_annotation_converter.formats.label_studio.task import LabelStudioTask, parse_ls_task from dagshub_annotation_converter.formats.yolo import YoloContext, import_lookup, import_yolo_result from dagshub_annotation_converter.formats.yolo.categories import Categories +from dagshub_annotation_converter.ir.base import IRAnnotationBase, IRTaskAnnotation from dagshub_annotation_converter.ir.image import ( CoordinateStyle, IRBBoxImageAnnotation, @@ -11,7 +12,8 @@ IRSegmentationImageAnnotation, IRSegmentationPoint, ) -from dagshub_annotation_converter.ir.image.annotations.base import IRAnnotationBase, IRImageAnnotationBase +from dagshub_annotation_converter.ir.image.annotations.base import IRImageAnnotationBase +from dagshub_annotation_converter.ir.video import IRVideoAnnotationTrack from dagshub.common.api import UserAPI from dagshub.common.helpers import log_message @@ -22,6 +24,8 @@ from dagshub.data_engine.model.datapoint import Datapoint +from dagshub_annotation_converter.formats.label_studio.videorectangle import VideoRectangleAnnotation + class AnnotationMetaDict(dict): def __init__(self, annotation: "MetadataAnnotations", *args, **kwargs): @@ -63,13 +67,13 @@ def __init__( self, datapoint: "Datapoint", field: str, - annotations: Optional[Sequence["IRAnnotationBase"]] = None, + annotations: Optional[Sequence["IRTaskAnnotation"]] = None, meta: Optional[Dict] = None, original_value: Optional[bytes] = None, ): self.datapoint = datapoint self.field = field - self.annotations: list["IRAnnotationBase"] + self.annotations: list["IRTaskAnnotation"] if annotations is None: annotations = [] self.annotations = list(annotations) @@ -94,12 +98,34 @@ def to_ls_task(self) -> Optional[bytes]: task = LabelStudioTask( user_id=UserAPI.get_current_user(self.datapoint.datasource.source.repoApi.host).user_id, ) - task.data["image"] = self.datapoint.download_url - # TODO: need to filter out non-image annotations here maybe? - task.add_ir_annotations(self.annotations) + if any(isinstance(ann, IRVideoAnnotationTrack) for ann in self.annotations): + task.data["video"] = self.datapoint.download_url + frames_count = self._get_video_frames_count() + for ann in self.annotations: + if isinstance(ann, IRVideoAnnotationTrack): + ls_ann = VideoRectangleAnnotation.from_ir_track(ann, frames_count=frames_count) + if ann.__pydantic_extra__ is not None: + ls_ann.__pydantic_extra__ = ann.__pydantic_extra__.copy() + task.add_annotation(ls_ann) + else: + task.add_ir_annotation(ann) + else: + task.data["image"] = self.datapoint.download_url + task.add_ir_annotations(self.annotations) task.meta.update(self.meta) return task.model_dump_json().encode("utf-8") + def _get_video_frames_count(self) -> Optional[int]: + max_frame: Optional[int] = None + for ann in self.annotations: + if not isinstance(ann, IRVideoAnnotationTrack): + continue + for track_ann in ann.annotations: + max_frame = track_ann.frame_number if max_frame is None else max(max_frame, track_ann.frame_number) + if max_frame is None: + return None + return max_frame + 1 + @property def value(self) -> Optional[bytes]: """ diff --git a/dagshub/data_engine/annotation/video.py b/dagshub/data_engine/annotation/video.py new file mode 100644 index 00000000..72587b6d --- /dev/null +++ b/dagshub/data_engine/annotation/video.py @@ -0,0 +1,41 @@ +from collections import defaultdict +from typing import Optional, Sequence + +from dagshub_annotation_converter.ir.video import ( + IRVideoFrameAnnotationBase, + IRVideoAnnotationTrack, + IRVideoSequence, +) + + +def build_video_sequence_from_annotations( + annotations: Sequence[IRVideoFrameAnnotationBase], + filename: Optional[str] = None, +) -> IRVideoSequence: + # Pre-group annotations into tracks (required by new from_annotations API) + by_track: dict[str, list[IRVideoFrameAnnotationBase]] = defaultdict(list) + for ann in annotations: + object_id = ann.imported_id + if object_id is None: + raise ValueError("Video annotation is missing an object identifier") + by_track[object_id].append(ann) + + tracks = [ + IRVideoAnnotationTrack.from_annotations(anns, object_id=str(tid)) + for tid, anns in by_track.items() + ] + + sequence = IRVideoSequence.from_annotations(tracks=tracks, filename=filename) + + if filename is not None: + for track in sequence.tracks: + for ann in track.annotations: + if ann.filename is None: + ann.filename = filename + + # resolved_* methods now cache results automatically + sequence.resolved_video_width() + sequence.resolved_video_height() + sequence.resolved_sequence_length() + + return sequence diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 94f78522..bbeab214 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -1668,6 +1668,16 @@ def import_annotations_from_files( Keyword Args: yolo_type: Type of YOLO annotations to import. Either ``bbox``, ``segmentation`` or ``pose``. + image_width: (MOT, CVAT video) Width of the video frames in pixels. \ + Used when the annotation file does not contain dimension metadata. + image_height: (MOT, CVAT video) Height of the video frames in pixels. \ + Used when the annotation file does not contain dimension metadata. + video_name: (MOT) Name/path of the video file these annotations belong to. \ + Used to key the resulting annotations when it cannot be inferred from the annotation file. + video_dir_name: (MOT filesystem layout) Name of the subdirectory containing video files. \ + Defaults to ``"videos"``. + label_dir_name: (MOT filesystem layout) Name of the subdirectory containing label files. \ + Defaults to ``"labels"``. Example to import segmentation annotations into an ``imported_annotations`` field, using YOLO information from an ``annotations.yaml`` file (can be local, or in the repo):: diff --git a/dagshub/data_engine/model/query_result.py b/dagshub/data_engine/model/query_result.py index b986b5c3..8a5f34ed 100644 --- a/dagshub/data_engine/model/query_result.py +++ b/dagshub/data_engine/model/query_result.py @@ -15,10 +15,17 @@ import dacite import dagshub_annotation_converter.converters.yolo import rich.progress +from dagshub_annotation_converter.converters.cvat import export_cvat_video_to_zip, export_cvat_videos_to_zips +from dagshub_annotation_converter.converters.mot import export_mot_sequences_to_dirs, export_mot_to_dir +from dagshub_annotation_converter.formats.mot import MOTContext +from dagshub_annotation_converter.converters.coco import export_to_coco_file +from dagshub_annotation_converter.formats.coco import CocoContext from dagshub_annotation_converter.formats.yolo import YoloContext from dagshub_annotation_converter.formats.yolo.categories import Categories from dagshub_annotation_converter.formats.yolo.common import ir_mapping +from dagshub_annotation_converter.ir.base import IRTaskAnnotation from dagshub_annotation_converter.ir.image import IRImageAnnotationBase +from dagshub_annotation_converter.ir.video import IRVideoAnnotationTrack, IRVideoFrameAnnotationBase, IRVideoSequence from pydantic import ValidationError from dagshub.auth import get_token @@ -31,6 +38,7 @@ from dagshub.common.util import lazy_load, multi_urljoin from dagshub.data_engine.annotation import MetadataAnnotations from dagshub.data_engine.annotation.metadata import ErrorMetadataAnnotations, UnsupportedMetadataAnnotations +from dagshub.data_engine.annotation.video import build_video_sequence_from_annotations from dagshub.data_engine.annotation.voxel_conversion import ( add_ls_annotations, add_voxel_annotations, @@ -768,8 +776,8 @@ def dp_url(dp: Datapoint): download_files(download_args, skip_if_exists=not redownload) return target_path - def _get_all_annotations(self, annotation_field: str) -> List[IRImageAnnotationBase]: - annotations = [] + def _get_all_annotations(self, annotation_field: str) -> List[IRTaskAnnotation]: + annotations: List[IRTaskAnnotation] = [] for dp in self.entries: if annotation_field in dp.metadata: if not hasattr(dp.metadata[annotation_field], "annotations"): @@ -778,6 +786,87 @@ def _get_all_annotations(self, annotation_field: str) -> List[IRImageAnnotationB annotations.extend(dp.metadata[annotation_field].annotations) return annotations + def _get_all_image_annotations(self, annotation_field: str) -> List[IRImageAnnotationBase]: + return [ann for ann in self._get_all_annotations(annotation_field) if isinstance(ann, IRImageAnnotationBase)] + + def _get_all_video_annotations(self, annotation_field: str) -> List[IRVideoFrameAnnotationBase]: + video_annotations: List[IRVideoFrameAnnotationBase] = [] + for ann in self._get_all_annotations(annotation_field): + if isinstance(ann, IRVideoFrameAnnotationBase): + video_annotations.append(ann) + elif isinstance(ann, IRVideoAnnotationTrack): + video_annotations.extend(ann.to_annotations()) + return video_annotations + + @staticmethod + def _annotations_to_sequences( + video_annotations: List[IRVideoFrameAnnotationBase], + ) -> List["IRVideoSequence"]: + """Group frame annotations into per-source video sequences.""" + by_source: Dict[str, List[IRVideoFrameAnnotationBase]] = {} + for ann in video_annotations: + filename = ann.filename or "" + by_source.setdefault(filename, []).append(ann) + + return [ + build_video_sequence_from_annotations(anns, filename=source_filename or None) + for source_filename, anns in by_source.items() + ] + + def _resolve_local_path(self, local_root: Path, repo_relative_filename: str) -> Optional[Path]: + """ + Resolves the local path of a downloaded file given its repo-relative filename. + + Tries the path directly under ``local_root`` first, then falls back to prepending + the datasource's source prefix (e.g. when files were downloaded with the prefix intact). + Returns ``None`` if the file is not found at either location. + """ + ann_path = Path(repo_relative_filename) + primary = local_root / ann_path + if primary.exists(): + return primary + source_prefix = Path(self.datasource.source.source_prefix) + with_prefix = local_root / source_prefix / ann_path + if with_prefix.exists(): + return with_prefix + return None + + def _resolve_export_dirs(self, download_dir: Path, media_dir_name: str) -> Tuple[Path, Path, Path]: + """ + Resolves the three directory paths for an export given a download root and a media subdirectory name. + + Strips any leading ``data/`` segment from the datasource's source prefix to avoid duplication, + then nests the media directory under ``/data//`` + (skipping the final segment if the prefix already ends with ``media_dir_name``). + + Returns ``(media_dir, labels_dir, dataset_root)`` where ``labels_dir`` is a sibling of ``media_dir``. + """ + data_root = download_dir / "data" + source_prefix = self.datasource.source.source_prefix + prefix_parts = source_prefix.parts + if prefix_parts and prefix_parts[0] == "data": + prefix_parts = prefix_parts[1:] + + media_dir = data_root + if prefix_parts: + media_dir = media_dir.joinpath(*prefix_parts) + if not prefix_parts or prefix_parts[-1] != media_dir_name: + media_dir = media_dir / media_dir_name + + dataset_root = media_dir.parent + labels_dir = dataset_root / "labels" + return media_dir, labels_dir, dataset_root + + def _resolve_annotation_field(self, annotation_field: Optional[str]) -> str: + if annotation_field is not None: + return annotation_field + annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()]) + if len(annotation_fields) == 0: + raise ValueError("No annotation fields found in the datasource") + annotation_field = annotation_fields[0] + log_message(f"Using annotations from field {annotation_field}") + return annotation_field + def export_as_yolo( self, download_dir: Optional[Union[str, Path]] = None, @@ -803,18 +892,13 @@ def export_as_yolo( Returns: The path to the YAML file with the metadata. Pass this path to ``YOLO.train()`` to train a model. """ - if annotation_field is None: - annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()]) - if len(annotation_fields) == 0: - raise ValueError("No annotation fields found in the datasource") - annotation_field = annotation_fields[0] - log_message(f"Using annotations from field {annotation_field}") + annotation_field = self._resolve_annotation_field(annotation_field) if download_dir is None: download_dir = Path("dagshub_export") download_dir = Path(download_dir) / "data" - annotations = self._get_all_annotations(annotation_field) + annotations = self._get_all_image_annotations(annotation_field) categories = Categories() if classes is not None: @@ -861,6 +945,236 @@ def export_as_yolo( log_message(f"Done! Saved YOLO Dataset, YAML file is at {yaml_path.absolute()}") return yaml_path + def export_as_mot( + self, + download_dir: Optional[Union[str, Path]] = None, + annotation_field: Optional[str] = None, + image_width: Optional[int] = None, + image_height: Optional[int] = None, + ) -> Path: + """ + Exports video annotations in MOT (Multiple Object Tracking) format. + + Single-video exports write a MOT sequence directory under ``output_dir/labels/``. + Multi-video exports write a dataset root compatible with + ``load_mot_from_fs()``:: + + output_dir/ + videos/ + labels/ + + Args: + download_dir: Where to export. Defaults to ``./dagshub_export`` + annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field. + image_width: Frame width. If None, inferred from annotations. + image_height: Frame height. If None, inferred from annotations. + + Returns: + Path to the exported MOT directory. + """ + annotation_field = self._resolve_annotation_field(annotation_field) + + if download_dir is None: + download_dir = Path("dagshub_export") + download_dir = Path(download_dir) + video_dir, labels_dir, dataset_root = self._resolve_export_dirs(download_dir, "videos") + + video_annotations = self._get_all_video_annotations(annotation_field) + if not video_annotations: + raise RuntimeError("No video annotations found to export") + + source_names = sorted({Path(ann.filename).name for ann in video_annotations if ann.filename}) + has_multiple_sources = len(source_names) > 1 + + log_message(f"Downloading videos into {video_dir}...") + annotated = QueryResult( + [dp for dp in self.entries if annotation_field in dp.metadata], + self.datasource, + self.fields, + ) + local_download_root = annotated.download_files(video_dir, keep_source_prefix=False) + + log_message("Exporting MOT annotations...") + sequences = self._annotations_to_sequences(video_annotations) + + if has_multiple_sources: + context = MOTContext() + context.video_width = image_width + context.video_height = image_height + export_mot_sequences_to_dirs(sequences, context, dataset_root) + result_path = dataset_root + else: + ref_filename = video_annotations[0].filename + video_file = self._resolve_local_path(local_download_root, ref_filename) + if video_file is None: + raise FileNotFoundError( + f"Could not find local downloaded video file for '{ref_filename}' " + f"under '{local_download_root}'." + ) + + context = MOTContext() + context.video_width = image_width + context.video_height = image_height + labels_dir.mkdir(parents=True, exist_ok=True) + single_name = Path(source_names[0]).stem if source_names else "sequence" + output_dir = labels_dir / single_name + result_path = export_mot_to_dir(sequences[0], context, output_dir, video_file=video_file) + + log_message(f"Done! Saved MOT annotations to {result_path.absolute()}") + return result_path + + def export_as_cvat_video( + self, + download_dir: Optional[Union[str, Path]] = None, + annotation_field: Optional[str] = None, + video_name: str = "video.mp4", + image_width: Optional[int] = None, + image_height: Optional[int] = None, + ) -> Path: + """ + Exports video annotations in CVAT video ZIP format. + + Args: + download_dir: Where to export. Defaults to ``./dagshub_export`` + annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field. + video_name: Name of the source video to embed in the XML metadata. + image_width: Frame width. If None, inferred from annotations. + image_height: Frame height. If None, inferred from annotations. + + Returns: + Path to the exported CVAT video ZIP file for single-video exports, + or output directory for multi-video exports. + """ + annotation_field = self._resolve_annotation_field(annotation_field) + + if download_dir is None: + download_dir = Path("dagshub_export") + download_dir = Path(download_dir) + video_dir, labels_dir, _ = self._resolve_export_dirs(download_dir, "videos") + + video_annotations = self._get_all_video_annotations(annotation_field) + if not video_annotations: + raise RuntimeError("No video annotations found to export") + + source_names = sorted({Path(ann.filename).name for ann in video_annotations if ann.filename}) + has_multiple_sources = len(source_names) > 1 + + log_message("Exporting CVAT video annotations...") + sequences = self._annotations_to_sequences(video_annotations) + + log_message(f"Downloading videos into {video_dir}...") + annotated = QueryResult( + [dp for dp in self.entries if annotation_field in dp.metadata], + self.datasource, + self.fields, + ) + local_download_root = annotated.download_files(video_dir, keep_source_prefix=False) + + if has_multiple_sources: + video_files: Optional[Dict[str, Union[str, Path]]] = None + if image_width is None or image_height is None: + video_files = {} + for ann_filename in {ann.filename for ann in video_annotations if ann.filename}: + local_video = self._resolve_local_path(local_download_root, ann_filename) + if local_video is None: + raise FileNotFoundError( + f"Could not find local downloaded video file for '{ann_filename}' " + f"under '{local_download_root}'." + ) + ann_path = Path(ann_filename) + video_files[ann_filename] = local_video + video_files[ann_path.name] = local_video + video_files[ann_path.stem] = local_video + + output_dir = labels_dir + output_dir.mkdir(parents=True, exist_ok=True) + export_cvat_videos_to_zips( + sequences, + output_dir, + image_width=image_width, + image_height=image_height, + video_files=video_files if video_files else None, + ) + result_path = output_dir + else: + ref_filename = next((a.filename for a in video_annotations), None) + if ref_filename is None: + raise FileNotFoundError("Missing annotation filename for single-video CVAT export.") + single_video_file = self._resolve_local_path(local_download_root, ref_filename) + if single_video_file is None: + raise FileNotFoundError( + f"Could not find local downloaded video file for '{ref_filename}' " + f"under '{local_download_root}'." + ) + + labels_dir.mkdir(parents=True, exist_ok=True) + if source_names: + output_name = f"{Path(source_names[0]).name}.zip" + else: + output_name = "annotations.zip" + output_path = labels_dir / output_name + result_path = export_cvat_video_to_zip( + sequences[0], + output_path, + video_name=video_name, + image_width=image_width, + image_height=image_height, + video_file=single_video_file, + ) + log_message(f"Done! Saved CVAT video annotations to {result_path.absolute()}") + return result_path + + def export_as_coco( + self, + download_dir: Optional[Union[str, Path]] = None, + annotation_field: Optional[str] = None, + output_filename: str = "annotations.json", + classes: Optional[Dict[int, str]] = None, + ) -> Path: + """ + Downloads the files and exports annotations in COCO format. + + Args: + download_dir: Where to download the files. Defaults to ``./dagshub_export`` + annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field. + output_filename: Name of the output COCO JSON file. Default is ``annotations.json``. + classes: Category mapping for the COCO dataset as ``{id: name}``. + If ``None``, categories will be inferred from the annotations. + + Returns: + Path to the exported COCO JSON file. + """ + annotation_field = self._resolve_annotation_field(annotation_field) + + if download_dir is None: + download_dir = Path("dagshub_export") + download_dir = Path(download_dir) + + annotations = self._get_all_annotations(annotation_field) + if not annotations: + raise RuntimeError("No annotations found to export") + + context = CocoContext() + if classes is not None: + categories = Categories() + for category_id, category_name in classes.items(): + categories.add(category_name, category_id) + context.categories = categories + + # Add the source prefix to all annotations + for ann in annotations: + ann.filename = os.path.join(self.datasource.source.source_prefix, ann.filename) + + image_download_path = download_dir / "data" + log_message("Downloading image files...") + self.download_files(image_download_path) + + output_path = download_dir / output_filename + log_message("Exporting COCO annotations...") + result_path = export_to_coco_file(annotations, output_path, context=context) + log_message(f"Done! Saved COCO annotations to {result_path.absolute()}") + return result_path + def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset": """ Creates a voxel51 dataset that can be used with\ diff --git a/setup.py b/setup.py index 6cdef855..eec96e69 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def get_version(rel_path: str) -> str: "python-dateutil", "boto3", "semver", - "dagshub-annotation-converter>=0.1.12", + "dagshub-annotation-converter>=0.2.0", ] extras_require = { diff --git a/tests/data_engine/annotation_import/test_annotation_parsing.py b/tests/data_engine/annotation_import/test_annotation_parsing.py index c04b0d51..133311c0 100644 --- a/tests/data_engine/annotation_import/test_annotation_parsing.py +++ b/tests/data_engine/annotation_import/test_annotation_parsing.py @@ -6,13 +6,15 @@ import pytest from dagshub_annotation_converter.ir.image import IRSegmentationImageAnnotation +from dagshub_annotation_converter.ir.video import CoordinateStyle, IRVideoBBoxFrameAnnotation from pytest import MonkeyPatch from dagshub.data_engine.annotation import MetadataAnnotations +from dagshub.data_engine.annotation.video import build_video_sequence_from_annotations from dagshub.data_engine.annotation.metadata import ErrorMetadataAnnotations, UnsupportedMetadataAnnotations from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags from dagshub.data_engine.model import datapoint, query_result -from dagshub.data_engine.model.datapoint import BlobDownloadError, BlobHashMetadata +from dagshub.data_engine.model.datapoint import BlobDownloadError, BlobHashMetadata, Datapoint from dagshub.data_engine.model.datasource import Datasource from dagshub.data_engine.model.query_result import QueryResult from tests.data_engine.util import add_metadata_field @@ -168,3 +170,44 @@ def test_nonexistent_annotation(ds_with_nonexistent_annotation): def test_blob_metadata_is_wrapped_from_backend(ds_with_document_annotation): qr = ds_with_document_annotation.all(load_documents=False, load_annotations=False) assert isinstance(qr[0].metadata[_annotation_field_name], BlobHashMetadata) + + +def test_video_tracks_to_ls_task_use_video_data_and_sequence_length(ds): + dp = Datapoint(datasource=ds, path="nested/video.mp4", datapoint_id=1, metadata={}) + frame_annotations = [ + IRVideoBBoxFrameAnnotation( + imported_id="1", + frame_number=0, + left=100.0, + top=150.0, + width=50.0, + height=80.0, + video_width=1920, + video_height=1080, + categories={"person": 1.0}, + coordinate_style=CoordinateStyle.DENORMALIZED, + ), + IRVideoBBoxFrameAnnotation( + imported_id="1", + frame_number=5, + left=110.0, + top=155.0, + width=50.0, + height=80.0, + video_width=1920, + video_height=1080, + categories={"person": 1.0}, + coordinate_style=CoordinateStyle.DENORMALIZED, + ), + ] + for ann in frame_annotations: + ann.filename = dp.path + + sequence = build_video_sequence_from_annotations(frame_annotations, filename=dp.path) + annotations = MetadataAnnotations(dp, _annotation_field_name, annotations=sequence.tracks) + + task = json.loads(annotations.to_ls_task()) + + assert task["data"]["video"] == dp.download_url + assert task["annotations"][0]["result"][0]["type"] == "videorectangle" + assert task["annotations"][0]["result"][0]["value"]["framesCount"] == sequence.sequence_length diff --git a/tests/data_engine/annotation_import/test_coco.py b/tests/data_engine/annotation_import/test_coco.py new file mode 100644 index 00000000..0db9cf8f --- /dev/null +++ b/tests/data_engine/annotation_import/test_coco.py @@ -0,0 +1,198 @@ +import datetime +import json +from unittest.mock import patch + +import pytest +from dagshub_annotation_converter.ir.image import ( + IRBBoxImageAnnotation, + CoordinateStyle, +) + +from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationsNotFoundError +from dagshub.data_engine.annotation.metadata import MetadataAnnotations +from dagshub.data_engine.client.models import MetadataSelectFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags +from dagshub.data_engine.model.datapoint import Datapoint +from dagshub.data_engine.model.query_result import QueryResult + + +# --- import --- + + +def test_import_coco_from_file(ds, tmp_path): + _write_coco(tmp_path, _make_coco_json()) + importer = AnnotationImporter(ds, "coco", tmp_path / "annotations.json", load_from="disk") + result = importer.import_annotations() + + assert "image1.jpg" in result + assert len(result["image1.jpg"]) == 1 + assert isinstance(result["image1.jpg"][0], IRBBoxImageAnnotation) + + +def test_import_coco_nonexistent_raises(ds, tmp_path): + importer = AnnotationImporter(ds, "coco", tmp_path / "nope.json", load_from="disk") + with pytest.raises(AnnotationsNotFoundError): + importer.import_annotations() + + +def test_coco_convert_to_ls_tasks(ds, tmp_path, mock_dagshub_auth): + importer = AnnotationImporter(ds, "coco", tmp_path / "ann.json", load_from="disk") + bbox = IRBBoxImageAnnotation( + filename="test.jpg", categories={"cat": 1.0}, + top=0.1, left=0.1, width=0.2, height=0.2, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.NORMALIZED, + ) + tasks = importer.convert_to_ls_tasks({"test.jpg": [bbox]}) + + assert "test.jpg" in tasks + task_json = json.loads(tasks["test.jpg"]) + assert "annotations" in task_json + assert len(task_json["annotations"]) > 0 + + +# --- _resolve_annotation_field --- + + +def test_resolve_explicit_field(ds): + qr = _make_qr(ds, [], ann_field="my_ann") + assert qr._resolve_annotation_field("explicit") == "explicit" + + +def test_resolve_auto_field(ds): + qr = _make_qr(ds, [], ann_field="my_ann") + assert qr._resolve_annotation_field(None) == "my_ann" + + +def test_resolve_no_fields_raises(ds): + qr = _make_qr(ds, [], ann_field=None) + with pytest.raises(ValueError, match="No annotation fields"): + qr._resolve_annotation_field(None) + + +def test_resolve_picks_alphabetically_first(ds): + fields = [] + for name in ["zebra_ann", "alpha_ann"]: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=name, + multiple=False, valueType=MetadataFieldType.BLOB, + name=name, tags={ReservedTags.ANNOTATION.value}, + )) + qr = QueryResult(datasource=ds, _entries=[], fields=fields) + assert qr._resolve_annotation_field(None) == "alpha_ann" + + +# --- export_as_coco --- + + +def test_export_coco_bbox_coordinates(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + ann = IRBBoxImageAnnotation( + filename="images/test.jpg", categories={"cat": 1.0}, + top=20.0, left=10.0, width=30.0, height=40.0, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + coco = json.loads(result.read_text()) + assert coco["annotations"][0]["bbox"] == [10.0, 20.0, 30.0, 40.0] + + +def test_export_coco_no_annotations_raises(ds, tmp_path): + dp = Datapoint(datasource=ds, path="test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with pytest.raises(RuntimeError, match="No annotations found"): + qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + +def test_export_coco_explicit_classes(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco( + download_dir=tmp_path, annotation_field="ann", classes={1: "cat", 2: "dog"} + ) + + coco = json.loads(result.read_text()) + assert "cat" in {c["name"] for c in coco["categories"]} + + +def test_export_coco_custom_filename(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco( + download_dir=tmp_path, annotation_field="ann", output_filename="custom.json" + ) + + assert result.name == "custom.json" + + +def test_export_coco_multiple_datapoints(ds, tmp_path): + dps = [] + for i, name in enumerate(["a.jpg", "b.jpg"]): + dp = Datapoint(datasource=ds, path=name, datapoint_id=i, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox(name)] + ) + dps.append(dp) + + qr = _make_qr(ds, dps, ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + coco = json.loads(result.read_text()) + assert len(coco["annotations"]) == 2 + assert len(coco["images"]) == 2 + + +# --- helpers --- + + +def _make_coco_json(): + return { + "categories": [{"id": 1, "name": "cat"}], + "images": [{"id": 1, "width": 640, "height": 480, "file_name": "image1.jpg"}], + "annotations": [{"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 30, 40]}], + } + + +def _write_coco(tmp_path, coco): + (tmp_path / "annotations.json").write_text(json.dumps(coco)) + + +def _make_image_bbox(filename="test.jpg") -> IRBBoxImageAnnotation: + return IRBBoxImageAnnotation( + filename=filename, categories={"cat": 1.0}, + top=20.0, left=10.0, width=30.0, height=40.0, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + + +def _make_qr(ds, datapoints, ann_field=None): + fields = [] + if ann_field: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=ann_field, + multiple=False, valueType=MetadataFieldType.BLOB, + name=ann_field, tags={ReservedTags.ANNOTATION.value}, + )) + return QueryResult(datasource=ds, _entries=datapoints, fields=fields) diff --git a/tests/data_engine/annotation_import/test_cvat_video.py b/tests/data_engine/annotation_import/test_cvat_video.py new file mode 100644 index 00000000..4bdef534 --- /dev/null +++ b/tests/data_engine/annotation_import/test_cvat_video.py @@ -0,0 +1,318 @@ +import datetime +import zipfile +from pathlib import PurePosixPath +from unittest.mock import patch, PropertyMock + +import pytest +from dagshub_annotation_converter.converters.cvat import export_cvat_video_to_xml_bytes +from dagshub_annotation_converter.ir.image import IRBBoxImageAnnotation, CoordinateStyle +from dagshub_annotation_converter.ir.video import IRVideoBBoxFrameAnnotation + +from dagshub.data_engine.annotation.importer import AnnotationImporter +from dagshub.data_engine.annotation.metadata import MetadataAnnotations +from dagshub.data_engine.annotation.video import build_video_sequence_from_annotations +from dagshub.data_engine.client.models import MetadataSelectFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags +from dagshub.data_engine.model.datapoint import Datapoint +from dagshub.data_engine.model.query_result import QueryResult + + +@pytest.fixture(autouse=True) +def mock_source_prefix(ds): + with patch.object(type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath()): + yield + + +# --- import --- + + +def test_import_cvat_video(ds, tmp_path): + xml_file = tmp_path / "annotations.xml" + xml_file.write_bytes(_make_cvat_video_xml()) + + importer = AnnotationImporter(ds, "cvat_video", xml_file, load_from="disk") + result = importer.import_annotations() + + assert len(result) == 1 + anns = list(result.values())[0] + assert len(anns) == 2 + assert all(isinstance(a, IRVideoBBoxFrameAnnotation) for a in anns) + + +def test_flatten_cvat_fs_preserves_sequence_filename(ds, tmp_path): + importer = AnnotationImporter(ds, "cvat_video", tmp_path / "dataset", load_from="disk") + sequence = build_video_sequence_from_annotations( + [_make_video_bbox(frame=0), _make_video_bbox(frame=5)], + filename="nested/folder/video.mp4", + ) + + result = importer._key_cvat_fs_annotations_by_filename({"nested/annotations.xml": sequence}) + + assert "nested/folder/video.mp4" in result + + +# --- _get_all_video_annotations --- + + +def test_get_all_video_filters(ds): + image_ann = IRBBoxImageAnnotation( + filename="test.jpg", categories={"cat": 1.0}, + top=0.1, left=0.1, width=0.2, height=0.2, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.NORMALIZED, + ) + video_ann = _make_video_bbox() + + dp = Datapoint(datasource=ds, path="dp_0", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[image_ann, video_ann] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + result = qr._get_all_video_annotations("ann") + assert len(result) == 1 + assert isinstance(result[0], IRVideoBBoxFrameAnnotation) + + +def test_get_all_video_empty(ds): + dp = Datapoint(datasource=ds, path="dp_0", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[]) + + qr = _make_qr(ds, [dp], ann_field="ann") + assert qr._get_all_video_annotations("ann") == [] + + +def test_get_all_video_aggregates_across_datapoints(ds): + dps = [] + for i in range(3): + dp = Datapoint(datasource=ds, path=f"dp_{i}", datapoint_id=i, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_video_bbox(frame=i)] + ) + dps.append(dp) + + qr = _make_qr(ds, dps, ann_field="ann") + assert len(qr._get_all_video_annotations("ann")) == 3 + + +def test_get_all_video_expands_tracks(ds): + dp = Datapoint(datasource=ds, path="nested/video.mp4", datapoint_id=0, metadata={}) + sequence = build_video_sequence_from_annotations( + [_make_video_bbox(frame=0), _make_video_bbox(frame=5)], + filename=dp.path, + ) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=sequence.tracks) + + qr = _make_qr(ds, [dp], ann_field="ann") + result = qr._get_all_video_annotations("ann") + + assert len(result) == 2 + assert all(ann.filename == dp.path for ann in result) + + +# --- export_as_cvat_video --- + + +def test_export_cvat_video_xml(ds, tmp_path, monkeypatch): + qr, _ = _make_video_qr(ds) + captured = {} + + def _mock_download_files(self, target_dir, *args, **kwargs): + captured["download_dir"] = target_dir + captured["keep_source_prefix"] = kwargs.get("keep_source_prefix", True) + (target_dir / "video.mp4").parent.mkdir(parents=True, exist_ok=True) + (target_dir / "video.mp4").write_bytes(b"fake") + return target_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + result = qr.export_as_cvat_video(download_dir=tmp_path, annotation_field="ann") + + assert result.exists() + assert result == tmp_path / "data" / "labels" / "video.mp4.zip" + assert captured["download_dir"] == tmp_path / "data" / "videos" + assert captured["keep_source_prefix"] is False + with zipfile.ZipFile(result, "r") as z: + content = z.read("annotations.xml").decode("utf-8") + assert "") + return output_path + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_cvat_video_to_zip", + _mock_export_cvat_video_to_zip, + ) + + qr.export_as_cvat_video(download_dir=tmp_path, annotation_field="ann") + + assert captured["video_file"] is not None + assert "data/videos" in captured["video_file"] + assert captured["video_file"].endswith("video.mp4") + + +def test_export_cvat_video_missing_local_file_raises(ds, tmp_path, monkeypatch): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + ann = _make_video_bbox(frame=0, object_id=0) + ann.video_width = 0 + ann.video_height = 0 + ann.filename = "missing.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann]) + qr = _make_qr(ds, [dp], ann_field="ann") + + def _mock_download_files(self, target_dir, *args, **kwargs): + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + + with pytest.raises(FileNotFoundError, match=r"missing\.mp4"): + qr.export_as_cvat_video(download_dir=tmp_path, annotation_field="ann") + + +# --- helpers --- + + +def _make_video_bbox(frame=0, object_id=0) -> IRVideoBBoxFrameAnnotation: + return IRVideoBBoxFrameAnnotation( + imported_id=str(object_id), frame_number=frame, + left=100.0, top=150.0, width=50.0, height=80.0, + video_width=1920, video_height=1080, + categories={"person": 1.0}, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + + +def _make_cvat_video_xml() -> bytes: + anns = [_make_video_bbox(frame=0, object_id=0), _make_video_bbox(frame=5, object_id=0)] + sequence = build_video_sequence_from_annotations(anns, filename="video.mp4") + return export_cvat_video_to_xml_bytes(sequence, video_name="video.mp4") + + +def _make_video_qr(ds): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + anns = [_make_video_bbox(frame=0, object_id=0), _make_video_bbox(frame=5, object_id=0)] + for ann in anns: + ann.filename = "video.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=anns) + qr = _make_qr(ds, [dp], ann_field="ann") + return qr, dp + + +def _make_qr(ds, datapoints, ann_field=None): + fields = [] + if ann_field: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=ann_field, + multiple=False, valueType=MetadataFieldType.BLOB, + name=ann_field, tags={ReservedTags.ANNOTATION.value}, + )) + return QueryResult(datasource=ds, _entries=datapoints, fields=fields) diff --git a/tests/data_engine/annotation_import/test_mot.py b/tests/data_engine/annotation_import/test_mot.py new file mode 100644 index 00000000..8de30655 --- /dev/null +++ b/tests/data_engine/annotation_import/test_mot.py @@ -0,0 +1,471 @@ +import configparser +import datetime +import json +import zipfile +from pathlib import Path, PurePosixPath +from unittest.mock import patch, PropertyMock + +import pytest +from dagshub_annotation_converter.ir.image import CoordinateStyle +from dagshub_annotation_converter.ir.video import IRVideoBBoxFrameAnnotation + +from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationsNotFoundError +from dagshub.data_engine.annotation.metadata import MetadataAnnotations +from dagshub.data_engine.annotation.video import build_video_sequence_from_annotations +from dagshub.data_engine.client.models import MetadataSelectFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags +from dagshub.data_engine.model.datapoint import Datapoint +from dagshub.data_engine.model.query_result import QueryResult + + +@pytest.fixture(autouse=True) +def mock_source_prefix(ds): + with patch.object(type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath()): + yield + + +# --- is_video_format --- + + +@pytest.mark.parametrize( + "ann_type, expected", + [ + ("yolo", False), + ("cvat", False), + ("coco", False), + ("mot", True), + ("cvat_video", True), + ], +) +def test_is_video_format(ds, ann_type, expected, tmp_path): + kwargs = {} + if ann_type == "yolo": + kwargs["yolo_type"] = "bbox" + importer = AnnotationImporter(ds, ann_type, tmp_path / "dummy", load_from="disk", **kwargs) + assert importer.is_video_format is expected + + +# --- _key_video_annotations_by_filename --- + + +def test_flatten_merges_frames(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "test_video", load_from="disk") + result = importer._key_video_annotations_by_filename({ + 0: [_make_video_bbox(frame=0)], + 5: [_make_video_bbox(frame=5)], + }) + assert "test_video" in result + assert len(result["test_video"]) == 2 + + +def test_flatten_defaults_to_file_stem(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "my_sequence", load_from="disk") + result = importer._key_video_annotations_by_filename({0: [_make_video_bbox()]}) + assert "my_sequence" in result + + +def test_flatten_video_name_override(ds, tmp_path): + importer = AnnotationImporter( + ds, "mot", tmp_path / "test_video", load_from="disk", video_name="custom.mp4" + ) + result = importer._key_video_annotations_by_filename({0: [_make_video_bbox()]}) + assert "custom.mp4" in result + + +def test_flatten_sequence(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "test_video", load_from="disk") + sequence = build_video_sequence_from_annotations([_make_video_bbox(frame=0), _make_video_bbox(frame=5)]) + result = importer._key_video_annotations_by_filename(sequence) + + assert "test_video" in result + assert len(result["test_video"]) == 2 + + +def test_flatten_sequence_preserves_sequence_filename(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "dataset", load_from="disk") + sequence = build_video_sequence_from_annotations( + [_make_video_bbox(frame=0), _make_video_bbox(frame=5)], + filename="nested/videos/video.mp4", + ) + + result = importer._key_video_annotations_by_filename(sequence) + + assert "nested/videos/video.mp4" in result + + +def test_flatten_mot_fs_preserves_relative_video_path(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "dataset", load_from="disk") + sequence = build_video_sequence_from_annotations( + [_make_video_bbox(frame=0), _make_video_bbox(frame=5)], + filename="nested/video.mp4", + ) + + result = importer._key_mot_fs_annotations_by_filename({Path("nested/video.mp4"): (sequence, object())}) + + assert "nested/video.mp4" in result + + +def test_build_video_sequence_sets_top_level_dimensions(): + anns = [ + IRVideoBBoxFrameAnnotation( + imported_id="0", + frame_number=0, + left=100.0, + top=150.0, + width=50.0, + height=80.0, + video_width=1920, + video_height=1080, + categories={"person": 1.0}, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + ] + + sequence = build_video_sequence_from_annotations(anns, filename="video.mp4") + + assert sequence.video_width == 1920 + assert sequence.video_height == 1080 + + +def test_video_export_layout_uses_datasource_prefix(ds): + qr, _ = _make_video_qr(ds) + with patch.object( + type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath("my_ds_path") + ): + video_dir, labels_dir, dataset_root = qr._resolve_export_dirs(Path("export"), "videos") + + assert video_dir == Path("export") / "data" / "my_ds_path" / "videos" + assert labels_dir == Path("export") / "data" / "my_ds_path" / "labels" + assert dataset_root == Path("export") / "data" / "my_ds_path" + + +def test_video_export_layout_reuses_existing_videos_suffix(ds): + qr, _ = _make_video_qr(ds) + with patch.object( + type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath("my_ds_path/videos") + ): + video_dir, labels_dir, dataset_root = qr._resolve_export_dirs(Path("export"), "videos") + + assert video_dir == Path("export") / "data" / "my_ds_path" / "videos" + assert labels_dir == Path("export") / "data" / "my_ds_path" / "labels" + assert dataset_root == Path("export") / "data" / "my_ds_path" + + +def test_video_export_layout_strips_leading_data_prefix(ds): + qr, _ = _make_video_qr(ds) + with patch.object( + type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath("data/videos") + ): + video_dir, labels_dir, dataset_root = qr._resolve_export_dirs(Path("export"), "videos") + + assert video_dir == Path("export") / "data" / "videos" + assert labels_dir == Path("export") / "data" / "labels" + assert dataset_root == Path("export") / "data" + + +# --- import --- + + +def test_import_mot_from_dir(ds, tmp_path): + mot_dir = tmp_path / "mot_seq" + _create_mot_dir(mot_dir) + + importer = AnnotationImporter(ds, "mot", mot_dir, load_from="disk") + result = importer.import_annotations() + + assert len(result) == 1 + anns = list(result.values())[0] + assert len(anns) == 2 + assert all(isinstance(a, IRVideoBBoxFrameAnnotation) for a in anns) + + +def test_import_mot_from_zip(ds, tmp_path): + mot_dir = tmp_path / "mot_seq" + _create_mot_dir(mot_dir) + zip_path = _zip_mot_dir(tmp_path, mot_dir) + + importer = AnnotationImporter(ds, "mot", zip_path, load_from="disk") + result = importer.import_annotations() + + assert len(result) == 1 + assert len(list(result.values())[0]) == 2 + + +def test_import_mot_from_fs_passes_dimensions(ds, tmp_path, monkeypatch): + # Create the labels/ subdir so the importer takes the load_mot_from_fs path + (tmp_path / "labels").mkdir() + captured = {} + + def _mock_load_mot_from_fs(import_dir, image_width=None, image_height=None, **kwargs): + captured["import_dir"] = import_dir + captured["image_width"] = image_width + captured["image_height"] = image_height + seq = build_video_sequence_from_annotations([_make_video_bbox(frame=0)], filename="seq_a") + return {Path("seq_a"): (seq, object())} + + monkeypatch.setattr("dagshub.data_engine.annotation.importer.load_mot_from_fs", _mock_load_mot_from_fs) + + with patch.object( + type(ds.source), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath("data/videos") + ): + importer = AnnotationImporter( + ds, + "mot", + tmp_path, + load_from="disk", + image_width=1280, + image_height=720, + ) + result = importer.import_annotations() + + assert captured["image_width"] == 1280 + assert captured["image_height"] == 720 + assert "seq_a" in result + + +def test_import_mot_nonexistent_raises(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "missing", load_from="disk") + with pytest.raises(AnnotationsNotFoundError): + importer.import_annotations() + + +# --- convert_to_ls_tasks --- + + +def test_convert_video_to_ls_tasks(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "video", load_from="disk") + video_anns = {"video.mp4": [_make_video_bbox(frame=0), _make_video_bbox(frame=1)]} + tasks = importer.convert_to_ls_tasks(video_anns) + + assert "video.mp4" in tasks + task_json = json.loads(tasks["video.mp4"]) + assert "annotations" in task_json + + +def test_convert_video_empty_skipped(ds, tmp_path): + importer = AnnotationImporter(ds, "mot", tmp_path / "video", load_from="disk") + tasks = importer.convert_to_ls_tasks({"video.mp4": []}) + assert "video.mp4" not in tasks + + +# --- export_as_mot --- + + +def test_export_mot_directory_structure(ds, tmp_path, monkeypatch): + qr, _ = _make_video_qr(ds) + captured = {} + + def _mock_download_files(self, target_dir, *args, **kwargs): + captured["download_dir"] = target_dir + captured["keep_source_prefix"] = kwargs.get("keep_source_prefix", True) + (target_dir / "video.mp4").parent.mkdir(parents=True, exist_ok=True) + (target_dir / "video.mp4").write_bytes(b"fake") + return target_dir + + def _mock_export_mot_to_dir(video_annotations, context, output_dir, video_file=None): + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "gt").mkdir(parents=True, exist_ok=True) + (output_dir / "gt" / "gt.txt").write_text("") + (output_dir / "gt" / "labels.txt").write_text("person\n") + config = configparser.ConfigParser() + config["Sequence"] = {"imWidth": "1920", "imHeight": "1080"} + with open(output_dir / "seqinfo.ini", "w") as f: + config.write(f) + return output_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_mot_to_dir", + _mock_export_mot_to_dir, + ) + result = qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + assert result.exists() + assert result == tmp_path / "data" / "labels" / "video" + assert captured["download_dir"] == tmp_path / "data" / "videos" + assert captured["keep_source_prefix"] is False + assert (result / "gt" / "gt.txt").exists() + assert (result / "gt" / "labels.txt").exists() + assert (result / "seqinfo.ini").exists() + + +def test_export_mot_explicit_dimensions(ds, tmp_path, monkeypatch): + qr, _ = _make_video_qr(ds) + captured = {} + + def _mock_download_files(self, target_dir, *args, **kwargs): + captured["download_dir"] = target_dir + (target_dir / "video.mp4").parent.mkdir(parents=True, exist_ok=True) + (target_dir / "video.mp4").write_bytes(b"fake") + return target_dir + + def _mock_export_mot_to_dir(video_annotations, context, output_dir, video_file=None): + output_dir.mkdir(parents=True, exist_ok=True) + config = configparser.ConfigParser() + config["Sequence"] = { + "imWidth": str(context.video_width), + "imHeight": str(context.video_height), + } + with open(output_dir / "seqinfo.ini", "w") as f: + config.write(f) + (output_dir / "gt").mkdir(parents=True, exist_ok=True) + (output_dir / "gt" / "gt.txt").write_text("") + (output_dir / "gt" / "labels.txt").write_text("person\n") + return output_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_mot_to_dir", + _mock_export_mot_to_dir, + ) + result = qr.export_as_mot( + download_dir=tmp_path, annotation_field="ann", image_width=1280, image_height=720 + ) + + seqinfo = (result / "seqinfo.ini").read_text() + assert captured["download_dir"] == tmp_path / "data" / "videos" + assert "1280" in seqinfo + assert "720" in seqinfo + + +def test_export_mot_no_annotations_raises(ds, tmp_path): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with pytest.raises(RuntimeError, match="No video annotations"): + qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + +def test_export_mot_multiple_videos(ds, tmp_path, monkeypatch): + dps = [] + for i in range(2): + dp = Datapoint(datasource=ds, path=f"video_{i}.mp4", datapoint_id=i, metadata={}) + ann = _make_video_bbox(frame=i, object_id=i) + ann.filename = dp.path + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann]) + dps.append(dp) + + captured = {} + + def _mock_download_files(self, target_dir, *args, **kwargs): + captured["download_dir"] = target_dir + target_dir.mkdir(parents=True, exist_ok=True) + for i in range(2): + (target_dir / f"video_{i}.mp4").write_bytes(b"fake") + return target_dir + + def _mock_export_mot_sequences_to_dirs(video_annotations, context, output_dir): + captured["output_dir"] = output_dir + for i in range(2): + seq_dir = output_dir / "labels" / f"video_{i}" + seq_dir.mkdir(parents=True, exist_ok=True) + (seq_dir / "gt").mkdir(parents=True, exist_ok=True) + (seq_dir / "gt" / "gt.txt").write_text("") + (seq_dir / "gt" / "labels.txt").write_text("person\n") + return output_dir / "labels" + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr( + "dagshub.data_engine.model.query_result.export_mot_sequences_to_dirs", + _mock_export_mot_sequences_to_dirs, + ) + qr = _make_qr(ds, dps, ann_field="ann") + result = qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + assert result == tmp_path / "data" + assert captured["download_dir"] == tmp_path / "data" / "videos" + assert captured["output_dir"] == tmp_path / "data" + assert (result / "labels" / "video_0" / "gt" / "gt.txt").exists() + assert (result / "labels" / "video_1" / "gt" / "gt.txt").exists() + + +def test_export_mot_passes_video_file_when_dimensions_missing(ds, tmp_path, monkeypatch): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + anns = [_make_video_bbox(frame=0, object_id=1), _make_video_bbox(frame=1, object_id=1)] + for ann in anns: + ann.video_width = 0 + ann.video_height = 0 + ann.filename = "video.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=anns) + qr = _make_qr(ds, [dp], ann_field="ann") + + captured = {} + + def _mock_download_files(self, target_dir, *args, **kwargs): + video_path = target_dir / "video.mp4" + video_path.parent.mkdir(parents=True, exist_ok=True) + video_path.write_bytes(b"video") + return target_dir + + def _mock_export_mot_to_dir(video_annotations, context, output_dir, video_file=None): + captured["video_file"] = str(video_file) if video_file is not None else None + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir + + monkeypatch.setattr(QueryResult, "download_files", _mock_download_files) + monkeypatch.setattr("dagshub.data_engine.model.query_result.export_mot_to_dir", _mock_export_mot_to_dir) + + qr.export_as_mot(download_dir=tmp_path, annotation_field="ann") + + assert captured["video_file"] is not None + assert "data/videos" in captured["video_file"] + assert captured["video_file"].endswith("video.mp4") + + +# --- helpers --- + + +def _make_video_bbox(frame=0, object_id=0) -> IRVideoBBoxFrameAnnotation: + return IRVideoBBoxFrameAnnotation( + imported_id=str(object_id), frame_number=frame, + left=100.0, top=150.0, width=50.0, height=80.0, + video_width=1920, video_height=1080, + categories={"person": 1.0}, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + + +def _create_mot_dir(mot_dir: Path): + gt_dir = mot_dir / "gt" + gt_dir.mkdir(parents=True) + (gt_dir / "gt.txt").write_text("1,1,100,150,50,80,1,1,1.0\n2,1,110,160,50,80,1,1,0.9\n") + (gt_dir / "labels.txt").write_text("person\n") + config = configparser.ConfigParser() + config["Sequence"] = { + "name": "test", "frameRate": "30", "seqLength": "100", + "imWidth": "1920", "imHeight": "1080", + } + with open(mot_dir / "seqinfo.ini", "w") as f: + config.write(f) + + +def _zip_mot_dir(tmp_path: Path, mot_dir: Path) -> Path: + zip_path = tmp_path / "mot.zip" + with zipfile.ZipFile(zip_path, "w") as z: + z.write(mot_dir / "gt" / "gt.txt", "gt/gt.txt") + z.write(mot_dir / "gt" / "labels.txt", "gt/labels.txt") + z.write(mot_dir / "seqinfo.ini", "seqinfo.ini") + return zip_path + + +def _make_video_qr(ds): + dp = Datapoint(datasource=ds, path="video.mp4", datapoint_id=0, metadata={}) + anns = [_make_video_bbox(frame=0, object_id=1), _make_video_bbox(frame=1, object_id=1)] + for ann in anns: + ann.filename = "video.mp4" + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=anns) + qr = _make_qr(ds, [dp], ann_field="ann") + return qr, dp + + +def _make_qr(ds, datapoints, ann_field=None): + fields = [] + if ann_field: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=ann_field, + multiple=False, valueType=MetadataFieldType.BLOB, + name=ann_field, tags={ReservedTags.ANNOTATION.value}, + )) + return QueryResult(datasource=ds, _entries=datapoints, fields=fields) diff --git a/tests/data_engine/conftest.py b/tests/data_engine/conftest.py index e8f0c70a..02ee8331 100644 --- a/tests/data_engine/conftest.py +++ b/tests/data_engine/conftest.py @@ -1,11 +1,13 @@ import datetime +from pathlib import PurePosixPath +from unittest.mock import PropertyMock import pytest from dagshub.common.api import UserAPI from dagshub.common.api.responses import UserAPIResponse from dagshub.data_engine import datasources -from dagshub.data_engine.client.models import MetadataSelectFieldSchema, PreprocessingStatus +from dagshub.data_engine.client.models import DatasourceType, MetadataSelectFieldSchema, PreprocessingStatus from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.datasource import DatasetState, Datasource from dagshub.data_engine.model.query_result import QueryResult @@ -26,6 +28,7 @@ def other_ds(mocker, mock_dagshub_auth) -> Datasource: def _create_mock_datasource(mocker, id, name) -> Datasource: ds_state = datasources.DatasourceState(id=id, name=name, repo="kirill/repo") + ds_state.source_type = DatasourceType.REPOSITORY ds_state.path = "repo://kirill/repo/data/" ds_state.preprocessing_status = PreprocessingStatus.READY mocker.patch.object(ds_state, "client") @@ -33,6 +36,7 @@ def _create_mock_datasource(mocker, id, name) -> Datasource: mocker.patch.object(ds_state, "get_from_dagshub") # Stub out root path so all the content_path/etc work without also mocking out RepoAPI mocker.patch.object(ds_state, "_root_path", return_value="http://example.com") + mocker.patch.object(type(ds_state), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath()) ds_state.repoApi = MockRepoAPI("kirill/repo") return Datasource(ds_state) diff --git a/tests/mocks/repo_api.py b/tests/mocks/repo_api.py index d457d161..22b6c94c 100644 --- a/tests/mocks/repo_api.py +++ b/tests/mocks/repo_api.py @@ -113,6 +113,10 @@ def generate_content_api_entry(path, is_dir=False, versioning="dvc") -> ContentA def default_branch(self) -> str: return self._default_branch + @property + def id(self) -> int: + return 1 + def get_connected_storages(self) -> List[StorageAPIEntry]: return self.storages