Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dl/stream export project #9159

Open
wants to merge 5 commits into
base: dl/stream-export
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Changed

- Optimized memory usage for projects with YOLO and COCO formats
(<https://github.com/cvat-ai/cvat/pull/9159>)
218 changes: 77 additions & 141 deletions cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from operator import add
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Literal, NamedTuple, Optional, Union
from typing import Any, Callable, Generator, Literal, NamedTuple, Optional, Union

import attr
import datumaro as dm
import datumaro.util
import defusedxml.ElementTree as ET
Expand Down Expand Up @@ -267,6 +268,7 @@ class Frame(NamedTuple):
shapes: Sequence[CommonData.Shape]
labels: Sequence[CommonData.Label]
subset: str
task_id: int

class Label(NamedTuple):
id: int
Expand Down Expand Up @@ -491,7 +493,8 @@ def get_frame(idx):
labeled_shapes=[],
tags=[],
shapes=[],
labels={}
labels={},
task_id=self._db_task.id,
)
return frames[frame]

Expand Down Expand Up @@ -1639,7 +1642,7 @@ def map_label(name, parent=''): return label_cat.find(name, parent)[0]
return self.convert_annotations(cvat_frame_anno, label_attrs, map_label)


class CvatTaskOrJobDataExtractor(dm.SubsetBase, CVATDataExtractorMixin):
class CvatDataExtractorBase(dm.SubsetBase, CVATDataExtractorMixin):
def __init__(
self,
instance_data: CommonData,
Expand All @@ -1653,44 +1656,50 @@ def __init__(
dm.SubsetBase.__init__(
self,
media_type=dm.Image if dimension == DimensionType.DIM_2D else dm.PointCloud,
subset=instance_meta['subset'],
subset=instance_meta['subset'] if "subset" in instance_meta else None,
)
CVATDataExtractorMixin.__init__(self, **kwargs)

self._categories = self._load_categories(instance_meta['labels'])
self._user = self._load_user_info(instance_meta) if dimension == DimensionType.DIM_3D else {}
self._dimension = dimension
self._format_type = format_type
self._instance_data = instance_data
self._include_images = include_images
self._grouped_by_frame = list(instance_data.group_by_frame(include_empty=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, we can call this function in the __iter__ method for streaming cases, otherwise it will keep all the annotations in memory.

Copy link
Contributor Author

@Eldies Eldies Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, these annotations will be kept in memory during __iter__ anyway, and I did not find another simple way to get number of items for project.
Also, I am redesigning this PR a bit because I missed that on streaming export items are iterated more than once for multi-subset projects

self._instance_meta = instance_meta

if isinstance(instance_data, TaskData):
db_tasks = [instance_data.db_instance]
elif isinstance(instance_data, JobData):
db_tasks = [instance_data.db_instance.segment.task]
elif isinstance(instance_data, ProjectData):
db_tasks = instance_data.tasks
else:
assert False

if dimension == DimensionType.DIM_3D or include_images:
if isinstance(instance_data, TaskData):
db_task = instance_data.db_instance
elif isinstance(instance_data, JobData):
db_task = instance_data.db_instance.segment.task
else:
assert False

self._media_provider = MEDIA_PROVIDERS_BY_DIMENSION[dimension](
{0: MediaSource(db_task)}
if self._dimension == DimensionType.DIM_3D or include_images:
self._media_provider = MEDIA_PROVIDERS_BY_DIMENSION[self._dimension](
{
task.id: MediaSource(task)
for task in db_tasks
}
)

def __iter__(self):
instance_meta = self._instance_data.meta[self._instance_data.META_FIELD]
is_video = instance_meta['mode'] == 'interpolation'
ext = ''
if is_video:
ext = TaskFrameProvider.VIDEO_FRAME_EXT
self._ext_per_task: dict[int, str] = {
task.id: TaskFrameProvider.VIDEO_FRAME_EXT if is_video else ''
for task in db_tasks
for is_video in [task.mode == 'interpolation']
}

for frame_data in self._instance_data.group_by_frame(include_empty=True):
def __iter__(self):
for frame_data in self._iterate_frame_data():
dm_media_args = {
'path': frame_data.name + ext,
'ext': ext or frame_data.name.rsplit(osp.extsep, maxsplit=1)[1],
'path': frame_data.name + self._ext_per_task[frame_data.task_id],
'ext': self._ext_per_task[frame_data.task_id] or frame_data.name.rsplit(osp.extsep, maxsplit=1)[1],
}
if self._dimension == DimensionType.DIM_3D:
dm_media: dm.PointCloud = self._media_provider.get_media_for_frame(
0, frame_data.id, **dm_media_args
frame_data.task_id, frame_data.id, **dm_media_args
)

if not self._include_images:
Expand All @@ -1703,24 +1712,12 @@ def __iter__(self):
dm_media_args['size'] = (frame_data.height, frame_data.width)
if self._include_images:
dm_media: dm.Image = self._media_provider.get_media_for_frame(
0, frame_data.idx, **dm_media_args
frame_data.task_id, frame_data.idx, **dm_media_args
)
else:
dm_media = dm.Image.from_file(**dm_media_args)

# do not keep parsed lazy list data after this iteration
frame_data = frame_data._replace(
labeled_shapes=[
(
shape._replace(points=shape.points.lazy_copy())
if isinstance(shape.points, LazyList) and not shape.points.is_parsed
else shape
)
for shape in frame_data.labeled_shapes
]
)

dm_anno = self._read_cvat_anno(frame_data, instance_meta['labels'])
dm_anno = self._read_cvat_anno(frame_data, self._instance_meta['labels'])

dm_attributes = {'frame': frame_data.frame}

Expand All @@ -1738,7 +1735,7 @@ def __iter__(self):
dm_attributes["createdAt"] = self._user["createdAt"]
dm_attributes["updatedAt"] = self._user["updatedAt"]
dm_attributes["labels"] = []
for (idx, (_, label)) in enumerate(instance_meta['labels']):
for (idx, (_, label)) in enumerate(self._instance_meta['labels']):
dm_attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]})
dm_attributes["track_id"] = -1

Expand All @@ -1753,7 +1750,30 @@ def __iter__(self):
yield dm_item

def __len__(self):
return len(self._instance_data)
return len(self._grouped_by_frame)

@property
def is_stream(self) -> bool:
return True

def _iterate_frame_data(self) -> Generator[CommonData.Frame | ProjectData.Frame, None, None]:
raise NotImplementedError()


class CvatTaskOrJobDataExtractor(CvatDataExtractorBase):
def _iterate_frame_data(self):
for frame_data in self._grouped_by_frame:
# do not keep parsed lazy list data after this iteration
yield frame_data._replace(
labeled_shapes=[
(
shape._replace(points=shape.points.lazy_copy())
if isinstance(shape.points, LazyList) and not shape.points.is_parsed
else shape
)
for shape in frame_data.labeled_shapes
]
)

def _read_cvat_anno(self, cvat_frame_anno: CommonData.Frame, labels: list):
categories = self.categories()
Expand All @@ -1767,112 +1787,26 @@ def map_label(name, parent=''): return label_cat.find(name, parent)[0]
return self.convert_annotations(cvat_frame_anno,
label_attrs, map_label, self._format_type, self._dimension)

@property
def is_stream(self) -> bool:
return True


class CVATProjectDataExtractor(dm.DatasetBase, CVATDataExtractorMixin):
def __init__(
self,
project_data: ProjectData,
*,
include_images: bool = False,
format_type: str = None,
dimension: DimensionType = DimensionType.DIM_2D,
**kwargs
):
dm.DatasetBase.__init__(
self, media_type=dm.Image if dimension == DimensionType.DIM_2D else dm.PointCloud
)
CVATDataExtractorMixin.__init__(self, **kwargs)

self._categories = self._load_categories(project_data.meta[project_data.META_FIELD]['labels'])
self._user = self._load_user_info(project_data.meta[project_data.META_FIELD]) if dimension == DimensionType.DIM_3D else {}
self._dimension = dimension
self._format_type = format_type

if self._dimension == DimensionType.DIM_3D or include_images:
self._media_provider = MEDIA_PROVIDERS_BY_DIMENSION[self._dimension](
{
task.id: MediaSource(task)
for task in project_data.tasks
}
)

ext_per_task: dict[int, str] = {
task.id: TaskFrameProvider.VIDEO_FRAME_EXT if is_video else ''
for task in project_data.tasks
for is_video in [task.mode == 'interpolation']
}

dm_items: list[dm.DatasetItem] = []
for frame_data in project_data.group_by_frame(include_empty=True):
dm_media_args = {
'path': frame_data.name + ext_per_task[frame_data.task_id],
'ext': ext_per_task[frame_data.task_id] or frame_data.name.rsplit(osp.extsep, maxsplit=1)[1],
}
if self._dimension == DimensionType.DIM_3D:
dm_media: dm.PointCloud = self._media_provider.get_media_for_frame(
frame_data.task_id, frame_data.id, **dm_media_args
)

if not include_images:
dm_media_args["extra_images"] = [
dm.Image.from_file(path=osp.basename(image.path))
for image in dm_media.extra_images
]
dm_media = dm.PointCloud.from_file(**dm_media_args)
else:
dm_media_args['size'] = (frame_data.height, frame_data.width)
if include_images:
dm_media: dm.Image = self._media_provider.get_media_for_frame(
frame_data.task_id, frame_data.idx, **dm_media_args
class CVATProjectDataExtractor(CvatDataExtractorBase):
def _iterate_frame_data(self):
for frame_data in self._grouped_by_frame:
# do not keep parsed lazy list data after this iteration
yield attr.evolve(
frame_data,
labeled_shapes=[
(
attr.evolve(shape, points=shape.points.lazy_copy())
if isinstance(shape.points, LazyList) and not shape.points.is_parsed
else shape
)
else:
dm_media = dm.Image.from_file(**dm_media_args)

dm_anno = self._read_cvat_anno(frame_data, project_data.meta[project_data.META_FIELD]['labels'])

dm_attributes = {'frame': frame_data.frame}

if self._dimension == DimensionType.DIM_2D:
dm_item = dm.DatasetItem(
id=osp.splitext(frame_data.name)[0],
annotations=dm_anno, media=dm_media,
subset=frame_data.subset,
attributes=dm_attributes,
)
elif self._dimension == DimensionType.DIM_3D:
if format_type == "sly_pointcloud":
dm_attributes["name"] = self._user["name"]
dm_attributes["createdAt"] = self._user["createdAt"]
dm_attributes["updatedAt"] = self._user["updatedAt"]
dm_attributes["labels"] = []
for (idx, (_, label)) in enumerate(project_data.meta[project_data.META_FIELD]['labels']):
dm_attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]})
dm_attributes["track_id"] = -1

dm_item = dm.DatasetItem(
id=osp.splitext(osp.split(frame_data.name)[-1])[0],
annotations=dm_anno, media=dm_media,
subset=frame_data.subset,
attributes=dm_attributes,
)

dm_items.append(dm_item)

self._items = dm_items
for shape in frame_data.labeled_shapes
],
)

def categories(self):
return self._categories

def __iter__(self):
yield from self._items

def __len__(self):
return len(self._items)


def GetCVATDataExtractor(
instance_data: Union[ProjectData, CommonData],
Expand All @@ -1891,9 +1825,11 @@ def GetCVATDataExtractor(
else:
return CvatTaskOrJobDataExtractor(instance_data, **kwargs)


class CvatImportError(Exception):
pass


@attrs
class CvatDatasetNotFoundError(CvatImportError):
message: str = ""
Expand Down
7 changes: 2 additions & 5 deletions cvat/apps/dataset_manager/formats/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from cvat.apps.dataset_manager.bindings import (
GetCVATDataExtractor,
NoMediaInAnnotationFileError,
ProjectData,
detect_dataset,
import_dm_annotations,
)
Expand All @@ -24,8 +23,7 @@
@exporter(name="COCO", ext="ZIP", version="1.0")
def _export(dst_file, temp_dir, instance_data, save_images=False):
with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor:
dataset_cls = Dataset if isinstance(instance_data, ProjectData) else StreamDataset
dataset = dataset_cls.from_extractors(extractor, env=dm_env)
dataset = StreamDataset.from_extractors(extractor, env=dm_env)
dataset.export(temp_dir, "coco_instances", save_media=save_images, merge_images=False)

make_zip_archive(temp_dir, dst_file)
Expand All @@ -52,8 +50,7 @@ def _import(src_file, temp_dir, instance_data, load_data_callback=None, **kwargs
@exporter(name="COCO Keypoints", ext="ZIP", version="1.0")
def _export(dst_file, temp_dir, instance_data, save_images=False):
with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor:
dataset_cls = Dataset if isinstance(instance_data, ProjectData) else StreamDataset
dataset = dataset_cls.from_extractors(extractor, env=dm_env)
dataset = StreamDataset.from_extractors(extractor, env=dm_env)
dataset.export(
temp_dir, "coco_person_keypoints", save_media=save_images, merge_images=False
)
Expand Down
6 changes: 2 additions & 4 deletions cvat/apps/dataset_manager/formats/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def _export_common(
**kwargs,
):
with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor:
dataset_cls = Dataset if isinstance(instance_data, ProjectData) else StreamDataset
dataset = dataset_cls.from_extractors(extractor, env=dm_env)
dataset = StreamDataset.from_extractors(extractor, env=dm_env)
dataset.export(temp_dir, format_name, save_media=save_images, **kwargs)

make_zip_archive(temp_dir, dst_file)
Expand Down Expand Up @@ -111,8 +110,7 @@ def _export_yolo_ultralytics_oriented_boxes(*args, **kwargs):
@exporter(name="Ultralytics YOLO Segmentation", ext="ZIP", version="1.0")
def _export_yolo_ultralytics_segmentation(dst_file, temp_dir, instance_data, *, save_images=False):
with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor:
dataset_cls = Dataset if isinstance(instance_data, ProjectData) else StreamDataset
dataset = dataset_cls.from_extractors(extractor, env=dm_env)
dataset = StreamDataset.from_extractors(extractor, env=dm_env)
dataset = dataset.transform("masks_to_polygons")
dataset.export(temp_dir, "yolo_ultralytics_segmentation", save_media=save_images)

Expand Down
Loading