diff --git a/mldata/dataset.py b/mldata/dataset.py index 914a3c1..42cfef2 100644 --- a/mldata/dataset.py +++ b/mldata/dataset.py @@ -1,17 +1,267 @@ -# -*- coding: utf-8 -*- +"""Datasets store the data used for experiments.""" +from itertools import accumulate +import hashlib +import numpy as np -class Dataset(list): - info = {} +BUFFER_SIZE = 1000 - def __init__(self, data=[]): - super(Dataset, self).__init__(data) +class Dataset(): + """Interface to interact with physical dataset -class LazyDataset(Dataset): - def __init__(self, lazy_functions): - super(LazyDataset, self).__init__() - self.lazy_functions = lazy_functions + A `Dataset` presents a unified access to data, independent of the + implementation details such as laziness. + + Parameters + ---------- + meta_data : Metadata + data : array_like + target : array_like + + Attributes + ---------- + meta_data : Metadata + Information about the data. See `MetaData` documentation for more info. + data : array_like + The array of data to train on. + target : array_like, optional + The array of target to use for supervised learning. `target` should + be `None` when the dataset doesn't support supervised learning. + + """ + def __init__(self, meta_data, data, target=None): + assert len(data) == meta_data.nb_examples,\ + "The metadata ``nb_examples`` is inconsistent with the length of "\ + "the dataset." + assert len(data) == meta_data.splits[-1] or\ + len(data) == sum(meta_data.splits),\ + "The metadata ``splits`` is inconsistent with the length of "\ + "the dataset." + self.data = data + self.target = target + self.meta_data = meta_data + + def __len__(self): + return self.meta_data.nb_examples + + def __hash__(self): + """ Hash function used for versioning.""" + hasher = hashlib.md5() + for l in self.data: + hasher.update(np.array(l)) + if self.target is not None: + for l in self.target: + hasher.update(np.array(l)) + return hasher.hexdigest()[:8] def __iter__(self): - return self.lazy_functions['__iter__']() + """Provide an iterator handling if the Dataset has a target.""" + #todo: retest efficiency of this buffering in python3. With zip being now lazy, it might not be better than the vanilla iter. + buffer = min(BUFFER_SIZE, len(self)) + + if self.target is not None: + for idx in range(0, len(self.data), buffer): + stop = min(idx + buffer, len(self)) + for ex, tg in zip(self.data[idx:stop], + self.target[idx:stop]): + yield (ex, tg) + else: + for idx in range(0, len(self.data), buffer): + stop = min(idx + buffer, len(self)) + for ex in self.data[idx:stop]: + yield (ex,) + + def __getitem__(self, key): + """Get the entry specified by the key. + + Parameters + ---------- + key : numpy-like key + The `key` can be a single integer, a slice or a tuple defining + coordinates. Can be treated as a NumPy key. + + Returns + ------- + (array_like, array_like) or (array_like,) + Return the element specified by the key. It can be an array or + simply a scalar of the type defined by the data [and target + arrays]. + The returned values are put in a tuple (data, target) or (data,). + + """ + if self.target is not None: + return self.data[key], self.target[key] + else: + return self.data[key], + + def _split_iterators(self, start, end, minibatch_size=1): + """ Iterate on a split. + + Parameters + ---------- + start : int + Id of the first element of the split. + end : int + Id of the next element after the last. + + """ + buffer = min(BUFFER_SIZE, end - start) + + if self.target is not None: + for idx in range(start, end, buffer): + stop = min(idx+buffer, end) + for i in range(idx, stop, minibatch_size): + j = min(stop, i+minibatch_size) + yield (self.data[i:j], self.target[i:j].reshape((1, -1))) + else: + for idx in range(start, end, buffer): + stop = min(idx+buffer, end) + for i in range(idx, stop, minibatch_size): + j = min(stop, i+minibatch_size) + yield (self.data[i:j],) + + def get_splits_iterators(self, minibatch_size=1): + """ Creates a tuple of iterator, each iterating on a split. + + Each iterators returned is used to iterate over the corresponding + split. For example, if the ``Metadata`` specifies a ``splits`` of + (10, 20, 30), ``get_splits_iterators`` returns a 3-tuple with an + iterator for the ten first examples, another for the ten next and a + third for the ten lasts. + + Parameters + ---------- + minibatch_size : int + The size of minibatches received each iteration. + + Returns + ------- + tuple of iterable + A tuple of iterator, one for each split. + + """ + sp = self._normalize_splits() + + itors = [self._split_iterators(start, end, minibatch_size) for + (start, end) in zip([0] + sp, sp)] + return itors + + def get_splits(self): + """ Get the datasets arrays. + + WARNING : This method will try to load the entire dataset in memory. + + Returns + ------- + tuple of tuple of array + The data and targets sliced in multiple subarrays. + ``((data1, target1), (data2, target2), ...)`` + + """ + sp = self._normalize_splits() + indices = zip([0]+sp, sp) + + if self.target is not None: + return tuple((self.data[slice(*s)], self.target[slice(*s)]) + for s in indices) + else: + return tuple((self.data[slice(*s)],) for s in indices) + + + def apply(self): + """Apply the preprocess specified in the associated metadata. + + This methods simply apply the function given in the metadata (the + identity by default) to the dataset. This function is supposed to do + work on the data and the targets, leaving the rest intact. Still, + as long as the result is still a `Dataset`, `apply` will work. + + Returns + ------- + Dataset + The preprocessed dataset. + + """ + ds = self.meta_data.preprocess(self) + assert isinstance(ds, Dataset) + return ds + + def _normalize_splits(self): + sp = list(self.meta_data.splits) + + # normalize the splits + if sum(sp) == len(self): + sp = list(accumulate(sp)) + assert sp[-1] == len(self), "The splits couldn't be normalized" + + return sp + + +class Metadata(): + """Keep track of information about a dataset. + + An instance of this class is required to build a `Dataset`. It gives + information on how the dataset is called, the split, etc. + + A single `Dataset` can have multiple metadata files specifying different + split or a special pre-processing that needs to be applied. The + philosophy is to have a single physical copy of the dataset with + different views that can be created on the fly as needed. + + Attributes + ---------- + name : str + The name of the `Dataset`. Default: "Default". + nb_examples : int + The number of example in the dataset (including all splits). Default: 0. + dictionary : Dictionary + _Not yet implemented_ + Gives a mapping of words (str) to id (int). Used only when the + dataset has been saved as an array of numbers instead of text. + Default: None + splits : tuple of int + Specifies the split used by this view of the dataset. Default: (). + The numbers can be either the number of the last examples in each + subsets or the number of examples in each categories. + preprocess : function or None + A function that is callable on a `Dataset` to preprocess the data. + The function cannot be a lambda function since those can't be pickled. + Default: identity function. + hash : str + The hash of the linked ``Dataset``. Default: "". + + """ + def __init__(self): + self.name = "Default" + self.nb_examples = 0 + self.dictionary = None + self.splits = () + self.preprocess = default_preprocess + self.hash = "" + + +def default_preprocess(dset): + return dset + + +class Dictionary: + """Word / integer association list + + This dictionary is used in `Metadata` for NLP problems. This class + ensures O(1) conversion from id to word and O(log n) conversion from word to + id. + + Notes + ----- + The class is *not yet implemented*. + + Plans are for the dictionary to be implemented as a list of words + alphabetically ordered with the index of the word being its id. A method + implements a binary search over the words in order to retrieve its id. + + """ + + def __init__(self): + raise NotImplementedError("The class Dictionary is not yet " + "implemented.") diff --git a/mldata/dataset_store.py b/mldata/dataset_store.py index 17d7dc8..e6abfa5 100644 --- a/mldata/dataset_store.py +++ b/mldata/dataset_store.py @@ -1,48 +1,271 @@ - +""" Manages dataset read/write operations.""" +#todo: Remove precise versions of datasets and manage dependencies. import os +import pickle as pk + import h5py import numpy as np -import itertools -import types -import mldata -import mldata.utils -from mldata.dataset import Dataset, LazyDataset +from SMARTdata.mldata.utils import config as cfg +from SMARTdata.mldata.dataset import Dataset, Metadata + + +def load(dset_name, version_name="baseDataset", lazy=False): + """ Load a dataset given its name. + + The load function will load the ``Dataset`` ``name`` provided it exists in + one of the datasets folders. This function allows reading of files which + are bigger than available memory using ``h5py``. + + Parameters + ---------- + name : str + The name of the dataset to load. The first match from the list of + dataset folder will be used, thus allowing private copy of a dataset. + version_name : str + If this is a special version of a dataset, use this name to indicate + it. Default: "baseDataset". + lazy : bool + If set to ``True``, the dataset will be read with ``h5py`` without + loading the whole dataset in memory. If set to ``False``, the file is + mapped in memory. Default: False. + + Returns + ------- + Dataset + Return the loaded dataset, if it exists. Else, return ``None``. + + Raises + ------ + LookupError + If the dataset ``dset_name`` does not exist, a ``LookupError`` is + raised. + + """ + path = None + if cfg.dataset_exists(dset_name): + path = cfg.get_dataset_path(dset_name) + else: + raise LookupError("This dataset does not exist.") + return _load_from_file(dset_name + '_' + version_name, path, lazy) + + +def save(dataset, version_name="baseDataset"): + """ Save the dataset, manages versions. -from mldata.utils.constants import DATASETS_FOLDER -from mldata.utils.utils import buffered_iter + A ``Dataset`` is saved according to its name and the ``version_name`` + provided. The ``version_name`` is used to denote different view of the + data, either using the ``preprocess`` field of a ``Metadata`` class or by + saving a new version of the dataset (with a different hash). The first + method is the most compact while the second method is more efficient when + loading a dataset. -def supervised_factory(examples, targets): - def lazy_iter(): - for e, t in itertools.izip(buffered_iter(examples), buffered_iter(targets)): - yield e, t + To save a dataset using the preprocessing method, the dataset *must not* + contain the preprocessed data, but the original dataset on which the + preprocess is applied. - lazy_functions = { - '__iter__': lazy_iter, - } + The dataset is split between two files : - return lazy_functions + - the data file ``[hash].data`` + - the metadata file [dataset Name]_[dataset version].meta -def load(path_or_name, lazy=False): - path = path_or_name - if not os.path.isfile(path): - if (path_or_name + ".h5") not in os.listdir(DATASETS_FOLDER): - print "Unknown dataset: '{0}'".format(path_or_name) - return + This function will replace a metadata file with the same version name + without prompting. - path = os.path.join(DATASETS_FOLDER, path_or_name + ".h5") + Parameters + ---------- + dataset : Dataset + The dataset to be saved. + version_name : str + If this is a special version of a dataset, use this name to indicate + it. Default: "baseDataset". - return _load_from_file(path, lazy) + """ + dset_name = dataset.meta_data.name -def _load_from_file(path, lazy=False): - dataset = Dataset() - - if not lazy: - with h5py.File(path, mode='r') as f: - dataset = Dataset(itertools.izip(f['input'][()], f['output'][()])) + if not cfg.dataset_exists(dset_name): + cfg.add_dataset(dset_name) + + dset_path = cfg.get_dataset_path(dset_name) + + dset_hash = dataset.__hash__() + dataset.meta_data.hash = dset_hash # insures metadata hash is up to date + + dset_file = dataset.meta_data.hash + ".data" + _save_dataset(dataset, dset_path, dset_file) + + meta_file = dset_name + '_' + version_name + ".meta" + _save_metadata(dataset.meta_data, dset_path, meta_file) + + +def _load_from_file(name, path, lazy): + """ Call to ``h5py`` to load the file. + + """ + metadata = None + try: + with open(os.path.join(path, name) + '.meta', 'rb') as f: + metadata = pk.load(f) + except FileNotFoundError: + raise LookupError("This dataset/version pair does not exist : " + name) + + datasetFile = None + file_to_load = os.path.join(path, metadata.hash + ".data") + if lazy: + datasetFile = h5py.File(file_to_load, mode='r', driver=None) else: - f = h5py.File(path, mode='r') - lazy_functions = supervised_factory(f['input'], f['output']) - dataset = LazyDataset(lazy_functions) + datasetFile = h5py.File(file_to_load, mode='r', driver='core') + + data = datasetFile['/']["data"] + target = None + try: + target = datasetFile['/']["targets"] + except: + pass + dset = Dataset(metadata, data, target) + dset._fileHandle = h5pyFileWrapper(datasetFile) + return dset + + +def _save_dataset(dataset, path, filename): + """Call to ``h5py`` to write the dataset + + Save the dataset and the associated metadata into their respective folder in + the dataset folder. + + Parameters + ---------- + dataset : Dataset + path : str + filename : str + + """ + if filename not in os.listdir(path): + fullname = os.path.join(path, filename) + with h5py.File(fullname, mode='w') as f: + f.create_dataset('data', data=dataset.data) + if dataset.target is not None: + f.create_dataset('targets', data=dataset.target) + + +def _save_metadata(metadata, path, filename): + """ Pickle the metadata. + + Parameters + ---------- + metadata : Metadata + path : str + filename : str + + """ + #todo: A dataset could be orphaned if overwritten by another metadata file. This needs to be checked in a future version. + if filename not in os.listdir(path): + with open(os.path.join(path, filename), 'wb') as f: + pk.dump(metadata, f, pk.HIGHEST_PROTOCOL) + + +def CSV_importer(filepath, + name, + splits, + target_column=None, + dtype=np.float64, + comments='#', + delimiter=' ', + converters=None, + skiprows=0, + usecols=None): + """ Import a CSV file into a ``Dataset``. + + From the ``filepath`` of a CSV file (using commas), create a ``Dataset`` + which can then be saved on disk. This importer supports only numbered + inputs (int, float, boolean values). + + Parameters + ---------- + filepath : str + The path of the CSV file to be imported. + name : str + The name of this dataset used to store the ``Dataset`` on disk. + splits : tuple of int + Gives the split of the dataset, like (train, valid, test). The + integers required is the id of the last example of a sub-dataset plus 1. + For example, if there is 8000 examples with 5000 in the training set, + 2000 in the validation set and 1000 in the test set, the splits would be + ``(5000, 7000, 8000)``. + An alternative form where each numbers represent the count of each + subsets is also supported. + target_column : int, optional + The column number of the target. If no target is provided, set to + ``None``. Default: None. + dtype : data-type, optional + Data-type of the resulting array; default: float. If this is a record + data-type, the resulting array will be 1-dimensional, and each row will + be interpreted as an element of the array. In this case, the number of + columns used must match the number of fields in the data-type. + comments : str, optional + The character used to indicate the start of a comment; default: ‘#’. + delimiter : str, optional + The string used to separate values. By default, this is any whitespace. + converters : dict, optional + A dictionary mapping column number to a function that will convert that + column to a float. E.g., if column 0 is a date string: + ``converters = {0: datestr2num}``. Converters can also be used to + provide a default value for missing data : + ``converters = {3: lambda s: float(s.strip() or 0)}``. Default: None. + skiprows : int, optional + Skip the first skiprows lines; default: 0. + usecols : sequence, optional + Which columns to read, with 0 being the first. For example, + ``usecols = (1,4,5)`` will extract the 2nd, 5th and 6th columns. The + default, None, results in all columns being read. + + Returns + ------- + Dataset + A ``Dataset`` with default values for ``Metadata``. + + """ + data = np.loadtxt(filepath, dtype, comments, delimiter, + converters, skiprows, usecols) + + meta = Metadata() + meta.name = name + meta.splits = splits + assert len(data) == splits[-1] or \ + len(data) == sum(splits),\ + "The dataset read is not consistent with the split given." + meta.nb_examples = len(data) + + dset = None + if target_column is not None: + targets = data[:, target_column].reshape((-1, 1)) + examples = data[:, list(range(0, target_column)) + + list(range(target_column+1, data.shape[1]))] + dset = Dataset(meta, examples, targets) + else: + dset = Dataset(meta, data) + + dset.meta_data.hash = dset.__hash__() + + return dset + + +def remove(name): + """ Remove a dataset from the datasets folder. + + Parameters + ---------- + name : str + Name of the dataset to delete. + + """ + cfg.remove_dataset(name) + + +class h5pyFileWrapper: + """ Used to close handle when a ``Dataset`` is destroyed.""" + def __init__(self, file): + self.file = file - return dataset + def __del__(self): + self.file.close() diff --git a/mldata/utils/config.py b/mldata/utils/config.py new file mode 100644 index 0000000..c897e7c --- /dev/null +++ b/mldata/utils/config.py @@ -0,0 +1,106 @@ +""" Manages the configuration file for MLData.""" +import configparser +import os +from os.path import expanduser, join + +from shutil import rmtree + +CONFIGFILE = join(expanduser("~"), '.mldataConfig') + +def add_dataset(dataset_name): + """ Add a dataset to the index.""" + + path = os.path.join(_load_path(), dataset_name) + + if not os.path.isdir(path): + os.mkdir(path) + + cp = _load_config() + cp['datasets'][dataset_name] = path + _save_config(cp) + +def remove_dataset(dataset_name): + """ Remove a dataset from the index.""" + path = os.path.join(_load_path(), dataset_name) + + if os.path.isdir(path): # Does path exist ? + rmtree(path, ignore_errors=True) + + cp = _load_config() + cp.remove_option('datasets', dataset_name) + _save_config(cp) + +def get_dataset_path(dataset_name): + """ Retreive the dataset path. + + Parameters + ---------- + dataset_name : str + Name of a dataset + + Returns + ------- + str + The string of the path where ``dataset_name`` is saved. + + Raises + ------ + KeyError + If the path specified in the config file does not exist in the system. + """ + cp = _load_config() + path = cp['datasets'][dataset_name] + if not os.path.isdir(path): + raise KeyError("Wrong path in .mldataConfig.") + else: + return path + +def dataset_exists(dataset_name): + """ Check if the dataset exists.""" + return _load_config().has_option('datasets', dataset_name) + +def _save_config(config): + """ Save a config file in the default config file emplacement.""" + with open(CONFIGFILE, 'w') as f: + config.write(f) + + +def _load_config(): + """ Loads the configuration file for MLData.""" + if not os.path.exists(CONFIGFILE): + _create_default_config() + cfg = configparser.ConfigParser() + cfg.read(CONFIGFILE) + return cfg + +def _create_default_config(): + """ Build and save a default config file for MLData. + + The default config is saved as ``.mldataConfig`` in the ``$HOME`` folder + or its equivalent. It stores the emplacement of dataset files and make an + index of accessible datasets. + + """ + cp = configparser.ConfigParser() + path = join(expanduser("~"), '.datasets') + if not os.path.isdir(path): + os.mkdir(path) + cp['config'] = {'path': path} + cp['datasets'] = {} + _save_config(cp) + with open(CONFIGFILE, 'a') as f: + f.write("# Datasets path shouldn't be changed manually.\n") + +def _load_path(): + """ Load the config file at the default emplacement. + + Returns + ------- + str + A list of strings giving the paths to dataset folders. + + """ + cp = _load_config() + path = cp['config']['path'] + assert os.path.isdir(path), "Configured path is not a valid directory." + return path diff --git a/mldata/utils/utils.py b/mldata/utils/utils.py deleted file mode 100644 index 662a006..0000000 --- a/mldata/utils/utils.py +++ /dev/null @@ -1,6 +0,0 @@ - - -def buffered_iter(arr, buffer_size=1000): - for idx in xrange(0, len(arr), buffer_size): - for e in arr[idx:idx+buffer_size]: - yield e \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..b4c3a7d --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,97 @@ +import copy +from itertools import chain + +import numpy as np +import nose.tools as nt + +from SMARTdata.mldata.dataset import Dataset, Metadata + + +class Dataset_test: + @classmethod + def setup_class(self): + self.dataSmall = np.random.random((30, 5)) + self.dataLarge = np.random.random((3000, 5)) + self.targetSmall = np.random.random((30, 1)) + self.targetLarge = np.random.random((3000, 1)) + + self.metadataS = Metadata() + self.metadataS.splits = (10, 20, 30) + self.metadataS.nb_examples = 30 + self.metadataL = Metadata() + self.metadataL.splits = (1000, 1000, 1000) + self.metadataL.nb_examples = 3000 + self.dsetS = Dataset(self.metadataS, self.dataSmall, self.targetSmall) + self.dsetL = Dataset(self.metadataL, self.dataLarge, self.targetLarge) + + def test_Dataset(self): + dset = Dataset(self.metadataS, self.dataSmall) + nt.assert_equal(dset.meta_data, self.metadataS) + nt.assert_true(np.array_equal(dset.data, self.dataSmall)) + nt.assert_is_none(dset.target) + + dsetS = Dataset(self.metadataS, self.dataSmall, self.targetSmall) + nt.assert_is_not_none(dsetS.target) + + def test_hash(self): + nt.assert_equal(self.dsetS.__hash__(), self.dsetS.__hash__()) + + dset = Dataset(self.metadataS, self.dataSmall, self.targetSmall) + nt.assert_equal(dset.__hash__(), self.dsetS.__hash__()) + + dset2 = Dataset(self.metadataS, self.dataSmall) + nt.assert_not_equal(dset2.__hash__(), dset.__hash__()) + + dset3 = Dataset(self.metadataL, self.dataLarge) + nt.assert_not_equal(dset2.__hash__(), dset3.__hash__()) + nt.assert_not_equal(dset3.__hash__(), dset.__hash__()) + + meta = Metadata() + meta.name = "AnotherName" + meta.splits = (10, 10, 10) # alternative split form + meta.nb_examples = 30 + dset4 = Dataset(meta, self.dataSmall) + nt.assert_equal(dset4.__hash__(), dset2.__hash__()) + nt.assert_not_equal(dset4.__hash__(), dset3.__hash__()) + + def test_len(self): + nt.assert_equal(len(self.dsetS), len(self.dsetS.data)) + nt.assert_equal(len(self.dsetS), self.dsetS.meta_data.nb_examples) + + def test_preprocess(self): + data2 = self.dataSmall * 2 + meta = copy.deepcopy(self.metadataS) + meta.preprocess = double_dset + dset2 = Dataset(meta, self.dataSmall, self.targetSmall) + dset2 = dset2.apply() + nt.assert_true(np.array_equal(data2, dset2.data)) + + def test_iter(self): + # With targets + dt, tg = [[z[i] for z in self.dsetS] for i in [0, 1]] + nt.assert_true(np.array_equal(np.array(dt), self.dataSmall)) + # Without targets + dset = Dataset(self.metadataS, self.dataSmall) + nt.assert_true(np.array_equal(np.array([z[0] for z in dset]), + self.dataSmall)) + + def test_get(self): + for i in range(len(self.dataSmall)): + nt.assert_true(np.array_equal(self.dataSmall[i], self.dsetS[i][0])) + + def test_get_splits_iterators(self): + citer = chain.from_iterable(self.dsetS.get_splits_iterators()) + for a, b in zip(citer, self.dsetS): + d1 = a[0] + d2 = [b[0]] + nt.assert_true(np.array_equal(d1,d2)) + + sp = self.dsetL.meta_data.splits + for splitn, it in zip(sp, self.dsetL.get_splits_iterators()): + nt.assert_equal(sum(1 for _ in it), splitn) + + +def double_dset(dset): + """ Basic preprocessing function. """ + return Dataset(dset.meta_data, dset.data * 2, dset.target * 2) + diff --git a/tests/test_dataset_store.py b/tests/test_dataset_store.py index 17e3c50..6fec26f 100644 --- a/tests/test_dataset_store.py +++ b/tests/test_dataset_store.py @@ -1,61 +1,71 @@ -from ipdb import set_trace as dbg - import os -import tempfile -import hashlib + import numpy as np -import itertools -import time -from functools import partial - -from numpy.testing import (assert_equal, - assert_almost_equal, - assert_array_equal, - assert_array_almost_equal, - assert_raises) - - -import mldata -import mldata.dataset_store as dataset_store - -DATA_DIR = os.path.join(os.path.realpath(mldata.__path__[0]), "..", "tests", "data") - -def load_mnist(lazy): - """ - Load mnist dataset from a hdf5 file and test if it matches mlpython's one. - """ - dataset_name = 'mnist' - - start = time.time() - import mlpython.datasets.store as mlstore - mldatasets = mlstore.get_classification_problem(dataset_name, load_to_memory= (not lazy)) - print "mlpython version loaded ({0:.2f}sec).".format(time.time() - start) - - start = time.time() - dataset_name = os.path.join(os.environ['MLPYTHON_DATASET_REPO'], dataset_name + ".h5") - dataset = mldata.dataset_store.load(dataset_name, lazy=lazy) - print "mldata version loaded ({0:.2f}sec).".format(time.time() - start) - - print "Comparing first 1000..." - count = 0 - for (e1, t1), (e2, t2) in itertools.izip(dataset, itertools.chain(*mldatasets)): - #print t1, t2 - assert_array_almost_equal(e1, e2) - assert_equal(t1, t2) - - count += 1 - if count >= 1000: - break - - -def test_load_mnist(): - """ - Load mnist dataset from a hdf5 file and test if it matches mlpython's one. - """ - load_mnist(lazy=False) - -def test_load_mnist_lazy(): - """ - Lazy load mnist dataset from a hdf5 file and test if it matches mlpython's one. - """ - load_mnist(lazy=True) +import nose.tools as nt + +import SMARTdata.mldata.dataset_store as ds + +RND_MATRIX = np.random.random((100, 10)) + + +def setup_module(): + np.savetxt("test.csv", RND_MATRIX) + + +def teardown_module(): + os.remove("test.csv") + ds.remove("test_dset") + + +def test_CSV_importer(): + dset = ds.CSV_importer("test.csv", + "test_dset", + (70, 90, 100), + 0) + + nt.assert_true(np.array_equal(RND_MATRIX[:, 1:], dset.data)) + + +def test_save_load(): + dset = ds.CSV_importer("test.csv", + "test_dset", + (70, 90, 100), + 0) + dset_nt = ds.CSV_importer("test.csv", + "test_dset", + (70, 90, 100)) + ds.save(dset, "v1") + ds.save(dset_nt, "noTarget") + dset2 = ds.load("test_dset", "v1") + dset_nt2 = ds.load("test_dset", "noTarget") + + nt.assert_equal(dset.__hash__(), dset2.__hash__()) + nt.assert_equal(dset.meta_data.name, dset2.meta_data.name) + nt.assert_equal(dset.meta_data.dictionary, dset2.meta_data.dictionary) + nt.assert_equal(dset.meta_data.nb_examples, dset2.meta_data.nb_examples) + nt.assert_equal(dset.meta_data.splits, dset2.meta_data.splits) + nt.assert_equal(dset2.meta_data.hash, dset2.__hash__()) + + ndata = np.array(dset2.data) + dset2.data = ndata * 2 + + ds.save(dset2, version_name="v2") + dset3 = ds.load("test_dset", "v2",lazy=True) + + nt.assert_not_equal(dset3.__hash__(), dset.__hash__()) + nt.assert_equal(dset3.meta_data.hash, dset3.__hash__()) + nt.assert_equal(dset.meta_data.name, dset3.meta_data.name) + nt.assert_equal(dset.meta_data.dictionary, dset3.meta_data.dictionary) + nt.assert_equal(dset.meta_data.nb_examples, dset3.meta_data.nb_examples) + nt.assert_equal(dset.meta_data.splits, dset3.meta_data.splits) + + # handle missing datasets + with nt.assert_raises(LookupError): + ds.load("inexistant_dataset") + + with nt.assert_raises(LookupError): + ds.load("test_dset", "v3") + + nt.assert_is_none(dset_nt2.target) + + diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py new file mode 100644 index 0000000..f8fffa2 --- /dev/null +++ b/tests/utils/test_config.py @@ -0,0 +1,44 @@ +import os + +import nose.tools as nt + +from SMARTdata.mldata.utils import config as cfg + + +def setup_module(): + # save current config file + if os.path.isfile(cfg.CONFIGFILE): + os.rename(cfg.CONFIGFILE, cfg.CONFIGFILE + ".bak") + + +def teardown_module(): + # restore config file + os.rename(cfg.CONFIGFILE + ".bak", cfg.CONFIGFILE) + + +def test_load_config(): + cf = cfg._load_config() + path = os.path.join(os.path.expanduser("~"), '.datasets') + + nt.assert_equal(path, cf['config']['path']) + nt.assert_equal(path, cfg._load_path()) + nt.assert_true(cf.has_section('datasets')) + + +def test_add_remove(): + cfg.add_dataset("test_dataset") + nt.assert_true(cfg.dataset_exists("test_dataset")) + + nt.assert_equal(cfg.get_dataset_path("test_dataset"), + os.path.join(cfg._load_path(), "test_dataset")) + path = cfg.get_dataset_path("test_dataset") + nt.assert_true(os.path.isdir(path)) + + cfg.remove_dataset("test_dataset") + nt.assert_false(cfg.dataset_exists("test_dataset")) + nt.assert_false(os.path.isdir(path)) + + + + +