Skip to content

Commit 146146d

Browse files
feat: add clotho dataset, audio web dataset doc
1 parent 58793ce commit 146146d

File tree

6 files changed

+199
-9
lines changed

6 files changed

+199
-9
lines changed

README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,33 @@ WAVDataset(
3333
)
3434
```
3535

36+
37+
### AudioWebDataset
38+
A [`WebDataset`](https://webdataset.github.io/webdataset/) extension for audio data. Assumes that the `.tar` file comes with pairs of `.wav` (or `.flac`) and `.json` data.
39+
```py
40+
from audio_data_pytorch import AudioWebDataset
41+
42+
dataset = AudioWebDataset(
43+
urls='mywebdataset.tar'
44+
)
45+
46+
waveform, info = next(iter(dataset))
47+
48+
print(waveform.shape) # torch.Size([2, 480000])
49+
print(info.keys()) # dict_keys(['text'])
50+
```
51+
52+
#### Full API:
53+
```py
54+
dataset = AudioWebDataset(
55+
urls: Union[str, Sequence[str]],
56+
transforms: Optional[Callable] = None, # Transforms to apply to audio files
57+
batch_size: Optional[int] = None, # Why is batch_size here? See https://webdataset.github.io/webdataset/gettingstarted/#webdataset-and-dataloader
58+
shuffle: int = 128, # Shuffle in groups of 128
59+
**kwargs, # Forwarded to WebDataset class
60+
)
61+
```
62+
3663
### LJSpeech Dataset
3764
An unsupervised dataset for LJSpeech with voice-only data.
3865
```py
@@ -129,6 +156,31 @@ dataset = YoutubeDataset(
129156
)
130157
```
131158

159+
### Clotho Dataset
160+
A wrapper for the [Clotho](https://zenodo.org/record/3490684#.Y0VVVOxBwR0) dataset extending `AudioWebDataset`. Requires `pip install py7zr` to decompress `.7z` archive.
161+
162+
```py
163+
from audio_data_pytorch import ClothoDataset, Crop, Stereo, Mono
164+
165+
dataset = ClothoDataset(
166+
root='./data/',
167+
preprocess_sample_rate=48000, # Added to all files during preprocessing
168+
preprocess_transforms=nn.Sequential(Crop(48000*10), Stereo()), # Added to all files during preprocessing
169+
transforms=Mono() # Added dynamically at iteration time
170+
)
171+
```
172+
173+
```py
174+
dataset = ClothoDataset(
175+
root: str, # Path where the dataset is saved
176+
split: str = 'train', # Dataset split, one of: 'train', 'valid'
177+
preprocess_sample_rate: Optional[int] = None, # Preprocesses dataset to this sample rate
178+
preprocess_transforms: Optional[Callable] = None, # Preprocesses dataset with the provided transfomrs
179+
reset: bool = False, # Re-compute preprocessing if `true`
180+
**kwargs # Forwarded to `AudioWebDataset`
181+
)
182+
```
183+
132184

133185
## Transforms
134186

audio_data_pytorch/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .audio_web_dataset import AudioWebDataset, AudioWebDatasetPreprocess
2+
from .clotho_dataset import ClothoDataset
23
from .common_voice_dataset import CommonVoiceDataset
34
from .libri_speech_dataset import LibriSpeechDataset
45
from .lj_speech_dataset import LJSpeechDataset

audio_data_pytorch/datasets/audio_web_dataset.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import re
55
import tarfile
6-
from typing import Callable, List, Optional, Sequence, Union
6+
from typing import Callable, Dict, List, Optional, Sequence, Union
77

88
import torchaudio
99
from torch import nn
@@ -18,6 +18,58 @@
1818
"""
1919

2020

21+
class AudioProcess:
22+
def __init__(
23+
self,
24+
path: str,
25+
info: Dict,
26+
sample_rate: Optional[int] = None,
27+
transforms: Optional[Callable] = None,
28+
):
29+
self.path = path
30+
self.sample_rate = sample_rate
31+
self.transforms = transforms
32+
self.info = info
33+
self.path_prefix = f"{os.path.splitext(self.path)[0]}_processed"
34+
self.wav_dest_path = None
35+
self.json_dest_path = None
36+
37+
def process_wav(self):
38+
waveform, rate = torchaudio.load(self.path)
39+
40+
if exists(self.sample_rate):
41+
resample = Resample(source=rate, target=self.sample_rate)
42+
waveform = resample(waveform)
43+
rate = self.sample_rate
44+
45+
if exists(self.transforms):
46+
waveform = self.transforms(waveform)
47+
48+
wav_dest_path = f"{self.path_prefix}.wav"
49+
print(wav_dest_path)
50+
torchaudio.save(wav_dest_path, waveform, rate)
51+
52+
self.wav_dest_path = wav_dest_path
53+
return wav_dest_path
54+
55+
def process_info(self):
56+
json_dest_path = f"{self.path_prefix}.json"
57+
with open(json_dest_path, "w") as f:
58+
json.dump(self.info, f)
59+
60+
self.json_dest_path = json_dest_path
61+
return json_dest_path
62+
63+
def __enter__(self):
64+
wav_processed_path = self.process_wav()
65+
json_processed_path = self.process_info()
66+
return wav_processed_path, json_processed_path
67+
68+
def __exit__(self, *args):
69+
os.remove(self.wav_dest_path)
70+
os.remove(self.json_dest_path)
71+
72+
2173
class AudioWebDatasetPreprocess:
2274
def __init__(
2375
self,
@@ -50,12 +102,12 @@ def str_to_tags(self, str: str) -> List[str]:
50102

51103
async def preprocess(self):
52104
urls, path = self.urls, self.root
53-
tarfile_name = os.path.join(path, f"{self.name}.tar")
105+
tarfile_name = os.path.join(path, f"{self.name}.tar.gz")
54106
waveform_id = 0
55107

56108
async with Downloader(urls, path=path) as files:
57109
async with Decompressor(files, path=path) as folders:
58-
with tarfile.open(tarfile_name, "w") as archive:
110+
with tarfile.open(tarfile_name, "w:gz") as archive:
59111
for folder in tqdm(folders):
60112
for wav in tqdm(glob.glob(folder + "/**/*.wav")):
61113
waveform, rate = torchaudio.load(wav)
@@ -112,16 +164,13 @@ class AudioWebDataset(WebDataset):
112164

113165
def __init__(
114166
self,
115-
path: Union[str, Sequence[str]],
167+
urls: Union[str, Sequence[str]],
116168
transforms: Optional[Callable] = None,
117169
batch_size: Optional[int] = None,
118-
recursive: bool = True,
119170
shuffle: int = 128,
120171
**kwargs,
121172
):
122-
paths = path if isinstance(path, (list, tuple)) else [path]
123-
tars = get_all_tar_filenames(paths, recursive=recursive)
124-
super().__init__(urls=tars, **kwargs)
173+
super().__init__(urls=urls, **kwargs)
125174

126175
(
127176
self.shuffle(shuffle)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
import tarfile
3+
from typing import Callable, List, Optional
4+
5+
import pandas as pd
6+
from tqdm import tqdm
7+
8+
from ..utils import Decompressor, Downloader, camel_to_snake, run_async
9+
from .audio_web_dataset import AudioProcess, AudioWebDataset
10+
11+
12+
class ClothoDataset(AudioWebDataset):
13+
def __init__(
14+
self,
15+
root: str,
16+
split: str = "train",
17+
preprocess_sample_rate: Optional[int] = None,
18+
preprocess_transforms: Optional[Callable] = None,
19+
reset: bool = False,
20+
**kwargs,
21+
):
22+
self.root = root
23+
self.split = self.split_conversion(split)
24+
self.preprocess_sample_rate = preprocess_sample_rate
25+
self.preprocess_transforms = preprocess_transforms
26+
27+
if not os.path.exists(self.tar_file_name) or reset:
28+
run_async(self.preprocess())
29+
30+
super().__init__(urls=self.tar_file_name, **kwargs)
31+
32+
def split_conversion(self, split: str) -> str:
33+
return {"train": "development", "valid": "evaluation"}[split]
34+
35+
@property
36+
def urls(self) -> List[str]:
37+
return [
38+
f"https://zenodo.org/record/4783391/files/clotho_audio_{self.split}.7z",
39+
f"https://zenodo.org/record/4783391/files/clotho_captions_{self.split}.csv",
40+
]
41+
42+
@property
43+
def data_path(self) -> str:
44+
return os.path.join(self.root, camel_to_snake(self.__class__.__name__))
45+
46+
@property
47+
def tar_file_name(self) -> str:
48+
return os.path.join(self.data_path, f"clotho_{self.split}.tar.gz")
49+
50+
async def preprocess(self):
51+
urls, path = self.urls, self.data_path
52+
waveform_id = 0
53+
54+
async with Downloader(urls, path=path) as files:
55+
to_decompress = [f for f in files if f.endswith(".7z")]
56+
caption_csv_file = [f for f in files if f.endswith(".csv")][0]
57+
async with Decompressor(to_decompress, path=path) as folders:
58+
captions = pd.read_csv(caption_csv_file)
59+
length = len(captions.index)
60+
61+
with tarfile.open(self.tar_file_name, "w:gz") as archive:
62+
for i, caption in tqdm(captions.iterrows(), total=length):
63+
wav_file_name = caption.file_name
64+
wav_path = os.path.join(folders[0], self.split, wav_file_name)
65+
wav_captions = [caption[f"caption_{i}"] for i in range(1, 6)]
66+
info = dict(text=wav_captions)
67+
68+
with AudioProcess(
69+
path=wav_path,
70+
sample_rate=self.preprocess_sample_rate,
71+
transforms=self.preprocess_transforms,
72+
info=info,
73+
) as (wav, json):
74+
archive.add(wav, arcname=f"{waveform_id:06d}.wav")
75+
archive.add(json, arcname=f"{waveform_id:06d}.json")
76+
77+
waveform_id += 1

audio_data_pytorch/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def is_zip(file_name: str) -> bool:
159159
return file_name.lower().endswith(".zip")
160160

161161

162+
def is_7zip(file_name: str) -> bool:
163+
return file_name.lower().endswith(".7z")
164+
165+
162166
class Decompressor:
163167
def __init__(
164168
self,
@@ -192,6 +196,12 @@ def decompress(self, file_name: str):
192196
elif is_tar(file_name):
193197
with tarfile.open(file_name) as archive:
194198
self.extract_all(archive, path)
199+
elif is_7zip(file_name):
200+
import py7zr
201+
202+
print(f"{self.description}: {path}")
203+
with py7zr.SevenZipFile(file_name, mode="r") as archive:
204+
archive.extractall(path=path)
195205
else:
196206
raise ValueError(f"Unsupported file extension: {file_name}")
197207
return path

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-data-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.12",
6+
version="0.0.13",
77
license="MIT",
88
description="Audio Data - PyTorch",
99
long_description_content_type="text/markdown",
@@ -18,6 +18,7 @@
1818
"requests",
1919
"tqdm",
2020
"aiohttp",
21+
"webdataset",
2122
],
2223
classifiers=[
2324
"Development Status :: 4 - Beta",

0 commit comments

Comments
 (0)