From ead2c68e6a41077d7c24ec17a5691976c303d3e6 Mon Sep 17 00:00:00 2001 From: David Bieber Date: Wed, 27 Oct 2021 13:04:28 -0400 Subject: [PATCH] Data generation --- core/data/generation/constants.py | 1 + core/data/generation/generate.py | 189 ++++++++++++++++ core/data/generation/program_generator.py | 248 +++++++++++++++++++++ core/data/generation/python_interpreter.py | 185 +++++++++++++++ 4 files changed, 623 insertions(+) create mode 100644 core/data/generation/constants.py create mode 100644 core/data/generation/generate.py create mode 100644 core/data/generation/program_generator.py create mode 100644 core/data/generation/python_interpreter.py diff --git a/core/data/generation/constants.py b/core/data/generation/constants.py new file mode 100644 index 00000000..76b875cc --- /dev/null +++ b/core/data/generation/constants.py @@ -0,0 +1 @@ +INDENT_STRING = ' ' \ No newline at end of file diff --git a/core/data/generation/generate.py b/core/data/generation/generate.py new file mode 100644 index 00000000..a8640806 --- /dev/null +++ b/core/data/generation/generate.py @@ -0,0 +1,189 @@ +"""Generates Control Flow Programs. + +This file was introduced as part of the Exception IPA-GNN effort, for generating +a new dataset suitable for testing the vanilla IPA-GNN and Exception IPA-GNN. +""" + +import collections +import dataclasses +import os +from typing import Optional, Sequence, Text, Tuple + +from absl import app +from python_graphs import control_flow +import tensorflow as tf +import tqdm + +from core.data import codenet_paths +from core.data import process +from core.data.generation import program_generator +from core.data.generation import python_interpreter + +TFRECORD_PATH = codenet_paths.RAW_CFP_RAISE_DATA_PATH +TFRECORD_PATH = 'tmp.tfrecord' + + +DEFAULT_OPS = ("+=", "-=", "*=") + + +@dataclasses.dataclass +class ArithmeticIfRepeatsConfig: + """Config for ArithmeticIfRepeats ProgramGenerator. + + Attributes: + base: The base to represent the integers in. + length: The number of statements in the generated programs. + num_digits: The number of digits in the values used by the programs. + max_repeat_statements: The maximum number of repeat statements allowed in + a program. + max_repetitions: The maximum number of repetitions a repeat statement may + specify. + repeat_probability: The probability that a given statement is a repeat + statement, provided a repeat statement is possible at that location. + max_if_statements: The maximum number of if statements allowed in a program. + if_probability: The probability that a given statement is an if statement, + provided an if statement is possible at that location. + ifelse_probability: The probability that a given statement is an if-else + statement, provided an if statement is possible at that location. + max_nesting: The maximum depth of nesting permitted, or None if no limit. + max_block_size: The maximum number of statements permitted in a block. + ops: The ops allowed in the generated programs. + encoder_name: The encoder name to use to encode the generated programs. + mod: The value (if any) to mod the intermediate values of the program by + after each step of execution. + output_mod: The value (if any) to mod the final values of the program by. + """ + base: int + length: int + num_digits: int = 1 + max_repeat_statements: Optional[int] = 2 + max_repetitions: int = 9 + repeat_probability: float = 0.1 + max_if_statements: Optional[int] = 2 + if_probability: float = 0.2 + ifelse_probability: float = 0.2 + max_nesting: Optional[int] = None + max_block_size: Optional[int] = 9 + ops: Tuple[Text, ...] = DEFAULT_OPS + encoder_name: Text = "simple" + mod: Optional[int] = 10 + output_mod: Optional[int] = None + + + +def int64_feature(value): + """Constructs a tf.train.Feature for the given int64 value list.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + + +def bytes_feature(values): + """Constructs a tf.train.Feature for the given str value list.""" + values = [v.encode('utf-8') for v in values] + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + + +def to_tf_example(source, target, steps): + """Constructs a tf.train.Example for the source code.""" + return tf.train.Example(features=tf.train.Features(feature={ + 'source': bytes_feature([source]), + 'target': bytes_feature([target]), + 'steps': int64_feature([steps]), + })) + + +def decode_fn(record_bytes): + features = { + 'source': tf.io.FixedLenFeature([1], dtype=tf.string), + 'target': tf.io.FixedLenFeature([1], dtype=tf.string), + 'steps': tf.io.FixedLenFeature([1], dtype=tf.int64), + } + return tf.io.parse_single_example(record_bytes, features) + + +def load_dataset(tfrecord_paths): + return tf.data.TFRecordDataset( + tfrecord_paths, + compression_type=None, buffer_size=None, num_parallel_reads=32 + ).map(decode_fn) + + +def read(): + for example in load_dataset([TFRECORD_PATH]): + source = example['source'].numpy()[0].decode('utf-8') + target = example['target'].numpy()[0].decode('utf-8') + print(source) + print('---') + # if 'raise' in source: + # print(target) + + +def generate_example_from_python_source(executor, base, python_source, mod, output_mod): + """Generates an example dict from the given statements.""" + cfg = control_flow.get_control_flow_graph(python_source) + python_source_lines = python_source.strip().split("\n") + + values = {"v0": 1} # Assume v0 starts at 1. + try: + values = python_interpreter.evaluate_cfg( + executor, cfg, mod=mod, + initial_values=values, + timeout=200) + error_type = "NoError" + except Exception as e: # pylint: disable=broad-except + error_type = type(e).__name__ + target_output = values["v0"] + + if output_mod is not None: + try: + target_output %= output_mod + except TypeError: + target_output = 1 + + return { + 'human_readable_target_output': str(target_output), + 'error_type': error_type + } + + +def main(argv: Sequence[str]) -> None: + del argv # Unused. + + # if os.path.exists(TFRECORD_PATH): + # return read() + + executor = python_interpreter.ExecExecutor() + counts = collections.Counter() + program_generator_config = ArithmeticIfRepeatsConfig( + base=10, + max_if_statements=5, + length=30, + ) + with tf.io.TFRecordWriter(TFRECORD_PATH) as file_writer: + for _ in tqdm.tqdm(range(50)): + source = program_generator.generate_python_source( + 30, program_generator_config) + print(source) + print() + example = ( + generate_example_from_python_source( + executor, program_generator_config.base, source, + mod=1000, + output_mod=1000, + ) + ) + print(example) + target = example['human_readable_target_output'] + error_type = example['error_type'] + lines = source.split('\n') + steps = process.get_step_limit(lines) + counts[target] += 1 + + if error_type != 'NoError': + target = error_type + record_bytes = to_tf_example(source, target, steps).SerializeToString() + file_writer.write(record_bytes) + print(dict(counts)) + + +if __name__ == '__main__': + app.run(main) diff --git a/core/data/generation/program_generator.py b/core/data/generation/program_generator.py new file mode 100644 index 00000000..d0c129f7 --- /dev/null +++ b/core/data/generation/program_generator.py @@ -0,0 +1,248 @@ +"""Generating and running arithmetic programs with if and repeat statements. + +We use a list of statements to represent a program. Each statement is a list of +an operator and two operands. The standard ops in a program are +, -, *, +if-statements, and a special "repeat" op ("r") that acts as a repeat block in +the program. + +The +, -, and * ops update a variable by modifying it. The first operand +indicates which variable is being updated. The second operand indicates +by how much to modify the variable. + +In the repeat op, the first operand indicates the number of repetitions and the +second op indicates how many statements to repeat. +""" + +import random +from absl import logging # pylint: disable=unused-import + +from core.data.generation import constants + +REPEAT_OP = "r" +IF_OP = "i" +ELSE_OP = "e" +PLACEHOLDER_OP = "_" + + +def generate_python_source(length, config): + """Generates Python code according to the config.""" + statements, unused_hole_statement_index = _generate_statements(length, config) + return _to_python_source(statements, config) + + +def generate_python_source_and_partial_python_source(length, config): + """Generates Python code according to the config.""" + statements, hole_statement_index = _generate_statements(length, config) + partial_statements = statements.copy() + partial_statements[hole_statement_index] = _placeholder_statement() + return (_to_python_source(statements, config), + _to_python_source(partial_statements, config)) + + +def _placeholder_statement(): + return (PLACEHOLDER_OP, 0, 0) + + +def _generate_statements(length, config): + """Generates a list of statements representing a control flow program. + + Args: + length: The number of statements to generate. + config: The ArithmeticRepeatsConfig specifying the properties of the program + to generate. + Returns: + A list of statements, each statement being a 3-tuple (op, operand, operand), + as well as the index of a statement to replace with a hole. + """ + max_value = config.base ** config.num_digits - 1 + + statements = [] + nesting_lines_remaining = [] + nesting_instructions = [] + num_repeats = 0 + num_ifs = 0 + hole_candidates = [] + instruction = None + for statement_index in range(length): + if instruction is None: + current_nesting = len(nesting_lines_remaining) + nesting_permitted = (config.max_nesting is None + or current_nesting < config.max_nesting) + too_many_repeats = (config.max_repeat_statements is not None + and num_repeats > config.max_repeat_statements) + repeat_permitted = nesting_permitted and not ( + too_many_repeats + or statement_index == length - 1 # Last line of program. + or 1 in nesting_lines_remaining # Last line of another block. + ) + too_many_ifs = (config.max_if_statements is not None + and num_ifs > config.max_if_statements) + if_permitted = nesting_permitted and not ( + too_many_ifs + or statement_index == length - 1 # Last line of program. + or 1 in nesting_lines_remaining # Last line of another block. + ) + ifelse_permitted = nesting_permitted and not ( + too_many_ifs + or statement_index >= length - 3 # Need 4 lines for if-else. + or 1 in nesting_lines_remaining # Last line of another block. + or 2 in nesting_lines_remaining # 2nd-to-last line of another block. + or 3 in nesting_lines_remaining # 3rd-to-last line of another block. + ) + op_random = random.random() + is_repeat = repeat_permitted and op_random < config.repeat_probability + is_if = if_permitted and ( + config.repeat_probability + < op_random + < config.repeat_probability + config.if_probability) + is_ifelse = ifelse_permitted and ( + config.repeat_probability + config.if_probability + < op_random + < (config.repeat_probability + + config.if_probability + + config.ifelse_probability)) + + # statements_remaining_* includes current statement. + statements_remaining_in_program = length - statement_index + statements_remaining_in_block = min( + [statements_remaining_in_program] + nesting_lines_remaining) + if config.max_block_size: + max_block_size = min(config.max_block_size, + statements_remaining_in_block) + else: + max_block_size = statements_remaining_in_block + + if is_repeat: + num_repeats += 1 + repetitions = random.randint(2, config.max_repetitions) + # num_statements includes current statement. + num_statements = random.randint(2, max_block_size) + nesting_lines_remaining.append(num_statements) + nesting_instructions.append(None) + # -1 is to not include current statement. + statement = (REPEAT_OP, repetitions, num_statements - 1) + elif is_if: + num_ifs += 1 + # num_statements includes current statement. + num_statements = random.randint(2, max_block_size) + nesting_lines_remaining.append(num_statements) + nesting_instructions.append(None) + threshold = random.randint(0, max_value) # "if v0 > {threshold}:" + # -1 is to not include current statement. + statement = (IF_OP, threshold, num_statements - 1) + elif is_ifelse: + num_ifs += 1 + # num_statements includes current statement. + num_statements = random.randint(4, max_block_size) + # Choose a statement to be the else statement. + else_statement_index = random.randint(2, num_statements - 2) + nesting_lines_remaining.append(else_statement_index) + nesting_instructions.append( + ("else", num_statements - else_statement_index)) + threshold = random.randint(0, max_value) # "if v0 > {threshold}:" + # -1 is to not include current statement. + statement = (IF_OP, threshold, else_statement_index - 1) + else: + op = random.choice(config.ops) + variable_index = 0 # "v0" + operand = random.randint(0, max_value) + statement = (op, variable_index, operand) + hole_candidates.append(statement_index) + else: # instruction is not None + if instruction[0] == "else": + # Insert an else block. + num_statements = instruction[1] + nesting_lines_remaining.append(num_statements) + nesting_instructions.append(None) + # -1 is to not include current statement. + statement = (ELSE_OP, 0, num_statements - 1) + else: + raise ValueError("Unexpected instruction", instruction) + + instruction = None + statements.append(statement) + + # Decrement nesting. + for nesting_index in range(len(nesting_lines_remaining)): + nesting_lines_remaining[nesting_index] -= 1 + while nesting_lines_remaining and nesting_lines_remaining[-1] == 0: + nesting_lines_remaining.pop() + instruction = nesting_instructions.pop() + assert 0 not in nesting_lines_remaining + + hole_statement_index = random.choice(hole_candidates) + + return statements, hole_statement_index + + +def _select_counter_variable(used_variables, config): + del config # Unused. + num_variables = 10 # TODO(dbieber): num_variables is hardcoded. + max_variable = num_variables - 1 + allowed_variables = ( + set(range(1, max_variable + 1)) - set(used_variables)) + return random.choice(list(allowed_variables)) + + +def _to_python_source(statements, config): + """Convert statements into Python source code. + + Repeat statements are rendered as while loops with a counter variable that + tracks the number of iterations remaining. + + Args: + statements: A list of statements. Each statement is a triple containing + (op, operand, operand). + config: An ArithmeticRepeatsConfig. + Returns: + Python source code representing the program. + """ + lines = [] + nesting_lines_remaining = [] + used_variables = [] + for statement in statements: + op, operand1, operand2 = statement + indent = constants.INDENT_STRING * len(nesting_lines_remaining) + if op is REPEAT_OP: + # num_statements doesn't include current statement. + repetitions, num_statements = operand1, operand2 + variable_index = _select_counter_variable(used_variables, config) + line1 = f"{indent}v{variable_index} = {repetitions}" + line2 = f"{indent}while v{variable_index} > 0:" + # +1 is for current statement. + nesting_lines_remaining.append(num_statements + 1) + used_variables.append(variable_index) + line3_indent = constants.INDENT_STRING * len(nesting_lines_remaining) + line3 = f"{line3_indent}v{variable_index} -= 1" + lines.extend([line1, line2, line3]) + elif op is IF_OP: + # num_statements doesn't include current statement. + threshold, num_statements = operand1, operand2 + lines.append(f"{indent}if v0 > {threshold}:") + # +1 is for current statement. + nesting_lines_remaining.append(num_statements + 1) + used_variables.append(None) + elif op is ELSE_OP: + lines.append(f"{indent}else:") + # +1 is for current statement. + num_statements = operand2 + nesting_lines_remaining.append(num_statements + 1) + used_variables.append(None) + elif op is PLACEHOLDER_OP: + lines.append(f"{indent}_ = 0") + elif op == "*=" and operand2 == 0: + line = f"{indent}raise RuntimeError()" + lines.append(line) + else: + variable_index, operand = operand1, operand2 + line = f"{indent}v{variable_index} {op} {operand}" + lines.append(line) + + # Decrement nesting. + for nesting_index in range(len(nesting_lines_remaining)): + nesting_lines_remaining[nesting_index] -= 1 + while nesting_lines_remaining and nesting_lines_remaining[-1] == 0: + nesting_lines_remaining.pop() + used_variables.pop() + + return "\n".join(lines) diff --git a/core/data/generation/python_interpreter.py b/core/data/generation/python_interpreter.py new file mode 100644 index 00000000..83b2f17f --- /dev/null +++ b/core/data/generation/python_interpreter.py @@ -0,0 +1,185 @@ +"""Python interpreter that operates on control flow graphs.""" + +import math +from absl import logging # pylint: disable=unused-import +import astunparse +import gast as ast +import tree + + +class ExecExecutor(object): + """A Python executor that uses exec. + + Potentially unsafe; use only with trusted code. + """ + + def __init__(self): + self.locals = {} + + def execute(self, code): + exec(code, # pylint:disable=exec-used + {'__builtins__': {'True': True, 'False': False, + 'range': range, + 'sqrt': math.sqrt, + 'AssertionError': AssertionError, + 'RuntimeError': RuntimeError, + 'len': len, + }}, + self.locals) + + def get_values(self, mod=None): + """Gets the values (mod `mod`, if applicable) of the executor.""" + values = self.locals.copy() + if mod is not None: + values = tree.map_structure(lambda x: x % mod, values) + return values + + +def evaluate_cfg(executor, cfg, mod=None, initial_values=None, timeout=None): + """Evaluates a Python program given its control flow graph. + + Args: + executor: The executor with which to perform the execution. + cfg: The control flow graph of the program to execute. + mod: The values are computed mod this. + initial_values: Optional dictionary mapping variable names to values. + timeout: Optional maximum number of basic blocks to evaluate before + raising a timeout RuntimeError. + Returns: + A values dictionary mapping variable names to their final values. + Raises: + RuntimeError: If timeout is given and the program runs for more than + `timeout` blocks, a RuntimeError is raised. + """ + executor.locals = {} + block = cfg.start_block + values = initial_values or {} # Default to no initial values. + blocks_evaluated = 0 + while block: + if timeout and blocks_evaluated > timeout: + raise RuntimeError('Evaluation of CFG has timed out.') + block, values = evaluate_until_next_basic_block( + executor, block, mod=mod, values=values) + blocks_evaluated += 1 + return values + + +def evaluate_until_next_basic_block(executor, basic_block, mod, values): + """Takes a single step of control flow graph evaluation. + + Evaluates a single basic block starting from the provided values. Returns + the correct next basic block to step to and the new values of all the + variables. + + Args: + executor: The executor with which to take a step of execution. + basic_block: (control_flow.BasicBlock) A single basic block from the control + flow graph. + mod: The values are computed mod this. + values: A dict mapping variable names to literal Python values. + Returns: + The next basic block to execute and the new mapping from variable names to + values. + """ + values = evaluate_basic_block(executor, basic_block, mod=mod, values=values) + if not basic_block.exits_from_end: + # This is the end of the program. + return None, values + elif len(basic_block.exits_from_end) == 1: + # TODO(dbieber): Modify control_flow.BasicBlock API to have functions + # `has_only_one_exit` and `get_only_exit`. + basic_block = next(iter(basic_block.exits_from_end)) + else: + assert len(basic_block.exits_from_end) == 2 + assert len(basic_block.branches) == 2, basic_block.branches + assert 'vBranch' in values + branch_decision = bool(values['vBranch']) + basic_block = basic_block.branches[branch_decision] + return basic_block, values + + +def evaluate_until_branch_decision(executor, basic_block, mod, values): + """Evaluates a Python program until reaching a branch decision. + + Evaluates one basic block at a time until a branch decision is reached. + Returns the resulting values of the variables, the instructions executed, + and the branch decision. The branch decision is represented as a dict mapping + True/False to the next basic block after the branch decision. + + Args: + executor: The executor with which to perform the execution. + basic_block: A single basic block from the control flow graph. + mod: The values are computed mod this. + values: A dict mapping variable names to literal Python values. + Returns: + A triple (values, instructions, branches). `values` is the resulting values + of the variables. `instructions` is a list of the instructions executed, + and `branches` is the branch decision reached, represented as a dict mapping + True/False to the next basic block after the branch decision. + """ + instructions = [] + + done = False + branches = None + while not done: + # Collect the instructions from the current block. + nodes = basic_block.control_flow_nodes + for node in nodes: + instructions.append(node.instruction) + + # Evaluate the current block. + values = evaluate_basic_block(executor, basic_block, mod=mod, values=values) + + # Determine next block to evaluate or whether to stop. + # TODO(dbieber): Refactor to reduce redundancy with + # evaluate_until_next_basic_block. + if not basic_block.exits_from_end: + # The program has terminated. + done = True + branches = None + elif len(basic_block.exits_from_end) == 1: + # There is no branch decision at this point. Keep evaluating. + basic_block = next(iter(basic_block.exits_from_end)) + else: + # Evaluation has reached a branch decision. + assert len(basic_block.exits_from_end) == 2 + assert len(basic_block.branches) == 2 + assert 'vBranch' in values + done = True + branches = basic_block.branches + + return values, instructions, branches + + +def evaluate_basic_block(executor, basic_block, mod, values): + """Evaluates a single basic block of Python with an executor. + + Args: + executor: The executor with which to perform the execution. + basic_block: A control_flow.BasicBlock of Python statements. + mod: The values are computed mod this. + values: A dictionary mapping variable names to their Python literal values. + Returns: + A dictionary mapping variable names to their final values at the end of + evaluating the basic block. + """ + + for var, value in values.items(): + python_source = f'{var} = {value}' + executor.execute(python_source) + + nodes = basic_block.control_flow_nodes + for index, node in enumerate(nodes): + instruction = node.instruction + ast_node = instruction.node + python_source = astunparse.unparse(ast_node) + + make_branch_decision = (index == len(nodes) - 1 and basic_block.branches) + if make_branch_decision: + python_source = 'vBranch = ' + python_source + + executor.execute(python_source) + + # Extract the values of the v0, v1... variables. + values = executor.get_values(mod=mod) + return values