|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import functools |
| 5 | +import os |
| 6 | + |
| 7 | +from utils import Tools, FilePathBuilder, CodexTokenizer, CodeGenTokenizer, CONSTANTS |
| 8 | + |
| 9 | +class PromptBuilder: |
| 10 | + def __init__(self, query_lines_with_retrieval_results, task_path, log_message, tokenizer): |
| 11 | + self.query_lines_with_retrieval_results = query_lines_with_retrieval_results |
| 12 | + self.log_message = log_message |
| 13 | + if tokenizer == CodexTokenizer: |
| 14 | + self.tokenizer = CodexTokenizer() |
| 15 | + self.max_retrieval_length = 2000 # half of the max length of the model |
| 16 | + elif tokenizer == CodeGenTokenizer: |
| 17 | + self.tokenizer = CodeGenTokenizer() |
| 18 | + self.max_retrieval_length = 1000 |
| 19 | + tasks = Tools.load_jsonl(task_path) |
| 20 | + self.tasks_by_task_id = {task['metadata']['task_id']: task for task in tasks} |
| 21 | + self.seperator = '# ' + '-' * 50 |
| 22 | + self.max_examples = 10 # maximum number of examples to be included in the prompt |
| 23 | + |
| 24 | + def _make_a_block(self, retrieved_context): |
| 25 | + content, sim_score = retrieved_context |
| 26 | + metadata = content['metadata'] |
| 27 | + # put the file path in the comment |
| 28 | + assert metadata[0]['fpath_tuple'][0] == metadata[0]['repo'] |
| 29 | + f_paths = ['/'.join(x['fpath_tuple'][1:]) for x in metadata] |
| 30 | + f_paths_str = '\n'.join([f'# {f_path}' for f_path in f_paths]) |
| 31 | + f_path_comment = f'# the below code fragment can be found in:' |
| 32 | + # put code lines in the comment |
| 33 | + content_lines = content['context'].splitlines() |
| 34 | + content_lines_comment = [f'# {line}' for line in content_lines] |
| 35 | + # aggregate the comment and the code lines |
| 36 | + |
| 37 | + block_str = '\n'.join([f_path_comment, f_paths_str, self.seperator] + content_lines_comment + [self.seperator]) + '\n' |
| 38 | + tokenized_block = self.tokenizer.tokenize(block_str) |
| 39 | + token_len = len(tokenized_block) |
| 40 | + return block_str, token_len |
| 41 | + |
| 42 | + def _make_an_extended_block(self, retrieved_context): |
| 43 | + content, sim_score = retrieved_context |
| 44 | + metadata = content['metadata'] |
| 45 | + # put the file path in the comment |
| 46 | + assert metadata[0]['fpath_tuple'][0] == metadata[0]['repo'] |
| 47 | + f_paths = ['/'.join(x['fpath_tuple'][1:]) for x in metadata] |
| 48 | + f_paths_str = '\n'.join([f'# {f_path}' for f_path in f_paths]) |
| 49 | + f_path_comment = f'# the below code fragment can be found in:' |
| 50 | + # put code lines in the comment |
| 51 | + original_code = Tools.read_code(os.path.join(FilePathBuilder.repo_base_dir, *metadata[0]['fpath_tuple'])) |
| 52 | + code_lines = original_code.splitlines() |
| 53 | + end_line_no = metadata[0]['end_line_no'] |
| 54 | + window_size = metadata[0]['window_size'] |
| 55 | + slice_size = metadata[0]['slice_size'] |
| 56 | + new_end_line_no = min(end_line_no + window_size // slice_size, len(code_lines)) |
| 57 | + new_start_line_no = max(0, new_end_line_no - window_size) |
| 58 | + content_lines = code_lines[new_start_line_no:new_end_line_no] |
| 59 | + content_lines_comment = [f'# {line}' for line in content_lines] |
| 60 | + # aggregate the comment and the code lines |
| 61 | + block_str = '\n'.join([f_path_comment, f_paths_str, self.seperator] + content_lines_comment + [self.seperator]) + '\n' |
| 62 | + tokenized_block = self.tokenizer.tokenize(block_str) |
| 63 | + token_len = len(tokenized_block) |
| 64 | + return block_str, token_len |
| 65 | + |
| 66 | + def _build_prompt(self, mode, prompt, top_k_context): |
| 67 | + prepend_context = "# Here are some relevant code fragments from other files of the repo:\n" |
| 68 | + prepend_context += self.seperator + '\n' |
| 69 | + current_token_length = 20 # the length of the head_prompt, same for codex and codegen tokenizer |
| 70 | + prepend_blocks = [] |
| 71 | + chosen_context = [] |
| 72 | + make_block_func = self._make_an_extended_block if mode == CONSTANTS.rg else self._make_a_block |
| 73 | + for retrieved_context in top_k_context[::-1]: |
| 74 | + if len(chosen_context) >= self.max_examples: |
| 75 | + break |
| 76 | + block_str, token_len = make_block_func(retrieved_context) |
| 77 | + if current_token_length + token_len < self.max_retrieval_length: |
| 78 | + prepend_blocks.insert(0, block_str) |
| 79 | + current_token_length += token_len |
| 80 | + chosen_context.append(retrieved_context) |
| 81 | + else: |
| 82 | + continue |
| 83 | + prepend_context += ''.join(prepend_blocks) # all the blocks already have a line break at the end |
| 84 | + return prepend_context + '\n' + prompt, chosen_context |
| 85 | + |
| 86 | + def build_2nd_stage_input_file(self, mode): |
| 87 | + new_prompt_lines = [] |
| 88 | + for query_line in self.query_lines_with_retrieval_results: |
| 89 | + task_id = query_line['metadata']['task_id'] |
| 90 | + task = self.tasks_by_task_id[task_id] |
| 91 | + old_prompt = task['prompt'] |
| 92 | + top_k_context = query_line['top_k_context'] |
| 93 | + new_prompt, chosen_context = self._build_prompt(mode, old_prompt, top_k_context) |
| 94 | + new_prompt_line = { |
| 95 | + 'prompt': new_prompt, |
| 96 | + 'metadata': task['metadata'], |
| 97 | + } |
| 98 | + new_prompt_line['metadata']['query_window'] = { |
| 99 | + 'context': query_line['context'], |
| 100 | + 'metadata': query_line['metadata'], |
| 101 | + } |
| 102 | + new_prompt_line['metadata']['top_k_context'] = [ |
| 103 | + { |
| 104 | + 'context': x[0]['context'], |
| 105 | + 'metadata': x[0]['metadata'], |
| 106 | + 'sim_score': x[1], |
| 107 | + } for x in chosen_context |
| 108 | + ] |
| 109 | + new_prompt_line['metadata']['window_size'] = query_line['metadata']['window_size'] |
| 110 | + new_prompt_line['metadata']['slice_size'] = chosen_context[0][0]['metadata'][0]['slice_size'] |
| 111 | + new_prompt_lines.append(new_prompt_line) |
| 112 | + print('done! ' + self.log_message) |
| 113 | + return new_prompt_lines |
| 114 | + |
| 115 | +class BuildPromptWrapper: |
| 116 | + def __init__(self, vectorizer, benchmark, repos, window_size, slice_size, tokenizer): |
| 117 | + if vectorizer == 'one-gram': |
| 118 | + self.vector_path_builder = FilePathBuilder.one_gram_vector_path |
| 119 | + elif vectorizer == 'ada002': |
| 120 | + self.vector_path_builder = FilePathBuilder.ada002_vector_path |
| 121 | + self.max_top_k = 20 |
| 122 | + self.repos = repos |
| 123 | + self.window_size = window_size |
| 124 | + self.slice_size = slice_size |
| 125 | + if benchmark == CONSTANTS.line_benchmark: |
| 126 | + self.task_path = FilePathBuilder.random_line_completion_benchmark |
| 127 | + elif benchmark == CONSTANTS.api_benchmark: |
| 128 | + self.task_path = FilePathBuilder.api_completion_benchmark |
| 129 | + elif benchmark == CONSTANTS.short_api_benchmark: |
| 130 | + self.task_path = FilePathBuilder.short_api_completion_benchmark |
| 131 | + elif benchmark == CONSTANTS.short_line_benchmark: |
| 132 | + self.task_path = FilePathBuilder.short_random_line_completion_benchmark |
| 133 | + self.benchmark = benchmark |
| 134 | + self.tokenizer = tokenizer |
| 135 | + |
| 136 | + def _run(self, mode, query_window_path_builder, output_file_path): |
| 137 | + workers = [] |
| 138 | + for repo in self.repos: |
| 139 | + query_window_path = query_window_path_builder(repo, self.window_size) |
| 140 | + query_line_path = self.vector_path_builder(query_window_path) |
| 141 | + repo_window_path = FilePathBuilder.repo_windows_path(repo, self.window_size, self.slice_size) |
| 142 | + repo_embedding_path = self.vector_path_builder(repo_window_path) |
| 143 | + retrieval_results = FilePathBuilder.retrieval_results_path(query_line_path, repo_embedding_path, self.max_top_k) |
| 144 | + |
| 145 | + query_lines_with_retrieval_results = Tools.load_pickle(retrieval_results) |
| 146 | + log_message = f'repo: {repo}, window: {self.window_size}, slice: {self.slice_size}' |
| 147 | + worker = PromptBuilder(query_lines_with_retrieval_results, self.task_path, log_message, self.tokenizer) |
| 148 | + workers.append(worker) |
| 149 | + lines = [] |
| 150 | + for worker in workers: |
| 151 | + lines += worker.build_2nd_stage_input_file(mode) |
| 152 | + Tools.dump_jsonl(lines, output_file_path) |
| 153 | + |
| 154 | + def build_first_search_prompt(self, mode, output_path): |
| 155 | + query_line_path_temp = functools.partial(FilePathBuilder.search_first_window_path, self.benchmark, mode) |
| 156 | + self._run(mode, query_line_path_temp, output_path) |
| 157 | + |
| 158 | + |
| 159 | + def build_prediction_prompt(self, mode, prediction_path, output_path): |
| 160 | + query_line_path_temp = functools.partial(FilePathBuilder.gen_first_window_path, self.benchmark, mode, prediction_path) |
| 161 | + self._run(mode, query_line_path_temp, output_path) |
| 162 | + |
0 commit comments