diff --git a/textblob/classifiers.py b/textblob/classifiers.py index 782bbebc..742e837c 100644 --- a/textblob/classifiers.py +++ b/textblob/classifiers.py @@ -76,9 +76,22 @@ def basic_extractor(document, train_set): :param document: The text to extract features from. Can be a string or an iterable. :param list train_set: Training data set, a list of tuples of the form - ``(words, label)``. + ``(words, label)`` OR an iterable of strings. """ - word_features = _get_words_from_dataset(train_set) + + try: + el_zero = next(iter(train_set)) #Infer input from first element. + except StopIteration: + return {} + if isinstance(el_zero, basestring): + word_features = [w for w in chain([el_zero],train_set)] + else: + try: + assert(isinstance(el_zero[0], basestring)) + word_features = _get_words_from_dataset(chain([el_zero],train_set)) + except: + raise ValueError('train_set is proabably malformed.') + tokens = _get_document_tokens(document) features = dict(((u'contains({0})'.format(word), (word in tokens)) for word in word_features)) @@ -123,6 +136,7 @@ def __init__(self, train_set, feature_extractor=basic_extractor, format=None, ** self.train_set = self._read_data(train_set, format) else: # train_set is a list of tuples self.train_set = train_set + self._word_set = _get_words_from_dataset(self.train_set) #Keep a hidden set of unique words. self.train_features = None def _read_data(self, dataset, format=None): @@ -166,7 +180,7 @@ def extract_features(self, text): ''' # Feature extractor may take one or two arguments try: - return self.feature_extractor(text, self.train_set) + return self.feature_extractor(text, self._word_set) except (TypeError, AttributeError): return self.feature_extractor(text) @@ -260,6 +274,7 @@ def update(self, new_data, *args, **kwargs): ``(text, label)``. """ self.train_set += new_data + self._word_set.update(_get_words_from_dataset(new_data)) self.train_features = [(self.extract_features(d), c) for d, c in self.train_set] try: