diff --git a/NgramModel.py b/NgramModel.py index 05217ef..8fc88c1 100644 --- a/NgramModel.py +++ b/NgramModel.py @@ -21,9 +21,11 @@ def __init__(self, n: int, name="", auto_pickle=False): """ self.n = n self.pickle_path = self.pathify(name or self.gen_pickle_name()) - if path.exists(self.pickle_path): # Ask if they intend to overwrite existing pickle - if input(f"overwrite {self.pickle_path}? (Y/N)\n").upper() != "Y": - self.pickle_path = self.pathify(self.gen_pickle_name()) + if ( + path.exists(self.pickle_path) + and input(f"overwrite {self.pickle_path}? (Y/N)\n").upper() != "Y" + ): + self.pickle_path = self.pathify(self.gen_pickle_name()) self.context_options: dict[tuple, Counter[str]] = defaultdict(Counter) # dict [context, Counter of possible tokens] self.num_tweets = 0 @@ -44,12 +46,10 @@ def generate_Ngrams(self, string: str): words = string.split(" ") words = [self.start] * (self.n - 1) + words + [self.end] * (self.n - 1) - list_of_tup = [] - - for i in range(len(words) + 1 - self.n): - list_of_tup.append((tuple(words[i + j] for j in range(self.n - 1)), words[i + self.n - 1])) - - return list_of_tup + return [ + (tuple(words[i + j] for j in range(self.n - 1)), words[i + self.n - 1]) + for i in range(len(words) + 1 - self.n) + ] def backup(self): os.makedirs("models", exist_ok=True) @@ -89,7 +89,7 @@ def get_word_prob(self, context: tuple, token: str): # return self.ngram_count[(context, token)] / context_freq def calculate_freq(self, context: tuple): - freq = sum(freq for freq in self.context_options[context].values()) + freq = sum(self.context_options[context].values()) self.context_freq_cache[self, context] = freq return freq diff --git a/browser.py b/browser.py index 6625360..27af955 100644 --- a/browser.py +++ b/browser.py @@ -30,5 +30,6 @@ def get_pin(url): time.sleep(5) - pin = driver.find_element(By.CSS_SELECTOR, "kbd > code").get_attribute("innerText") - return pin + return driver.find_element(By.CSS_SELECTOR, "kbd > code").get_attribute( + "innerText" + ) diff --git a/main.py b/main.py index 296cd88..fd92439 100644 --- a/main.py +++ b/main.py @@ -54,8 +54,7 @@ def log_and_backup(): if __name__ == "__main__": for i in signal.valid_signals(): - if (i == signal.SIGKILL or - i == signal.SIGSTOP): + if i in [signal.SIGKILL, signal.SIGSTOP]: continue signal.signal(i, exit_gracefully) diff --git a/twitter.py b/twitter.py index ae6d7cc..ba66deb 100644 --- a/twitter.py +++ b/twitter.py @@ -73,7 +73,7 @@ def get_tweets(self, keyword: str) -> list[str]: :return: a list of tweets returned as strings """ search_param["query"] = keyword - to_return = dict() + to_return = {} # noinspection PyBroadException try: to_return = requests.get(Twitter.__search_url, search_param, auth=self.bearer_oauth).json()