From 1cc6831a75816e4cfd1f8303db332988e33ac8ed Mon Sep 17 00:00:00 2001 From: Sharan Sairamesh Date: Tue, 10 Jun 2025 08:52:18 -0700 Subject: [PATCH 1/2] Created new context strategy function with base64 function (#265) Summary: Pull Request resolved: https://github.com/facebookresearch/AugLy/pull/265 Created a context interface, and strategy parent class for base64 function encoding and future types of encoding. Lot of code is reused from diffs D75792134 and D75792239. Differential Revision: D76028127 --- .../text_tests/expected_metadata.json | 14 +-- .../tests/text_tests/functional_unit_test.py | 37 +++---- .../tests/text_tests/transforms_unit_test.py | 65 +++++++----- augly/text/__init__.py | 14 +-- augly/text/augmenters/__init__.py | 9 +- augly/text/augmenters/base64.py | 36 +++++++ augly/text/augmenters/encode_base64.py | 86 ---------------- augly/text/augmenters/encode_text_context.py | 21 ++++ augly/text/augmenters/encode_text_strategy.py | 99 +++++++++++++++++++ augly/text/augmenters/utils.py | 5 + augly/text/functional.py | 40 +++++--- augly/text/intensity.py | 21 +++- augly/text/transforms.py | 21 ++-- 13 files changed, 292 insertions(+), 176 deletions(-) create mode 100644 augly/text/augmenters/base64.py delete mode 100644 augly/text/augmenters/encode_base64.py create mode 100644 augly/text/augmenters/encode_text_context.py create mode 100644 augly/text/augmenters/encode_text_strategy.py diff --git a/augly/tests/assets/expected_metadata/text_tests/expected_metadata.json b/augly/tests/assets/expected_metadata/text_tests/expected_metadata.json index 7cc600be..76cc1461 100644 --- a/augly/tests/assets/expected_metadata/text_tests/expected_metadata.json +++ b/augly/tests/assets/expected_metadata/text_tests/expected_metadata.json @@ -82,18 +82,20 @@ "src_length": 1 } ], - "encode_base64": [ + "encode_text": [ { "dst_length": 1, "input_type": "list", "intensity": 100.0, - "name": "encode_base64", + "name": "encode_text", "src_length": 1, - "granularity": "all", "aug_min": 1, - "aug_max": 10, - "aug_p": 0.3, - "n": 1 + "aug_max": 1, + "aug_p": 1.0, + "n": 1, + "p": 1.0, + "encoder": "base64", + "method": "sentence" } ], "get_baseline": [ diff --git a/augly/tests/text_tests/functional_unit_test.py b/augly/tests/text_tests/functional_unit_test.py index 5b75a6bd..f45ffe50 100644 --- a/augly/tests/text_tests/functional_unit_test.py +++ b/augly/tests/text_tests/functional_unit_test.py @@ -10,7 +10,9 @@ import unittest from augly import text as txtaugs +from augly.text.augmenters.utils import Encoding from augly.utils import FUN_FONTS_GREEK_PATH +from nlpaug.util import Method class FunctionalTextUnitTest(unittest.TestCase): @@ -51,34 +53,23 @@ def test_contractions(self) -> None: augmented_words[0] == "I would call him but I don't know where he's gone" ) - def test_encode_base64_all(self) -> None: - augmented_words = txtaugs.encode_base64("Hello, world!") - self.assertTrue(augmented_words[0] == "SGVsbG8sIHdvcmxkIQ==") - - def test_encode_base64_word(self) -> None: - random.seed(42) # Set seed for reproducibility - augmented_words_word = txtaugs.encode_base64( - "Hello, world!", granularity="word", aug_min=1, aug_max=1, aug_p=1.0 + def test_encode_text_base64_sentence(self) -> None: + augmented_words = txtaugs.encode_text( + "Hello, world!", 1, 1, 1.0, Method.SENTENCE, Encoding.BASE64 ) - self.assertEqual(augmented_words_word[0], "SGVsbG8=, world!") + self.assertEqual(augmented_words[0], "SGVsbG8sIHdvcmxkIQ==") - def test_encode_base64_char(self) -> None: - random.seed(42) - augmented_words_char = txtaugs.encode_base64( - "Hello, world!", granularity="char", aug_min=1, aug_max=2, aug_p=1.0 + def test_encode_text_base64_word(self) -> None: + augmented_words_word = txtaugs.encode_text( + "Hello, world!", 1, 1, 1.0, Method.WORD, Encoding.BASE64 ) - self.assertEqual(augmented_words_char[0], "SA==ellbw== LA== wbw==rlZA== IQ==") + self.assertEqual(augmented_words_word[0], "SGVsbG8=, world!") - def test_encode_base64_general(self) -> None: - random.seed(42) - augmented_words_low_p = txtaugs.encode_base64( - "Hello, world!", granularity="word", aug_min=1, aug_max=2, aug_p=0.1 - ) - random.seed(42) - augmented_words_high_p = txtaugs.encode_base64( - "Hello, world!", granularity="word", aug_min=1, aug_max=2, aug_p=0.9 + def test_encode_text_base64_char(self) -> None: + augmented_words_char = txtaugs.encode_text( + "Hello, world!", 1, 1, 1.0, Method.CHAR, Encoding.BASE64 ) - self.assertTrue(len(augmented_words_high_p[0]) > len(augmented_words_low_p[0])) + self.assertEqual(augmented_words_char[0], "SA==ello LA== dw==orld IQ==") def test_get_baseline(self) -> None: augmented_baseline = txtaugs.get_baseline(self.texts) diff --git a/augly/tests/text_tests/transforms_unit_test.py b/augly/tests/text_tests/transforms_unit_test.py index d7b6a63f..18bb0dd6 100644 --- a/augly/tests/text_tests/transforms_unit_test.py +++ b/augly/tests/text_tests/transforms_unit_test.py @@ -14,7 +14,9 @@ from typing import Any, Dict, List from augly import text as txtaugs +from augly.text.augmenters.utils import Encoding from augly.utils import TEXT_METADATA_PATH +from nlpaug.util import Method def are_equal_metadata( @@ -136,57 +138,68 @@ def test_Compose(self) -> None: are_equal_metadata(self.metadata, self.expected_metadata["compose"]), ) - def test_EncodeBase64(self) -> None: - augmented_text = txtaugs.EncodeBase64( - granularity="all", aug_min=1, aug_max=10, aug_p=0.3, n=1, p=1.0 + def test_EncodeText_Base64_Sentence(self) -> None: + augmented_text = txtaugs.EncodeTextTransform( + aug_min=1, + aug_max=1, + aug_p=1.0, + method=Method.SENTENCE, + encoder=Encoding.BASE64, + n=1, + p=1.0, )( ["Hello, world!"], metadata=self.metadata, ) self.assertTrue(augmented_text[0] == "SGVsbG8sIHdvcmxkIQ==") + self.expected_metadata["encode_text"][0]["encoder"] = Encoding.BASE64 self.assertTrue( - are_equal_metadata(self.metadata, self.expected_metadata["encode_base64"]) + are_equal_metadata(self.metadata, self.expected_metadata["encode_text"]) ) - def test_EncodeBase64_Word(self) -> None: + def test_EncodeText_Base64_Word(self) -> None: self.metadata = [] - random.seed(42) - augmented_text = txtaugs.EncodeBase64( - granularity="word", aug_min=1, aug_max=1, aug_p=1.0, n=1, p=1.0 + augmented_text = txtaugs.EncodeTextTransform( + aug_min=1, + aug_max=1, + aug_p=1.0, + method=Method.WORD, + encoder=Encoding.BASE64, + n=1, + p=1.0, )( ["Hello, world!"], metadata=self.metadata, ) self.assertEqual(augmented_text[0], "SGVsbG8=, world!") - expected_metadata = deepcopy(self.expected_metadata["encode_base64"]) - expected_metadata[0]["granularity"] = "word" - expected_metadata[0]["aug_p"] = 1.0 - expected_metadata[0]["aug_max"] = 1 - expected_metadata[0]["intensity"] = 100.0 + metadata_expected = deepcopy(self.expected_metadata["encode_text"]) + metadata_expected[0]["method"] = "word" + metadata_expected[0]["encoder"] = Encoding.BASE64 + self.assertTrue(are_equal_metadata(self.metadata, metadata_expected)) - self.assertTrue(are_equal_metadata(self.metadata, expected_metadata)) - - def test_EncodeBase64_Char(self) -> None: + def test_EncodeText_Base64_Char(self) -> None: self.metadata = [] - random.seed(42) - augmented_text = txtaugs.EncodeBase64( - granularity="char", aug_min=1, aug_max=2, aug_p=1.0, n=1, p=1.0 + augmented_text = txtaugs.EncodeTextTransform( + aug_min=1, + aug_max=1, + aug_p=1.0, + method=Method.CHAR, + encoder=Encoding.BASE64, + n=1, + p=1.0, )( ["Hello, world!"], metadata=self.metadata, ) - self.assertEqual(augmented_text[0], "SA==ebA==lo LA== wbw==rlZA== IQ==") - - expected_metadata = deepcopy(self.expected_metadata["encode_base64"]) - expected_metadata[0]["granularity"] = "char" - expected_metadata[0]["aug_p"] = 1.0 - expected_metadata[0]["aug_max"] = 2 - expected_metadata[0]["intensity"] = 100.0 + self.assertEqual(augmented_text[0], "SA==ello LA== wocg==ld IQ==") + expected_metadata = deepcopy(self.expected_metadata["encode_text"]) + expected_metadata[0]["method"] = "char" + expected_metadata[0]["encoder"] = Encoding.BASE64 self.assertTrue(are_equal_metadata(self.metadata, expected_metadata)) def test_GetBaseline(self) -> None: diff --git a/augly/text/__init__.py b/augly/text/__init__.py index 2f96d506..5bdd66a2 100644 --- a/augly/text/__init__.py +++ b/augly/text/__init__.py @@ -12,7 +12,7 @@ apply_lambda, change_case, contractions, - encode_base64, + encode_text, get_baseline, insert_punctuation_chars, insert_text, @@ -32,9 +32,10 @@ ) from augly.text.intensity import ( apply_lambda_intensity, + base64_intensity, change_case_intensity, contractions_intensity, - encode_base64_intensity, + encode_text_intensity, get_baseline_intensity, insert_punctuation_chars_intensity, insert_text_intensity, @@ -56,7 +57,7 @@ ApplyLambda, ChangeCase, Contractions, - EncodeBase64, + EncodeTextTransform, GetBaseline, InsertPunctuationChars, InsertText, @@ -81,7 +82,7 @@ "ApplyLambda", "ChangeCase", "Contractions", - "EncodeBase64", + "EncodeTextTransform", "GetBaseline", "InsertPunctuationChars", "InsertText", @@ -101,7 +102,7 @@ "apply_lambda", "change_case", "contractions", - "encode_base64", + "encode_text", "get_baseline", "insert_punctuation_chars", "insert_text", @@ -119,9 +120,10 @@ "split_words", "swap_gendered_words", "apply_lambda_intensity", + "base64_intensity", "change_case_intensity", "contractions_intensity", - "encode_base64_intensity", + "encode_text_intensity", "get_baseline_intensity", "insert_punctuation_chars_intensity", "insert_text_intensity", diff --git a/augly/text/augmenters/__init__.py b/augly/text/augmenters/__init__.py index 8e8c0bfd..6662a3af 100644 --- a/augly/text/augmenters/__init__.py +++ b/augly/text/augmenters/__init__.py @@ -7,11 +7,13 @@ # pyre-unsafe +from augly.text.augmenters.base64 import Base64 from augly.text.augmenters.baseline import BaselineAugmenter from augly.text.augmenters.bidirectional import BidirectionalAugmenter from augly.text.augmenters.case import CaseAugmenter from augly.text.augmenters.contraction import ContractionAugmenter -from augly.text.augmenters.encode_base64 import EncodeBase64 +from augly.text.augmenters.encode_text_context import EncodeText +from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation from augly.text.augmenters.fun_fonts import FunFontsAugmenter from augly.text.augmenters.insert_text import InsertTextAugmenter from augly.text.augmenters.insertion import InsertionAugmenter @@ -22,13 +24,14 @@ from augly.text.augmenters.word_replacement import WordReplacementAugmenter from augly.text.augmenters.words_augmenter import WordsAugmenter - __all__ = [ + "Base64", "BaselineAugmenter", "BidirectionalAugmenter", "CaseAugmenter", "ContractionAugmenter", - "EncodeBase64", + "EncodeText", + "EncodeTextAugmentation", "FunFontsAugmenter", "InsertTextAugmenter", "InsertionAugmenter", diff --git a/augly/text/augmenters/base64.py b/augly/text/augmenters/base64.py new file mode 100644 index 00000000..706bece9 --- /dev/null +++ b/augly/text/augmenters/base64.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import codecs + +from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation +from augly.text.augmenters.utils import Encoding +from nlpaug.util import Method + + +class Base64(EncodeTextAugmentation): + def __init__( + self, + aug_min: int, + aug_max: int, + aug_p: float, + method: Method, + ): + super().__init__( + name="Base64", + aug_min=aug_min, + aug_max=aug_max, + aug_p=aug_p, + encoder=Encoding.BASE64, + method=str(method), + ) + assert 0 <= aug_min <= aug_max + assert 0 <= aug_p <= 1 + + def encode(self, input_string: str) -> str: + encoded_bytes = codecs.encode(input_string.encode("utf-8"), "base64") + return encoded_bytes.decode("utf-8").strip() diff --git a/augly/text/augmenters/encode_base64.py b/augly/text/augmenters/encode_base64.py deleted file mode 100644 index b88d97c8..00000000 --- a/augly/text/augmenters/encode_base64.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import codecs - -from augly.text.augmenters.utils import detokenize, get_aug_idxes, tokenize -from nlpaug.augmenter.word import Augmenter # @manual -from nlpaug.util import Action, Method # @manual - - -class EncodeBase64(Augmenter): - def __init__(self, granularity="all", aug_min=1, aug_max=10, aug_p=0.3): - - assert granularity in ["char", "word", "all"] - assert 0 <= aug_min <= aug_max - assert 0 <= aug_p <= 1 - - self.granularity = granularity - super().__init__( - name="EncodeBase64", - action=Action.SUBSTITUTE, - method=Method.WORD, - aug_min=aug_min, - aug_max=aug_max, - aug_p=aug_p, - ) - - def clean(self, data): - - if isinstance(data, list): - return [self.clean(d) for d in data] - elif isinstance(data, str): - return data - elif data is None: - return "" - else: - return str(data) - - def encode_text(self, input_string: str) -> str: - if not isinstance(input_string, str): - raise TypeError("Input must be a string") - - encoded_bytes = codecs.encode(input_string.encode("utf-8"), "base64") - return encoded_bytes.decode("utf-8").strip() - - def substitute(self, data) -> str: - if self.granularity == "all": - return self.encode_text(data) - - tokens = tokenize(data) - if not tokens: - return "" - - if self.granularity == "word": - aug_word_cnt = self._generate_aug_cnt( - len(tokens), self.aug_min, self.aug_max, self.aug_p - ) - aug_word_idxes = set( - get_aug_idxes( - self, tokens, list(range(len(tokens))), aug_word_cnt, Method.WORD - ) - ) - for i, token in enumerate(tokens): - if i in aug_word_idxes: - tokens[i] = self.encode_text(token) - - elif self.granularity == "char": - for t_i, token in enumerate(tokens): - chars = list(token) - aug_char_cnt = self._generate_aug_cnt( - len(chars), self.aug_min, self.aug_max, self.aug_p - ) - aug_char_idxes = set( - get_aug_idxes( - self, chars, list(range(len(chars))), aug_char_cnt, Method.CHAR - ) - ) - for c_i, char in enumerate(chars): - if c_i in aug_char_idxes: - chars[c_i] = self.encode_text(char) - tokens[t_i] = "".join(chars) - return detokenize(tokens) diff --git a/augly/text/augmenters/encode_text_context.py b/augly/text/augmenters/encode_text_context.py new file mode 100644 index 00000000..35c476fc --- /dev/null +++ b/augly/text/augmenters/encode_text_context.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +from typing import List, Union + +from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation + + +class EncodeText: + def __init__(self, encoder: EncodeTextAugmentation): + self.encoder = encoder + + def augmenter(self, input_string: Union[List[str], str]) -> Union[List[str], str]: + return self.encoder.augment(input_string, 1) diff --git a/augly/text/augmenters/encode_text_strategy.py b/augly/text/augmenters/encode_text_strategy.py new file mode 100644 index 00000000..4bdf2ca0 --- /dev/null +++ b/augly/text/augmenters/encode_text_strategy.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from abc import abstractmethod +from typing import List, Union + +from augly.text.augmenters.utils import detokenize, Encoding, get_aug_idxes, tokenize +from nlpaug.augmenter.word import Augmenter +from nlpaug.util import Action, Method + + +class EncodeTextAugmentation(Augmenter): + def __init__( + self, + name: str, + aug_min: int, + aug_max: int, + aug_p: float, + encoder: Encoding = Encoding.BASE64, + method: str = Method.SENTENCE, + ): + super().__init__( + name=name, + aug_min=aug_min, + aug_max=aug_max, + aug_p=aug_p, + action=Action.SUBSTITUTE, + method=method, + ) + + self.encoder = encoder + self.method = method + + @classmethod + def clean(cls, data: Union[str, List[str], None]) -> Union[str, List[str]]: + if isinstance(data, str): + return data + elif isinstance(data, list): + cleaned_data = [cls.clean(d) for d in data] + if all(isinstance(d, str) for d in cleaned_data): + # pyre-ignore + return cleaned_data + return "".join(str(d) for d in cleaned_data) + elif data is None: + return "" + else: + return str(data) + + @classmethod + def is_duplicate(cls, dataset: List[str], data: str) -> bool: + return data in dataset + + @abstractmethod + def encode(self, input_string: str) -> str: + raise NotImplementedError + + def substitute(self, data: str) -> str: + if self.method == Method.SENTENCE: + return self.encode(data) + + tokens = tokenize(data) + if not tokens: + return "" + + if self.method == Method.WORD: + augment_count = self._generate_aug_cnt( + len(tokens), self.aug_min, self.aug_max, self.aug_p + ) + to_augment = set( + get_aug_idxes( + self, tokens, list(range(len(tokens))), augment_count, Method.WORD + ) + ) + for i, token in enumerate(tokens): + if i in to_augment: + tokens[i] = self.encode(token) + + elif self.method == Method.CHAR: + for token_idx, token in enumerate(tokens): + chars = list(token) + augment_count = self._generate_aug_cnt( + len(chars), self.aug_min, self.aug_max, self.aug_p + ) + to_augment = set( + get_aug_idxes( + self, chars, list(range(len(chars))), augment_count, Method.CHAR + ) + ) + for char_idx, char in enumerate(chars): + if char_idx in to_augment: + chars[char_idx] = self.encode(char) + tokens[token_idx] = "".join(chars) + return detokenize(tokens) diff --git a/augly/text/augmenters/utils.py b/augly/text/augmenters/utils.py index 31297afd..080736f4 100644 --- a/augly/text/augmenters/utils.py +++ b/augly/text/augmenters/utils.py @@ -8,6 +8,7 @@ # pyre-unsafe import re +from enum import Enum from typing import List, Optional, Tuple import regex @@ -269,3 +270,7 @@ def get_aug_idxes( aug_idxes = augmenter.sample(priority_idxes, aug_cnt) return aug_idxes + + +class Encoding(Enum): + BASE64 = "base64" diff --git a/augly/text/functional.py b/augly/text/functional.py index 9ac688bd..da3c6492 100644 --- a/augly/text/functional.py +++ b/augly/text/functional.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from augly.text import augmenters as a, utils as txtutils +from augly.text.augmenters.utils import Encoding from augly.utils import ( CONTRACTIONS_MAPPING, FUN_FONTS_PATH, @@ -18,6 +19,7 @@ MISSPELLING_DICTIONARY_PATH, UNICODE_MAPPING_PATH, ) +from nlpaug.util import Method def apply_lambda( @@ -167,22 +169,21 @@ def contractions( return aug_texts -def encode_base64( +def encode_text( texts: Union[str, List[str]], - granularity: str = "all", - aug_min: int = 1, - aug_max: int = 10, - aug_p: float = 0.3, + aug_min: int, + aug_max: int, + aug_p: float, + method: Method, + encoder: Encoding, n: int = 1, + p: float = 1.0, metadata: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, List[str]]: - """ - Encodes text into base64, with options for different granularity levels - - @param texts: a string or a list of text documents to be augmented + """Alters text based on encoding function and based on different granularity levels + such as the entire text, specific words or characters. - @param granularity: Level at which to apply base64 encoding. - Options: 'char', 'word', or 'all' (entire text). + @param texts: A string or a list of text documents to be augmented @param aug_min: Minimum number of units (words/chars) to augment. @@ -190,8 +191,15 @@ def encode_base64( @param aug_p: Probability of augmenting each unit. + param method: Level at which to apply base64 encoding. + Options: 'char', 'word', or 'sentence' (entire text). + + param encoder: Specific function that determines type of encoding performed + @param n: Number of augmentations to be performed. + @param p: + @param metadata: if set to be a list, metadata about the function execution including its name, the source & dest length, etc. will be appended to the inputted list. If set to None, no metadata will be appended or returned @@ -200,15 +208,17 @@ def encode_base64( """ func_kwargs = txtutils.get_func_kwargs(metadata, locals()) - base64_aug = a.EncodeBase64(granularity, aug_min, aug_max, aug_p) - if not isinstance(texts, list): texts = [texts] - aug_texts = base64_aug.augment(texts) + if encoder == Encoding.BASE64: + encoder_strategy = a.Base64(aug_min, aug_max, aug_p, method) + # pyre-ignore + encoder_context = a.EncodeText(encoder_strategy) + aug_texts = encoder_context.augmenter(texts) txtutils.get_metadata( metadata=metadata, - function_name="encode_base64", + function_name="encode_text", aug_texts=aug_texts, **func_kwargs, ) diff --git a/augly/text/intensity.py b/augly/text/intensity.py index e2121bf9..fcb45117 100644 --- a/augly/text/intensity.py +++ b/augly/text/intensity.py @@ -9,12 +9,22 @@ from typing import Any, Dict, List, Optional, Union +from augly.text.augmenters.utils import Encoding + +from nlpaug import Method + def apply_lambda_intensity(aug_function: str, **kwargs) -> float: intensity_func = globals().get(f"{aug_function}_intensity") return intensity_func(**kwargs) if intensity_func else 100.0 +def base64_intensity(method: Method, aug_p: float, aug_max: int, **kwargs) -> float: + return ( + 100.0 if method == Method.SENTENCE else replace_intensity_helper(aug_p, aug_max) + ) + + def change_case_intensity(granularity: str, cadence: float, **kwargs) -> float: return char_insertion_intensity_helper(granularity, cadence) @@ -23,10 +33,15 @@ def contractions_intensity(aug_p: float, **kwargs) -> float: return aug_p * 100.0 -def encode_base64_intensity( - granularity: str = "all", aug_p: float = 0.3, aug_max: int = 10, **kwargs +def encode_text_intensity( + encoder: Encoding, method: Method, aug_p: float, aug_max: int, **kwargs ) -> float: - return 100.0 if granularity == "all" else replace_intensity_helper(aug_p, aug_max) + if encoder == Encoding.BASE64: + return base64_intensity(method, aug_p, aug_max) + else: + raise NotImplementedError( + f"Intensity function for encoder {encoder} is not implemented" + ) def get_baseline_intensity(**kwargs) -> float: diff --git a/augly/text/transforms.py b/augly/text/transforms.py index dbb1fcfc..ecd2efea 100644 --- a/augly/text/transforms.py +++ b/augly/text/transforms.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from augly.text import functional as F +from augly.text.augmenters.utils import Encoding from augly.utils import ( CONTRACTIONS_MAPPING, FUN_FONTS_PATH, @@ -19,6 +20,7 @@ MISSPELLING_DICTIONARY_PATH, UNICODE_MAPPING_PATH, ) +from nlpaug.util import Method """ @@ -270,13 +272,14 @@ def apply_transform( return F.contractions(texts, metadata=metadata, **aug_kwargs) -class EncodeBase64(BaseTransform): +class EncodeTextTransform(BaseTransform): def __init__( self, - granularity: str = "all", - aug_min: int = 1, - aug_max: int = 10, - aug_p: float = 0.3, + aug_min: int, + aug_max: int, + aug_p: float, + method: Method, + encoder: Encoding, n: int = 1, p: float = 1.0, ): @@ -292,13 +295,15 @@ def __init__( @param n: number of augmentations to be performed for each text - @param p: the probability of the transform being applied; default value is 1.0 + @param p: the probability of the overall being applied to each unit (word/char). + default value is 1.0 """ super().__init__(p) - self.granularity = granularity self.aug_min = aug_min self.aug_max = aug_max self.aug_p = aug_p + self.method = method + self.encoder = encoder self.n = n def apply_transform( @@ -323,7 +328,7 @@ def apply_transform( if not texts: return texts - return F.encode_base64(texts, metadata=metadata, **aug_kwargs) + return F.encode_text(texts, metadata=metadata, **aug_kwargs) class GetBaseline(BaseTransform): From eb38d45549187d29ebf87499fc28016af2c6414d Mon Sep 17 00:00:00 2001 From: Sharan Sairamesh Date: Tue, 10 Jun 2025 09:18:25 -0700 Subject: [PATCH 2/2] Added leetspeak function encoding to encode_text (#264) Summary: Pull Request resolved: https://github.com/facebookresearch/AugLy/pull/264 Added leetspeak text encoding function to the enode_text interface and tested comprehensively. Word method level test covers character level test as leetspeak works on character level Differential Revision: D76058699 --- .../tests/text_tests/functional_unit_test.py | 48 +++++++++++------- .../tests/text_tests/transforms_unit_test.py | 48 ++++++++++++++++-- augly/text/augmenters/__init__.py | 2 + augly/text/augmenters/leetspeak.py | 50 +++++++++++++++++++ augly/text/augmenters/utils.py | 1 + augly/text/functional.py | 3 +- augly/text/intensity.py | 8 +++ 7 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 augly/text/augmenters/leetspeak.py diff --git a/augly/tests/text_tests/functional_unit_test.py b/augly/tests/text_tests/functional_unit_test.py index f45ffe50..a7ce94c1 100644 --- a/augly/tests/text_tests/functional_unit_test.py +++ b/augly/tests/text_tests/functional_unit_test.py @@ -38,39 +38,39 @@ def test_apply_lambda(self) -> None: augmented_apply_lambda = txtaugs.apply_lambda(self.texts) self.assertTrue(augmented_apply_lambda[0] == self.texts[0]) - def test_change_case(self) -> None: - augmented_words = txtaugs.change_case(self.texts[0], cadence=3.0, case="upper") - self.assertTrue( - augmented_words[0] - == "THE quick brown 'FOX' couldn't jump OVER the green, GRASSY hill.", - ) - - def test_contractions(self) -> None: - augmented_words = txtaugs.contractions( - "I would call him but I do not know where he has gone", aug_p=0.7 - ) - self.assertTrue( - augmented_words[0] == "I would call him but I don't know where he's gone" - ) - - def test_encode_text_base64_sentence(self) -> None: + def test_base64_sentence(self) -> None: augmented_words = txtaugs.encode_text( "Hello, world!", 1, 1, 1.0, Method.SENTENCE, Encoding.BASE64 ) self.assertEqual(augmented_words[0], "SGVsbG8sIHdvcmxkIQ==") - def test_encode_text_base64_word(self) -> None: + def test_base64_word(self) -> None: augmented_words_word = txtaugs.encode_text( "Hello, world!", 1, 1, 1.0, Method.WORD, Encoding.BASE64 ) self.assertEqual(augmented_words_word[0], "SGVsbG8=, world!") - def test_encode_text_base64_char(self) -> None: + def test_base64_char(self) -> None: augmented_words_char = txtaugs.encode_text( "Hello, world!", 1, 1, 1.0, Method.CHAR, Encoding.BASE64 ) self.assertEqual(augmented_words_char[0], "SA==ello LA== dw==orld IQ==") + def test_change_case(self) -> None: + augmented_words = txtaugs.change_case(self.texts[0], cadence=3.0, case="upper") + self.assertTrue( + augmented_words[0] + == "THE quick brown 'FOX' couldn't jump OVER the green, GRASSY hill.", + ) + + def test_contractions(self) -> None: + augmented_words = txtaugs.contractions( + "I would call him but I do not know where he has gone", aug_p=0.7 + ) + self.assertTrue( + augmented_words[0] == "I would call him but I don't know where he's gone" + ) + def test_get_baseline(self) -> None: augmented_baseline = txtaugs.get_baseline(self.texts) self.assertTrue( @@ -272,6 +272,18 @@ def test_insert_zero_width_chars(self) -> None: ], ) + def test_leetspeak_sentence(self) -> None: + augmented_words = txtaugs.encode_text( + "Hello, world!", 1, 1, 1.0, Method.SENTENCE, Encoding.LEETSPEAK + ) + self.assertEqual(augmented_words[0], "h3110, w0r1d!") + + def test_leetspeak_word(self) -> None: + augmented_words = txtaugs.encode_text( + "Hello, world!", 1, 1, 1.0, Method.WORD, Encoding.LEETSPEAK + ) + self.assertEqual(augmented_words[0], "h3110, world!") + def test_merge_words(self) -> None: augmented_split_words = txtaugs.merge_words(self.texts, aug_word_p=0.3, n=1) self.assertTrue( diff --git a/augly/tests/text_tests/transforms_unit_test.py b/augly/tests/text_tests/transforms_unit_test.py index 18bb0dd6..c67f2c84 100644 --- a/augly/tests/text_tests/transforms_unit_test.py +++ b/augly/tests/text_tests/transforms_unit_test.py @@ -138,7 +138,7 @@ def test_Compose(self) -> None: are_equal_metadata(self.metadata, self.expected_metadata["compose"]), ) - def test_EncodeText_Base64_Sentence(self) -> None: + def test_Base64_Sentence(self) -> None: augmented_text = txtaugs.EncodeTextTransform( aug_min=1, aug_max=1, @@ -158,7 +158,7 @@ def test_EncodeText_Base64_Sentence(self) -> None: are_equal_metadata(self.metadata, self.expected_metadata["encode_text"]) ) - def test_EncodeText_Base64_Word(self) -> None: + def test_Base64_Word(self) -> None: self.metadata = [] augmented_text = txtaugs.EncodeTextTransform( @@ -180,7 +180,7 @@ def test_EncodeText_Base64_Word(self) -> None: metadata_expected[0]["encoder"] = Encoding.BASE64 self.assertTrue(are_equal_metadata(self.metadata, metadata_expected)) - def test_EncodeText_Base64_Char(self) -> None: + def test_Base64_Char(self) -> None: self.metadata = [] augmented_text = txtaugs.EncodeTextTransform( @@ -291,6 +291,48 @@ def test_InsertZeroWidthChars(self) -> None: ), ) + def test_LeetSpeak_Sentence(self) -> None: + augmented_text = txtaugs.EncodeTextTransform( + aug_min=1, + aug_max=1, + aug_p=1.0, + method=Method.SENTENCE, + encoder=Encoding.LEETSPEAK, + n=1, + p=1.0, + )( + ["Hello, world!"], + metadata=self.metadata, + ) + + self.assertTrue(augmented_text[0] == "h3110, w0r1d!") + self.expected_metadata["encode_text"][0]["encoder"] = Encoding.LEETSPEAK + self.assertTrue( + are_equal_metadata(self.metadata, self.expected_metadata["encode_text"]) + ) + + def test_Leetspeak_Word(self) -> None: + self.metadata = [] + + augmented_text = txtaugs.EncodeTextTransform( + aug_min=1, + aug_max=1, + aug_p=1.0, + method=Method.WORD, + encoder=Encoding.LEETSPEAK, + n=1, + p=1.0, + )( + ["Hello, world!"], + metadata=self.metadata, + ) + self.assertEqual(augmented_text[0], "h3110, world!") + + metadata_expected = deepcopy(self.expected_metadata["encode_text"]) + metadata_expected[0]["method"] = "word" + metadata_expected[0]["encoder"] = Encoding.LEETSPEAK + self.assertTrue(are_equal_metadata(self.metadata, metadata_expected)) + def test_MergeWords(self) -> None: aug_merge_words = txtaugs.MergeWords(aug_word_p=0.3)( self.texts, metadata=self.metadata diff --git a/augly/text/augmenters/__init__.py b/augly/text/augmenters/__init__.py index 6662a3af..ef7c2884 100644 --- a/augly/text/augmenters/__init__.py +++ b/augly/text/augmenters/__init__.py @@ -17,6 +17,7 @@ from augly.text.augmenters.fun_fonts import FunFontsAugmenter from augly.text.augmenters.insert_text import InsertTextAugmenter from augly.text.augmenters.insertion import InsertionAugmenter +from augly.text.augmenters.leetspeak import LeetSpeak from augly.text.augmenters.letter_replacement import LetterReplacementAugmenter from augly.text.augmenters.text_replacement import TextReplacementAugmenter from augly.text.augmenters.typo import TypoAugmenter @@ -35,6 +36,7 @@ "FunFontsAugmenter", "InsertTextAugmenter", "InsertionAugmenter", + "LeetSpeak", "LetterReplacementAugmenter", "WordsAugmenter", "TextReplacementAugmenter", diff --git a/augly/text/augmenters/leetspeak.py b/augly/text/augmenters/leetspeak.py new file mode 100644 index 00000000..70131a7c --- /dev/null +++ b/augly/text/augmenters/leetspeak.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation +from augly.text.augmenters.utils import Encoding +from nlpaug.util import Method + + +class LeetSpeak(EncodeTextAugmentation): + def __init__( + self, + aug_min: int, + aug_max: int, + aug_p: float, + method: Method, + ): + super().__init__( + name="LeetSpeak", + aug_min=aug_min, + aug_max=aug_max, + aug_p=aug_p, + encoder=Encoding.LEETSPEAK, + method=str(method), + ) + assert 0 <= aug_min <= aug_max + assert 0 <= aug_p <= 1 + + def encode(self, input_string: str) -> str: + leet_map = { + "a": ["4", "@"], + "b": ["8"], + "e": ["3"], + "g": ["6"], + "i": ["1", "!"], + "l": ["1"], + "o": ["0"], + "s": ["5", "$"], + "t": ["7", "+"], + "z": ["2"], + } + input_string = input_string.lower() + return "".join( + random.choice(leet_map.get(char, [char])) for char in input_string + ) diff --git a/augly/text/augmenters/utils.py b/augly/text/augmenters/utils.py index 080736f4..59066ac1 100644 --- a/augly/text/augmenters/utils.py +++ b/augly/text/augmenters/utils.py @@ -274,3 +274,4 @@ def get_aug_idxes( class Encoding(Enum): BASE64 = "base64" + LEETSPEAK = "leetspeak" diff --git a/augly/text/functional.py b/augly/text/functional.py index da3c6492..aadcb8df 100644 --- a/augly/text/functional.py +++ b/augly/text/functional.py @@ -212,7 +212,8 @@ def encode_text( texts = [texts] if encoder == Encoding.BASE64: encoder_strategy = a.Base64(aug_min, aug_max, aug_p, method) - # pyre-ignore + else: + encoder_strategy = a.LeetSpeak(aug_min, aug_max, aug_p, method) encoder_context = a.EncodeText(encoder_strategy) aug_texts = encoder_context.augmenter(texts) diff --git a/augly/text/intensity.py b/augly/text/intensity.py index fcb45117..38478939 100644 --- a/augly/text/intensity.py +++ b/augly/text/intensity.py @@ -38,6 +38,8 @@ def encode_text_intensity( ) -> float: if encoder == Encoding.BASE64: return base64_intensity(method, aug_p, aug_max) + elif encoder == Encoding.LEETSPEAK: + return leetspeak_intensity(method, aug_p, aug_max) else: raise NotImplementedError( f"Intensity function for encoder {encoder} is not implemented" @@ -76,6 +78,12 @@ def insert_zero_width_chars_intensity( return char_insertion_intensity_helper(granularity, cadence) +def leetspeak_intensity(method: Method, aug_p: float, aug_max: int, **kwargs) -> float: + return ( + 100.0 if method == Method.SENTENCE else replace_intensity_helper(aug_p, aug_max) + ) + + def merge_words_intensity(aug_word_p: float, aug_word_max: int, **kwargs) -> float: return replace_intensity_helper(aug_word_p, aug_word_max)