Skip to content

Add encoding and bits_per_sample option to save function #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/backend/sox_io/common.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'


def get_enc_params(dtype):
if dtype == 'float32':
return 'PCM_F', 32
if dtype == 'int32':
return 'PCM_S', 32
if dtype == 'int16':
return 'PCM_S', 16
if dtype == 'uint8':
return 'PCM_U', 8
raise ValueError(f'Unexpected dtype: {dtype}')
4 changes: 3 additions & 1 deletion test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
)
from .common import (
name_func,
get_enc_params,
)


@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate)
sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)
740 changes: 302 additions & 438 deletions test/torchaudio_unittest/backend/sox_io/save_test.py

Large diffs are not rendered by default.

22 changes: 14 additions & 8 deletions test/torchaudio_unittest/backend/sox_io/torchscript_test.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
)
from .common import (
name_func,
get_enc_params,
)


@@ -35,8 +36,12 @@ def py_save_func(
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
torchaudio.save(
filepath, tensor, sample_rate, channels_first,
compression, None, encoding, bits_per_sample)


@skipIfNoExec('sox')
@@ -102,15 +107,16 @@ def test_save_wav(self, dtype, sample_rate, num_channels):
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)

expected = get_wav_data(dtype, num_channels)
expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
enc, bps = get_enc_params(dtype)

py_save_func(py_path, expected, sample_rate, True, None)
ts_save_func(ts_path, expected, sample_rate, True, None)
py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)

py_data, py_sr = load_wav(py_path)
ts_data, ts_sr = load_wav(ts_path)
py_data, py_sr = load_wav(py_path, normalize=False)
ts_data, ts_sr = load_wav(ts_path, normalize=False)

self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
@@ -131,8 +137,8 @@ def test_save_flac(self, sample_rate, num_channels, compression_level):
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')

py_save_func(py_path, expected, sample_rate, True, compression_level)
ts_save_func(ts_path, expected, sample_rate, True, compression_level)
py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)

# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav'
12 changes: 8 additions & 4 deletions test/torchaudio_unittest/common_utils/sox_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import subprocess
import warnings

@@ -32,6 +33,7 @@ def gen_audio_file(
command = [
'sox',
'-V3', # verbose
'--no-dither', # disable automatic dithering
'-R',
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
@@ -61,21 +63,23 @@ def gen_audio_file(
]
if attenuation is not None:
command += ['vol', f'-{attenuation}dB']
print(' '.join(command))
print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True)


def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
*, encoding=None, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', '-V3', '-R', str(src_path)]
command = ['sox', '-V3', '--no-dither', '-R', str(src_path)]
if encoding is not None:
command += ['--encoding', str(encoding)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
command += ['--compression', str(compression)]
command += [dst_path]
print(' '.join(command))
print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True)


191 changes: 128 additions & 63 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import warnings
from typing import Tuple, Optional

import torch
@@ -152,26 +151,6 @@ def load(
filepath, frame_offset, num_frames, normalize, channels_first, format)


@torch.jit.unused
def _save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format, dtype)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)


@_mod_utils.requires_module('torchaudio._torchaudio')
def save(
filepath: str,
@@ -180,30 +159,11 @@ def save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.

Note:
Supported formats are;

* WAV, AMB

* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer

* MP3
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB

To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.

Args:
filepath (str or pathlib.Path): Path to save file.
This function also handles ``pathlib.Path`` objects, but is annotated
@@ -215,32 +175,137 @@ def save(
compression (Optional[float]): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command.

* | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
| VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
* | ``FLAC``: compression level. Whole number from ``0`` to ``8``.
| ``8`` is default and highest compression.
* | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression
| and lowest quality. Default: ``3``.
``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.

``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.

``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.

See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional): Output audio format.
This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
dtype (str, optional): Output tensor dtype.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed.
``dtype`` parameter is only effective for ``float32`` Tensor.
format (str, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is infered from
file extension. If file extension is missing or different, you can specify the
correct format with this argument.

When ``filepath`` argument is file-like object, this argument is required.

Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;

- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)

Default values
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.

``"wav"``, ``"amb"``
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used to determine the default value.
- ``"PCM_U"`` if dtype is ``uint8``
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
- ``"PCM_F"`` if dtype is ``float32``

- ``"PCM_U"`` if ``bits_per_sample=8``
- ``"PCM_S"`` otherwise

``"sph"`` format;
- the default value is ``"PCM_S"``

bits_per_sample (int, optional): Changes the bit depth for the supported formats.
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.

Default Value;
If not provided, the default values are picked based on ``format`` and ``"encoding"``;

``"wav"``, ``"amb"``;
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used.
- ``8`` if dtype is ``uint8``
- ``16`` if dtype is ``int16``
- ``32`` if dtype is ``int32`` or ``float32``

- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"``
- ``32`` if ``encoding`` is ``"PCM_F"``

``"flac"`` format;
- the default value is ``24``

``"sph"`` format;
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``

``"amb"`` format;
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
- ``32`` if ``encoding`` is ``"PCM_F"``

Supported formats/encodings/bit depth/compression are;

``"wav"``, ``"amb"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law

Note: Default encoding/bit depth is determined by the dtype of the input Tensor.

``"mp3"``
Fixed bit rate (such as 128kHz) and variable bit rate compression.
Default: VBR with high quality.

``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)

``"ogg"``, ``"vorbis"``
- Different quality level. Default: approx. 112kbps

``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law

``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s

Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
or ``libmp3lame`` etc.
"""
if src.dtype == torch.float32 and dtype is None:
warnings.warn(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format, dtype)
return
if hasattr(filepath, 'write'):
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression,
format, encoding, bits_per_sample)
return
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, dtype)
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)


@_mod_utils.requires_module('torchaudio._torchaudio')
1 change: 1 addition & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ set(
sox/utils.cpp
sox/effects.cpp
sox/effects_chain.cpp
sox/types.cpp
)

if(BUILD_TRANSDUCER)
Loading