Skip to content

Commit 9a013fd

Browse files
committed
Add file_name to the returned item in Snips dataset (#2775)
Summary: Pull Request resolved: #2775 Reviewed By: carolineechen Differential Revision: D40481144 Pulled By: nateanl fbshipit-source-id: 5d0fb2478767704603a3ec28d74160e7892d4d0e
1 parent 88a8dd4 commit 9a013fd

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

test/torchaudio_unittest/datasets/snips_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _get_mocked_samples(dataset_dir: str, subset: str, seed: int):
5555
transcript, iob, intent = f"{spk}XXX", f"{spk}YYY", f"{spk}ZZZ"
5656
label = "BOS " + transcript + " EOS\tO " + iob + " " + intent
5757
_save_label(label_path, wav_stem, label)
58-
samples.append((waveform, _SAMPLE_RATE, transcript, iob, intent))
58+
samples.append((waveform, _SAMPLE_RATE, wav_stem, transcript, iob, intent))
5959
return samples
6060

6161

@@ -100,12 +100,13 @@ def setUpClass(cls):
100100

101101
def _testSnips(self, dataset, data_samples):
102102
num_samples = 0
103-
for i, (data, sample_rate, transcript, iob, intent) in enumerate(dataset):
103+
for i, (data, sample_rate, file_name, transcript, iob, intent) in enumerate(dataset):
104104
self.assertEqual(data, data_samples[i][0])
105105
assert sample_rate == data_samples[i][1]
106-
assert transcript == data_samples[i][2]
107-
assert iob == data_samples[i][3]
108-
assert intent == data_samples[i][4]
106+
assert file_name == data_samples[i][2]
107+
assert transcript == data_samples[i][3]
108+
assert iob == data_samples[i][4]
109+
assert intent == data_samples[i][5]
109110
num_samples += 1
110111

111112
assert num_samples == len(data_samples)

torchaudio/datasets/snips.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
112112
Path to audio
113113
int:
114114
Sample rate
115+
str:
116+
File name
115117
str:
116118
Transcription of audio
117119
str:
@@ -123,7 +125,7 @@ def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
123125
relpath = os.path.relpath(audio_path, self._path)
124126
file_name = audio_path.with_suffix("").name
125127
transcript, iob, intent = self.labels[file_name]
126-
return relpath, _SAMPLE_RATE, transcript, iob, intent
128+
return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent
127129

128130
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
129131
"""Load the n-th sample from the dataset.
@@ -138,6 +140,8 @@ def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
138140
Waveform
139141
int:
140142
Sample rate
143+
str:
144+
File name
141145
str:
142146
Transcription of audio
143147
str:

0 commit comments

Comments
 (0)