diff --git a/examples/generate/generate_masked_fill_in_blank_qa/README.md b/examples/generate/generate_masked_fill_in_blank_qa/README.md new file mode 100644 index 00000000..3251d5b9 --- /dev/null +++ b/examples/generate/generate_masked_fill_in_blank_qa/README.md @@ -0,0 +1,2 @@ +# Generate Masked Fill-in-blank QAs +# TODO diff --git a/examples/generate/generate_masked_fill_in_blank_qa/generate_masked_fill_in_blank.sh b/examples/generate/generate_masked_fill_in_blank_qa/generate_masked_fill_in_blank.sh new file mode 100644 index 00000000..c974bffa --- /dev/null +++ b/examples/generate/generate_masked_fill_in_blank_qa/generate_masked_fill_in_blank.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml diff --git a/examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml b/examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml new file mode 100644 index 00000000..f6b0a63b --- /dev/null +++ b/examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml @@ -0,0 +1,54 @@ +global_params: + working_dir: cache + graph_backend: networkx # graph database backend, support: kuzu, networkx + kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: quintuple + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + save_output: true # save output + params: + method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa + data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index eb204535..b83be604 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -74,4 +74,10 @@ def format_generation_results( {"role": "assistant", "content": answer}, ] } + + if output_data_format == "QA_pairs": + return { + "question": question, + "answer": answer, + } raise ValueError(f"Unknown output data format: {output_data_format}") diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 95ccd1ae..8bd2c9d8 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -15,6 +15,7 @@ AtomicGenerator, CoTGenerator, FillInBlankGenerator, + MaskedFillInBlankGenerator, MultiAnswerGenerator, MultiChoiceGenerator, MultiHopGenerator, @@ -30,6 +31,8 @@ DFSPartitioner, ECEPartitioner, LeidenPartitioner, + QuintuplePartitioner, + TriplePartitioner, ) from .reader import ( CSVReader, @@ -73,6 +76,7 @@ "QuizGenerator": ".generator", "TrueFalseGenerator": ".generator", "VQAGenerator": ".generator", + "MaskedFillInBlankGenerator": ".generator", # KG Builder "LightRAGKGBuilder": ".kg_builder", "MMKGBuilder": ".kg_builder", @@ -86,6 +90,8 @@ "DFSPartitioner": ".partitioner", "ECEPartitioner": ".partitioner", "LeidenPartitioner": ".partitioner", + "TriplePartitioner": ".partitioner", + "QuintuplePartitioner": ".partitioner", # Reader "CSVReader": ".reader", "JSONReader": ".reader", diff --git a/graphgen/models/generator/__init__.py b/graphgen/models/generator/__init__.py index 8562c34b..6fd25629 100644 --- a/graphgen/models/generator/__init__.py +++ b/graphgen/models/generator/__init__.py @@ -8,3 +8,4 @@ from .quiz_generator import QuizGenerator from .true_false_generator import TrueFalseGenerator from .vqa_generator import VQAGenerator +from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator diff --git a/graphgen/models/generator/masked_fill_in_blank_generator.py b/graphgen/models/generator/masked_fill_in_blank_generator.py new file mode 100644 index 00000000..bc6f06c8 --- /dev/null +++ b/graphgen/models/generator/masked_fill_in_blank_generator.py @@ -0,0 +1,134 @@ +import random +import re +from typing import Any, Optional + +from graphgen.bases import BaseGenerator +from graphgen.templates import AGGREGATED_GENERATION_PROMPT +from graphgen.utils import detect_main_language, logger + +random.seed(42) + + +class MaskedFillInBlankGenerator(BaseGenerator): + """ + Masked Fill-in-blank Generator follows a TWO-STEP process: + 1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning. + 2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text. + """ + + @staticmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + """ + Build prompts for REPHRASE. + :param batch + :return: + """ + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + relations_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + language = detect_main_language(entities_str + relations_str) + + # TODO: configure add_context + # if add_context: + # original_ids = [ + # node["source_id"].split("")[0] for node in _process_nodes + # ] + [edge[2]["source_id"].split("")[0] for edge in _process_edges] + # original_ids = list(set(original_ids)) + # original_text = await text_chunks_storage.get_by_ids(original_ids) + # original_text = "\n".join( + # [ + # f"{index + 1}. {text['content']}" + # for index, text in enumerate(original_text) + # ] + # ) + prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format( + entities=entities_str, relationships=relations_str + ) + return prompt + + @staticmethod + def parse_rephrased_text(response: str) -> Optional[str]: + """ + Parse the rephrased text from the response. + :param response: + :return: rephrased text + """ + rephrased_match = re.search( + r"(.*?)", response, re.DOTALL + ) + if rephrased_match: + rephrased_text = rephrased_match.group(1).strip() + else: + logger.warning("Failed to parse rephrased text from response: %s", response) + return None + return rephrased_text.strip('"').strip("'") + + @staticmethod + def parse_response(response: str) -> dict: + pass + + async def generate( + self, + batch: tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ], + ) -> list[dict]: + """ + Generate QAs based on a given batch. + :param batch + :return: QA pairs + """ + rephrasing_prompt = self.build_prompt(batch) + response = await self.llm_client.generate_answer(rephrasing_prompt) + context = self.parse_rephrased_text(response) + if not context: + return [] + + nodes, edges = batch + + assert len(nodes) == 3, ( + "MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, " + f"but got {len(nodes)} nodes." + ) + assert len(edges) == 2, ( + "MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, " + f"but got {len(edges)} edges." + ) + + node1, node2, node3 = nodes + mask_node = random.choice([node1, node2, node3]) + mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t") + mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE) + + match = re.search(mask_pattern, context) + if match: + gth = match.group(0) + masked_context = mask_pattern.sub("___", context) + else: + logger.debug( + "Regex Match Failed!\n" + "Expected name of node: %s\n" + "Actual context: %s\n", + mask_node_name, + context, + ) + return [] + + logger.debug("masked_context: %s", masked_context) + qa_pairs = { + "question": masked_context, + "answer": gth, + } + return [qa_pairs] diff --git a/graphgen/models/partitioner/__init__.py b/graphgen/models/partitioner/__init__.py index 2e1bcb68..4306f247 100644 --- a/graphgen/models/partitioner/__init__.py +++ b/graphgen/models/partitioner/__init__.py @@ -3,3 +3,5 @@ from .dfs_partitioner import DFSPartitioner from .ece_partitioner import ECEPartitioner from .leiden_partitioner import LeidenPartitioner +from .quintuple_partitioner import QuintuplePartitioner +from .triple_partitioner import TriplePartitioner diff --git a/graphgen/models/partitioner/quintuple_partitioner.py b/graphgen/models/partitioner/quintuple_partitioner.py new file mode 100644 index 00000000..7d570b28 --- /dev/null +++ b/graphgen/models/partitioner/quintuple_partitioner.py @@ -0,0 +1,74 @@ +import random +from collections import deque +from typing import Any, Iterable, Set + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + +random.seed(42) + + +class QuintuplePartitioner(BasePartitioner): + """ + quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node). + 1. Automatically ignore isolated points. + 2. In each connected component, yield quintuples in the order of BFS. + """ + + def partition( + self, + g: BaseGraphStorage, + **kwargs: Any, + ) -> Iterable[Community]: + nodes = [n[0] for n in g.get_all_nodes()] + random.shuffle(nodes) + + visited_nodes: Set[str] = set() + used_edges: Set[frozenset[str]] = set() + + for seed in nodes: + if seed in visited_nodes: + continue + + # start BFS in a connected component + queue = deque([seed]) + visited_nodes.add(seed) + + while queue: + u = queue.popleft() + + # collect all neighbors connected to node u via unused edges + available_neighbors = [] + for v in g.get_neighbors(u): + edge_key = frozenset((u, v)) + if edge_key not in used_edges: + available_neighbors.append(v) + + # standard BFS queue maintenance + if v not in visited_nodes: + visited_nodes.add(v) + queue.append(v) + + random.shuffle(available_neighbors) + + # every two neighbors paired with the center node u creates one quintuple + # Note: If available_neighbors has an odd length, the remaining edge + # stays unused for now. It may be matched into a quintuple later + # when its other endpoint is processed as a center node. + for i in range(0, len(available_neighbors) // 2 * 2, 2): + v1 = available_neighbors[i] + v2 = available_neighbors[i + 1] + + edge1 = frozenset((u, v1)) + edge2 = frozenset((u, v2)) + + used_edges.add(edge1) + used_edges.add(edge2) + + v1_s, v2_s = sorted((v1, v2)) + + yield Community( + id=f"{v1_s}-{u}-{v2_s}", + nodes=[v1_s, u, v2_s], + edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))], + ) diff --git a/graphgen/models/partitioner/triple_partitioner.py b/graphgen/models/partitioner/triple_partitioner.py new file mode 100644 index 00000000..2bdfe8d5 --- /dev/null +++ b/graphgen/models/partitioner/triple_partitioner.py @@ -0,0 +1,58 @@ +import random +from collections import deque +from typing import Any, Iterable, Set + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + +random.seed(42) + + +class TriplePartitioner(BasePartitioner): + """ + Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node). + 1. Automatically ignore isolated points. + 2. In each connected component, yield triples in the order of BFS. + """ + + def partition( + self, + g: BaseGraphStorage, + **kwargs: Any, + ) -> Iterable[Community]: + nodes = [n[0] for n in g.get_all_nodes()] + random.shuffle(nodes) + + visited_nodes: Set[str] = set() + used_edges: Set[frozenset[str]] = set() + + for seed in nodes: + if seed in visited_nodes: + continue + + # start BFS in a connected component + queue = deque([seed]) + visited_nodes.add(seed) + + while queue: + u = queue.popleft() + + for v in g.get_neighbors(u): + edge_key = frozenset((u, v)) + + # if this edge has not been used, a new triple has been found + if edge_key not in used_edges: + used_edges.add(edge_key) + + # use the edge name to ensure the uniqueness of the ID + u_sorted, v_sorted = sorted((u, v)) + yield Community( + id=f"{u_sorted}-{v_sorted}", + nodes=[u_sorted, v_sorted], + edges=[(u_sorted, v_sorted)], + ) + + # continue to BFS + if v not in visited_nodes: + visited_nodes.add(v) + queue.append(v) diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py index 1868a50e..18ce1d43 100644 --- a/graphgen/operators/generate/generate_service.py +++ b/graphgen/operators/generate/generate_service.py @@ -71,6 +71,10 @@ def __init__( self.llm_client, num_of_questions=generate_kwargs.get("num_of_questions", 5), ) + elif self.method == "masked_fill_in_blank": + from graphgen.models import MaskedFillInBlankGenerator + + self.generator = MaskedFillInBlankGenerator(self.llm_client) elif self.method == "true_false": from graphgen.models import TrueFalseGenerator diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index dfadf8da..b55c5159 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -28,7 +28,7 @@ def __init__( self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) method = partition_kwargs["method"] - self.method_params = partition_kwargs["method_params"] + self.method_params = partition_kwargs.get("method_params", {}) if method == "bfs": from graphgen.models import BFSPartitioner @@ -57,6 +57,14 @@ def __init__( if self.method_params.get("anchor_ids") else None, ) + elif method == "triple": + from graphgen.models import TriplePartitioner + + self.partitioner = TriplePartitioner() + elif method == "quintuple": + from graphgen.models import QuintuplePartitioner + + self.partitioner = QuintuplePartitioner() else: raise ValueError(f"Unsupported partition method: {method}")