Skip to content

Commit 07872f7

Browse files
Support Audio feature for TAR archives in sequential access (huggingface#3129)
* Add test fixture for TAR WAV file * Add test iter_archive * Test dataset with Audio feature for TAR archive * Add Audio method to decode from bytes instead of path * Add Audio support for bytes besides path * Fix docstring * Remove archived attribute from test audio with TAR archive * Remove archived attribute from Audio feature * Implement Audio.encode_example * Call Audio.encode_example from encode_nested_example * Fix docs * Enhance Audio.decode_example to accept a string * Fix docs * Implement private Audio._storage_dtype to specify cached dtype * Change Audio._storage_dtype dynamically when encoding a string * Update test of Audio instantiation * Set ArrowWriter.schema property dynamically calculated from features * Update ArrowWriter.write_examples_on_file * Update ArrowWriter._build_writer * Fix code quality * Replace _schema with schema and condition on schema in ArrowWriter * Add test for MP3 TAR audio file * Refactor Audio decode_example * Pass raw bytes to torchaudio.load * Revert "Pass raw bytes to torchaudio.load" This reverts commit c973209. * Pass format to load in _decode_example_with_torchaudio * Fix filename extension in test * Fix Audio tests CI * Fix Audio tests CI * Fix audio test CI by checking out PR HEAD commit instead of merge commit * Change default Audio storage dtype to string * Rename Audio decode functions * Refactor Audio decode_example * Force CI re-run * Refactor and rename * Fix docstring * Fix docstrings
1 parent 6eda9e4 commit 07872f7

File tree

5 files changed

+193
-47
lines changed

5 files changed

+193
-47
lines changed

.github/workflows/test-audio.yml

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
name: Test audio
22

33
on:
4-
push:
5-
branches:
6-
- master
74
pull_request:
85
branches:
96
- master
@@ -12,15 +9,17 @@ jobs:
129
test:
1310
runs-on: ubuntu-latest
1411
steps:
12+
- name: Install OS dependencies
13+
run: |
14+
sudo apt-get update
15+
sudo apt-get install libsndfile1 sox
1516
- uses: actions/checkout@v2
17+
with:
18+
ref: ${{ github.event.pull_request.head.sha }}
1619
- name: Set up Python
1720
uses: actions/setup-python@v2
1821
with:
1922
python-version: "3.6"
20-
- name: Install OS dependencies
21-
run: |
22-
sudo apt-get update
23-
sudo apt-get install libsndfile1 sox
2423
- name: Install dependencies
2524
run: |
2625
python -m pip install --upgrade pip

src/datasets/arrow_writer.py

+29-24
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(
207207
raise ValueError("At least one of path and stream must be provided.")
208208
if features is not None:
209209
self._features = features
210-
self._schema = pa.schema(features.type)
210+
self._schema = None
211211
elif schema is not None:
212212
self._schema: pa.Schema = schema
213213
self._features = Features.from_arrow_schema(self._schema)
@@ -222,9 +222,7 @@ def __init__(
222222
self._hasher = KeyHasher("")
223223

224224
self._check_duplicates = check_duplicates
225-
226-
if disable_nullable and self._schema is not None:
227-
self._schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in self._schema)
225+
self._disable_nullable = disable_nullable
228226

229227
self._path = path
230228
if stream is None:
@@ -269,6 +267,7 @@ def close(self):
269267
self.stream.close() # This also closes self.pa_writer if it is opened
270268

271269
def _build_writer(self, inferred_schema: pa.Schema):
270+
schema = self.schema
272271
inferred_features = Features.from_arrow_schema(inferred_schema)
273272
if self._features is not None:
274273
if self.update_features: # keep original features it they match, or update them
@@ -279,21 +278,27 @@ def _build_writer(self, inferred_schema: pa.Schema):
279278
if inferred_field == fields[name]:
280279
inferred_features[name] = self._features[name]
281280
self._features = inferred_features
282-
self._schema: pa.Schema = inferred_schema
281+
schema: pa.Schema = inferred_schema
283282
else:
284283
self._features = inferred_features
285-
self._schema: pa.Schema = inferred_schema
284+
schema: pa.Schema = inferred_schema
286285
if self.disable_nullable:
287-
self._schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in self._schema)
286+
schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in schema)
288287
if self.with_metadata:
289-
self._schema = self._schema.with_metadata(
290-
self._build_metadata(DatasetInfo(features=self._features), self.fingerprint)
291-
)
292-
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
288+
schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=self._features), self.fingerprint))
289+
self._schema = schema
290+
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema)
293291

294292
@property
295293
def schema(self):
296-
return self._schema if self._schema is not None else []
294+
_schema = (
295+
self._schema
296+
if self._schema is not None
297+
else (pa.schema(self._features.type) if self._features is not None else None)
298+
)
299+
if self._disable_nullable and _schema is not None:
300+
_schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in _schema)
301+
return _schema if _schema is not None else []
297302

298303
@staticmethod
299304
def _build_metadata(info: DatasetInfo, fingerprint: Optional[str] = None) -> Dict[str, str]:
@@ -312,18 +317,18 @@ def write_examples_on_file(self):
312317

313318
# Since current_examples contains (example, key) tuples
314319
cols = (
315-
[col for col in self._schema.names if col in self.current_examples[0][0]]
316-
+ [col for col in self.current_examples[0][0].keys() if col not in self._schema.names]
317-
if self._schema
320+
[col for col in self.schema.names if col in self.current_examples[0][0]]
321+
+ [col for col in self.current_examples[0][0].keys() if col not in self.schema.names]
322+
if self.schema
318323
else self.current_examples[0][0].keys()
319324
)
320325

321-
schema = None if self.pa_writer is None and self.update_features else self._schema
322-
try_schema = self._schema if self.pa_writer is None and self.update_features else None
326+
schema = None if self.pa_writer is None and self.update_features else self.schema
327+
try_schema = self.schema if self.pa_writer is None and self.update_features else None
323328
arrays = []
324329
inferred_types = []
325330
for col in cols:
326-
col_type = schema.field(col).type if schema is not None else None
331+
col_type = schema.field(col).type if schema else None
327332
col_try_type = try_schema.field(col).type if try_schema is not None and col in try_schema.names else None
328333
typed_sequence = OptimizedTypedSequence(
329334
[row[0][col] for row in self.current_examples], type=col_type, try_type=col_try_type, col=col
@@ -339,7 +344,7 @@ def write_examples_on_file(self):
339344
)
340345
arrays.append(pa_array)
341346
inferred_types.append(inferred_type)
342-
schema = pa.schema(zip(cols, inferred_types)) if self.pa_writer is None else self._schema
347+
schema = pa.schema(zip(cols, inferred_types)) if self.pa_writer is None else self.schema
343348
table = pa.Table.from_arrays(arrays, schema=schema)
344349
self.write_table(table)
345350
self.current_examples = []
@@ -420,11 +425,11 @@ def write_batch(
420425
"""
421426
if batch_examples and len(next(iter(batch_examples.values()))) == 0:
422427
return
423-
schema = None if self.pa_writer is None and self.update_features else self._schema
424-
try_schema = self._schema if self.pa_writer is None and self.update_features else None
428+
schema = None if self.pa_writer is None and self.update_features else self.schema
429+
try_schema = self.schema if self.pa_writer is None and self.update_features else None
425430
typed_sequence_examples = {}
426431
for col in sorted(batch_examples.keys()):
427-
col_type = schema.field(col).type if schema is not None else None
432+
col_type = schema.field(col).type if schema else None
428433
col_try_type = try_schema.field(col).type if try_schema is not None and col in try_schema.names else None
429434
typed_sequence = OptimizedTypedSequence(batch_examples[col], type=col_type, try_type=col_try_type, col=col)
430435
typed_sequence_examples[col] = typed_sequence
@@ -459,8 +464,8 @@ def finalize(self, close_stream=True):
459464
self.hkey_record = []
460465
self.write_examples_on_file()
461466
if self.pa_writer is None:
462-
if self._schema is not None:
463-
self._build_writer(self._schema)
467+
if self.schema:
468+
self._build_writer(self.schema)
464469
else:
465470
raise ValueError("Please pass `features` or at least one example when writing data")
466471
self.pa_writer.close()

src/datasets/features/audio.py

+63-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
22
from dataclasses import dataclass, field
3+
from io import BytesIO
34
from typing import Any, ClassVar, Optional
45

56
import pyarrow as pa
@@ -11,50 +12,98 @@
1112
class Audio:
1213
"""Audio Feature to extract audio data from an audio file.
1314
15+
Input: The Audio feature accepts as input:
16+
- A :obj:`str`: Absolute path to the audio file (i.e. random access is allowed).
17+
- A :obj:`dict` with the keys:
18+
19+
- path: String with relative path of the audio file to the archive file.
20+
- bytes: Bytes content of the audio file.
21+
22+
This is useful for archived files with sequential access.
23+
1424
Args:
1525
sampling_rate (:obj:`int`, optional): Target sampling rate. If `None`, the native sampling rate is used.
16-
mono (:obj:`bool`, default ```True``): Whether to convert the audio signal to mono by averaging samples across channels.
26+
mono (:obj:`bool`, default ``True``): Whether to convert the audio signal to mono by averaging samples across
27+
channels.
1728
"""
1829

1930
sampling_rate: Optional[int] = None
2031
mono: bool = True
32+
_storage_dtype: str = "string"
2133
id: Optional[str] = None
2234
# Automatically constructed
2335
dtype: ClassVar[str] = "dict"
2436
pa_type: ClassVar[Any] = None
2537
_type: str = field(default="Audio", init=False, repr=False)
2638

2739
def __call__(self):
28-
return pa.string()
40+
return (
41+
pa.struct({"path": pa.string(), "bytes": pa.binary()}) if self._storage_dtype == "struct" else pa.string()
42+
)
43+
44+
def encode_example(self, value):
45+
"""Encode example into a format for Arrow.
46+
47+
Args:
48+
value (:obj:`str` or :obj:`dict`): Data passed as input to Audio feature.
49+
50+
Returns:
51+
:obj:`str` or :obj:`dict`
52+
"""
53+
if isinstance(value, dict):
54+
self._storage_dtype = "struct"
55+
return value
2956

3057
def decode_example(self, value):
3158
"""Decode example audio file into audio data.
3259
3360
Args:
34-
value: Audio file path.
61+
value (obj:`str` or :obj:`dict`): Either a string with the absolute audio file path or a dictionary with
62+
keys:
63+
64+
- path: String with relative audio file path.
65+
- bytes: Bytes of the audio file.
3566
3667
Returns:
3768
dict
3869
"""
39-
# TODO: backard compatibility for users without audio dependencies
40-
array, sampling_rate = (
41-
self._decode_example_with_torchaudio(value)
42-
if value.endswith(".mp3")
43-
else self._decode_example_with_librosa(value)
44-
)
45-
return {"path": value, "array": array, "sampling_rate": sampling_rate}
70+
path, file = (value["path"], BytesIO(value["bytes"])) if isinstance(value, dict) else (value, None)
71+
if path.endswith("mp3"):
72+
array, sampling_rate = self._decode_mp3(file if file else path)
73+
else:
74+
if file:
75+
array, sampling_rate = self._decode_non_mp3_file_like(file)
76+
else:
77+
array, sampling_rate = self._decode_non_mp3_path_like(path)
78+
return {"path": path, "array": array, "sampling_rate": sampling_rate}
4679

47-
def _decode_example_with_librosa(self, value):
80+
def _decode_non_mp3_path_like(self, path):
4881
try:
4982
import librosa
5083
except ImportError as err:
5184
raise ImportError("To support decoding audio files, please install 'librosa'.") from err
5285

53-
with xopen(value, "rb") as f:
86+
with xopen(path, "rb") as f:
5487
array, sampling_rate = librosa.load(f, sr=self.sampling_rate, mono=self.mono)
5588
return array, sampling_rate
5689

57-
def _decode_example_with_torchaudio(self, value):
90+
def _decode_non_mp3_file_like(self, file):
91+
try:
92+
import librosa
93+
import soundfile as sf
94+
except ImportError as err:
95+
raise ImportError("To support decoding audio files, please install 'librosa'.") from err
96+
97+
array, sampling_rate = sf.read(file)
98+
array = array.T
99+
if self.mono:
100+
array = librosa.to_mono(array)
101+
if self.sampling_rate and self.sampling_rate != sampling_rate:
102+
array = librosa.resample(array, sampling_rate, self.sampling_rate, res_type="kaiser_best")
103+
sampling_rate = self.sampling_rate
104+
return array, sampling_rate
105+
106+
def _decode_mp3(self, path_or_file):
58107
try:
59108
import torchaudio
60109
import torchaudio.transforms as T
@@ -65,7 +114,7 @@ def _decode_example_with_torchaudio(self, value):
65114
except RuntimeError as err:
66115
raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err
67116

68-
array, sampling_rate = torchaudio.load(value)
117+
array, sampling_rate = torchaudio.load(path_or_file, format="mp3")
69118
if self.sampling_rate and self.sampling_rate != sampling_rate:
70119
if not hasattr(self, "_resampler"):
71120
self._resampler = T.Resample(sampling_rate, self.sampling_rate)

src/datasets/features/features.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def encode_nested_example(schema, obj):
849849
return list(obj)
850850
# Object with special encoding:
851851
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
852-
elif isinstance(schema, (ClassLabel, TranslationVariableLanguages, Value, _ArrayXD)):
852+
elif isinstance(schema, (Audio, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD)):
853853
return schema.encode_example(obj)
854854
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
855855
return obj
@@ -961,7 +961,8 @@ class Features(dict):
961961
:class:`datasets.Sequence`.
962962
963963
- a :class:`Array2D`, :class:`Array3D`, :class:`Array4D` or :class:`Array5D` feature for multidimensional arrays
964-
- a :class:`datasets.Audio` stores the path to an audio file and can extract audio data from it
964+
- an :class:`Audio` feature to store the absolute path to an audio file or a dictionary with the relative path
965+
to an audio file ("path" key) and its bytes content ("bytes" key). This feature extracts the audio data.
965966
- :class:`datasets.Translation` and :class:`datasets.TranslationVariableLanguages`, the two features specific to Machine Translation
966967
"""
967968

0 commit comments

Comments
 (0)