diff --git a/hare/main.py b/hare/main.py index e8b8171..4f91321 100644 --- a/hare/main.py +++ b/hare/main.py @@ -6,6 +6,7 @@ from os.path import dirname, abspath, isdir, realpath from urllib.request import urlretrieve from zipfile import ZipFile +import numpy as np from hare.brain import AbstractBrain from hare.conversation import Conversation @@ -114,7 +115,7 @@ def train(self): texts.append(' LINEBREAK '.join(conversation.get_all_utterances_for_speaker(speaker))) target.append(label) - self.brain.train(texts,target) + self.brain.train(np.array(texts),np.array(target)) def save(self, location : str): self.brain.save(location) @@ -414,4 +415,4 @@ def load_example_conversations() -> List[Conversation]: speaker, content = line.split('\t') current_conversation.add_utterance(speaker, content) - return conversations \ No newline at end of file + return conversations