-
Notifications
You must be signed in to change notification settings - Fork 42
Verify decoder outputs #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
5e06186
Verify decoder outputs
Dan-Flores 0ad8a3c
wip, comppare w tc_public
Dan-Flores e4a6ceb
Test sequential and random frames
Dan-Flores 86c8a4f
Test sequential and random frames
Dan-Flores f4f8973
lint
Dan-Flores dcbc1e8
Only reuse stream_index option, use seek_mode=exact
Dan-Flores 68f96e1
End script after verify_outputs
Dan-Flores 8517aa2
Simplify assertion logic
Dan-Flores 31c6a5f
add print statement when frames match
Dan-Flores c813bfd
Increase verified outputs to num_samples, rename non_seq_frames
Dan-Flores File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
import abc | ||
import json | ||
import subprocess | ||
import typing | ||
import urllib.request | ||
from concurrent.futures import ThreadPoolExecutor, wait | ||
from dataclasses import dataclass | ||
from dataclasses import dataclass, field | ||
from itertools import product | ||
from pathlib import Path | ||
|
||
|
@@ -23,6 +24,7 @@ | |
get_next_frame, | ||
seek_to_pts, | ||
) | ||
from torchcodec._frame import FrameBatch | ||
from torchcodec.decoders import VideoDecoder, VideoStreamMetadata | ||
|
||
torch._dynamo.config.cache_size_limit = 100 | ||
|
@@ -824,6 +826,42 @@ def convert_result_to_df_item( | |
return df_item | ||
|
||
|
||
@dataclass | ||
class DecoderKind: | ||
display_name: str | ||
kind: typing.Type[AbstractDecoder] | ||
default_options: dict[str, str] = field(default_factory=dict) | ||
|
||
|
||
decoder_registry = { | ||
"decord": DecoderKind("DecordAccurate", DecordAccurate), | ||
"decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch), | ||
"torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore), | ||
"torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch), | ||
"torchcodec_core_nonbatch": DecoderKind( | ||
"TorchCodecCoreNonBatch", TorchCodecCoreNonBatch | ||
), | ||
"torchcodec_core_compiled": DecoderKind( | ||
"TorchCodecCoreCompiled", TorchCodecCoreCompiled | ||
), | ||
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic), | ||
"torchcodec_public_nonbatch": DecoderKind( | ||
"TorchCodecPublicNonBatch", TorchCodecPublicNonBatch | ||
), | ||
"torchvision": DecoderKind( | ||
# We don't compare against TorchVision's "pyav" backend because it doesn't support | ||
# accurate seeks. | ||
"TorchVision[backend=video_reader]", | ||
TorchVision, | ||
{"backend": "video_reader"}, | ||
), | ||
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder), | ||
"opencv": DecoderKind( | ||
"OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"} | ||
), | ||
} | ||
|
||
|
||
def run_benchmarks( | ||
decoder_dict: dict[str, AbstractDecoder], | ||
video_files_paths: list[Path], | ||
|
@@ -986,3 +1024,77 @@ def run_benchmarks( | |
compare = benchmark.Compare(results) | ||
compare.print() | ||
return df_data | ||
|
||
|
||
def verify_outputs(decoders_to_run, video_paths, num_samples): | ||
# Reuse TorchCodecPublic decoder stream_index option, if provided. | ||
options = decoder_registry["torchcodec_public"].default_options | ||
if torchcodec_decoder := next( | ||
( | ||
decoder | ||
for name, decoder in decoders_to_run.items() | ||
if "TorchCodecPublic" in name | ||
), | ||
None, | ||
): | ||
options["stream_index"] = ( | ||
str(torchcodec_decoder._stream_index) | ||
if torchcodec_decoder._stream_index is not None | ||
else "" | ||
) | ||
# Create default TorchCodecPublic decoder to use as a baseline | ||
torchcodec_public_decoder = TorchCodecPublic(**options) | ||
|
||
# Get frames using each decoder | ||
for video_file_path in video_paths: | ||
metadata = get_metadata(video_file_path) | ||
metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps" | ||
print(f"{metadata_label=}") | ||
|
||
# Generate uniformly spaced PTS | ||
duration = metadata.duration_seconds | ||
pts_list = [i * duration / num_samples for i in range(num_samples)] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically this is uniformly-spaced PTS values, or just evenly-spaced PTS values. Uniformly random would look like |
||
# Get the frames from TorchCodecPublic as the baseline | ||
torchcodec_public_results = decode_and_adjust_frames( | ||
torchcodec_public_decoder, | ||
video_file_path, | ||
num_samples=num_samples, | ||
pts_list=pts_list, | ||
) | ||
|
||
for decoder_name, decoder in decoders_to_run.items(): | ||
print(f"video={video_file_path}, decoder={decoder_name}") | ||
|
||
frames = decode_and_adjust_frames( | ||
decoder, | ||
video_file_path, | ||
num_samples=num_samples, | ||
pts_list=pts_list, | ||
) | ||
for f1, f2 in zip(torchcodec_public_results, frames): | ||
torch.testing.assert_close(f1, f2) | ||
print(f"Results of baseline TorchCodecPublic and {decoder_name} match!") | ||
|
||
|
||
def decode_and_adjust_frames( | ||
decoder, video_file_path, *, num_samples: int, pts_list: list[float] | ||
): | ||
frames = [] | ||
# Decode non-sequential frames using decode_frames function | ||
non_seq_frames = decoder.decode_frames(video_file_path, pts_list) | ||
# TorchCodec's batch APIs return a FrameBatch, so we need to extract the frames | ||
if isinstance(non_seq_frames, FrameBatch): | ||
non_seq_frames = non_seq_frames.data | ||
frames.extend(non_seq_frames) | ||
|
||
# Decode sequential frames using decode_first_n_frames function | ||
seq_frames = decoder.decode_first_n_frames(video_file_path, num_samples) | ||
if isinstance(seq_frames, FrameBatch): | ||
seq_frames = seq_frames.data | ||
frames.extend(seq_frames) | ||
|
||
# Check and convert frames to C,H,W for consistency with other decoders. | ||
if frames[0].shape[-1] == 3: | ||
frames = [frame.permute(-1, *range(frame.dim() - 1)) for frame in frames] | ||
return frames |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that the reference decoder will be subject to the options that the user provides, such as
seek_mode
. I think we shouldn't try to use the options the user provided, but instead decide what the reference decoder is, and always use that. That means that we probably shouldn't use the default options, but decide what options we want to use. Regardingseek_mode
, I think we should probably useexact
as the reference.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One catch is that given the new 'stream_index' option, we rely on the user to indicate which stream the benchmarks should compare. Without this option, the benchmark for OpenCV vs TorchCodecPublic would not match.
I agree that
exact
should be used as the reference. I've updated the code to only reuse thestream_index
argument.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! But we want to use the stream_index that the user provided. Good call. :)