Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/generate/generate_masked_fill_in_blank_qa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Generate Masked Fill-in-blank QAs
# TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The README file currently only contains "TODO". To make this example useful for other developers and users, please add a brief description of what this feature does, what the configuration options mean, and how to run it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 -m graphgen.run \
--config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
global_params:
working_dir: cache
graph_backend: networkx # graph database backend, support: kuzu, networkx
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv

nodes:
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
op_name: read
type: source
dependencies: []
params:
input_path:
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples

- id: chunk_documents
op_name: chunk
type: map_batch
dependencies:
- read_files
execution_params:
replicas: 4
params:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting

- id: build_kg
op_name: build_kg
type: map_batch
dependencies:
- chunk_documents
execution_params:
replicas: 1
batch_size: 128

- id: partition
op_name: partition
type: aggregate
dependencies:
- build_kg
params:
method: quintuple

- id: generate
op_name: generate
type: map_batch
dependencies:
- partition
execution_params:
replicas: 1
batch_size: 128
save_output: true # save output
params:
method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa
data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs
6 changes: 6 additions & 0 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,10 @@ def format_generation_results(
{"role": "assistant", "content": answer},
]
}

if output_data_format == "QA_pairs":
return {
"question": question,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a trailing whitespace after question,. Please remove it to maintain code style consistency.

Suggested change
"question": question,
"question": question,

"answer": answer,
}
raise ValueError(f"Unknown output data format: {output_data_format}")
6 changes: 6 additions & 0 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AtomicGenerator,
CoTGenerator,
FillInBlankGenerator,
MaskedFillInBlankGenerator,
MultiAnswerGenerator,
MultiChoiceGenerator,
MultiHopGenerator,
Expand All @@ -30,6 +31,8 @@
DFSPartitioner,
ECEPartitioner,
LeidenPartitioner,
QuintuplePartitioner,
TriplePartitioner,
)
from .reader import (
CSVReader,
Expand Down Expand Up @@ -73,6 +76,7 @@
"QuizGenerator": ".generator",
"TrueFalseGenerator": ".generator",
"VQAGenerator": ".generator",
"MaskedFillInBlankGenerator": ".generator",
# KG Builder
"LightRAGKGBuilder": ".kg_builder",
"MMKGBuilder": ".kg_builder",
Expand All @@ -86,6 +90,8 @@
"DFSPartitioner": ".partitioner",
"ECEPartitioner": ".partitioner",
"LeidenPartitioner": ".partitioner",
"TriplePartitioner": ".partitioner",
"QuintuplePartitioner": ".partitioner",
# Reader
"CSVReader": ".reader",
"JSONReader": ".reader",
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .quiz_generator import QuizGenerator
from .true_false_generator import TrueFalseGenerator
from .vqa_generator import VQAGenerator
from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator
134 changes: 134 additions & 0 deletions graphgen/models/generator/masked_fill_in_blank_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import random
import re
from typing import Any, Optional

from graphgen.bases import BaseGenerator
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
from graphgen.utils import detect_main_language, logger

random.seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation, which can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.choice on line 103.



class MaskedFillInBlankGenerator(BaseGenerator):
"""
Masked Fill-in-blank Generator follows a TWO-STEP process:
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text.
"""

@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""
Build prompts for REPHRASE.
:param batch
:return:
"""
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relations_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relations_str)

# TODO: configure add_context
# if add_context:
# original_ids = [
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
# original_ids = list(set(original_ids))
# original_text = await text_chunks_storage.get_by_ids(original_ids)
# original_text = "\n".join(
# [
# f"{index + 1}. {text['content']}"
# for index, text in enumerate(original_text)
# ]
# )
Comment on lines +43 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large block of commented-out code seems to be related to a future feature (add_context). It's better to remove commented-out code from the codebase to improve readability. If this logic is needed for future reference, it should be tracked in an issue or a separate branch.

prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
entities=entities_str, relationships=relations_str
)
return prompt

@staticmethod
def parse_rephrased_text(response: str) -> Optional[str]:
"""
Parse the rephrased text from the response.
:param response:
:return: rephrased text
"""
rephrased_match = re.search(
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
)
if rephrased_match:
rephrased_text = rephrased_match.group(1).strip()
else:
logger.warning("Failed to parse rephrased text from response: %s", response)
return None
return rephrased_text.strip('"').strip("'")

@staticmethod
def parse_response(response: str) -> dict:
pass
Comment on lines +78 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parse_response method is defined as an abstract method in the BaseGenerator class but is implemented with pass here. Additionally, the return type hint dict is incompatible with the base class's list[dict]. Since this method is not used in the overridden generate method, it should either be implemented correctly or raise NotImplementedError to adhere to the abstract base class contract.

Suggested change
@staticmethod
def parse_response(response: str) -> dict:
pass
@staticmethod
def parse_response(response: str) -> list[dict]:
raise NotImplementedError("This method is not used in MaskedFillInBlankGenerator as it overrides the `generate` method.")


async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> list[dict]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
rephrasing_prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(rephrasing_prompt)
context = self.parse_rephrased_text(response)
if not context:
return []

nodes, edges = batch

assert len(nodes) == 3, (
"MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, "
f"but got {len(nodes)} nodes."
)
assert len(edges) == 2, (
"MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, "
f"but got {len(edges)} edges."
)

node1, node2, node3 = nodes
mask_node = random.choice([node1, node2, node3])
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)

match = re.search(mask_pattern, context)
if match:
gth = match.group(0)
masked_context = mask_pattern.sub("___", context)
else:
logger.debug(
"Regex Match Failed!\n"
"Expected name of node: %s\n"
"Actual context: %s\n",
mask_node_name,
context,
)
return []

logger.debug("masked_context: %s", masked_context)
qa_pairs = {
"question": masked_context,
"answer": gth,
}
return [qa_pairs]
2 changes: 2 additions & 0 deletions graphgen/models/partitioner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from .dfs_partitioner import DFSPartitioner
from .ece_partitioner import ECEPartitioner
from .leiden_partitioner import LeidenPartitioner
from .quintuple_partitioner import QuintuplePartitioner
from .triple_partitioner import TriplePartitioner
74 changes: 74 additions & 0 deletions graphgen/models/partitioner/quintuple_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import random
from collections import deque
from typing import Any, Iterable, Set

from graphgen.bases import BaseGraphStorage, BasePartitioner
from graphgen.bases.datatypes import Community

random.seed(42)


class QuintuplePartitioner(BasePartitioner):
"""
quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node).
1. Automatically ignore isolated points.
2. In each connected component, yield quintuples in the order of BFS.
"""

def partition(
self,
g: BaseGraphStorage,
**kwargs: Any,
) -> Iterable[Community]:
nodes = [n[0] for n in g.get_all_nodes()]
random.shuffle(nodes)

visited_nodes: Set[str] = set()
used_edges: Set[frozenset[str]] = set()

for seed in nodes:
if seed in visited_nodes:
continue

# start BFS in a connected component
queue = deque([seed])
visited_nodes.add(seed)

while queue:
u = queue.popleft()

# collect all neighbors connected to node u via unused edges
available_neighbors = []
for v in g.get_neighbors(u):
edge_key = frozenset((u, v))
if edge_key not in used_edges:
available_neighbors.append(v)

# standard BFS queue maintenance
if v not in visited_nodes:
visited_nodes.add(v)
queue.append(v)

random.shuffle(available_neighbors)

# every two neighbors paired with the center node u creates one quintuple
# Note: If available_neighbors has an odd length, the remaining edge
# stays unused for now. It may be matched into a quintuple later
# when its other endpoint is processed as a center node.
for i in range(0, len(available_neighbors) // 2 * 2, 2):
v1 = available_neighbors[i]
v2 = available_neighbors[i + 1]

edge1 = frozenset((u, v1))
edge2 = frozenset((u, v2))

used_edges.add(edge1)
used_edges.add(edge2)

v1_s, v2_s = sorted((v1, v2))

yield Community(
id=f"{v1_s}-{u}-{v2_s}",
nodes=[v1_s, u, v2_s],
edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))],
)
58 changes: 58 additions & 0 deletions graphgen/models/partitioner/triple_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import random
from collections import deque
from typing import Any, Iterable, Set

from graphgen.bases import BaseGraphStorage, BasePartitioner
from graphgen.bases.datatypes import Community

random.seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation. This can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.shuffle.



class TriplePartitioner(BasePartitioner):
"""
Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
1. Automatically ignore isolated points.
2. In each connected component, yield triples in the order of BFS.
"""

def partition(
self,
g: BaseGraphStorage,
**kwargs: Any,
) -> Iterable[Community]:
nodes = [n[0] for n in g.get_all_nodes()]
random.shuffle(nodes)

visited_nodes: Set[str] = set()
used_edges: Set[frozenset[str]] = set()

for seed in nodes:
if seed in visited_nodes:
continue

# start BFS in a connected component
queue = deque([seed])
visited_nodes.add(seed)

while queue:
u = queue.popleft()

for v in g.get_neighbors(u):
edge_key = frozenset((u, v))

# if this edge has not been used, a new triple has been found
if edge_key not in used_edges:
used_edges.add(edge_key)

# use the edge name to ensure the uniqueness of the ID
u_sorted, v_sorted = sorted((u, v))
yield Community(
id=f"{u_sorted}-{v_sorted}",
nodes=[u_sorted, v_sorted],
edges=[(u_sorted, v_sorted)],
)

# continue to BFS
if v not in visited_nodes:
visited_nodes.add(v)
queue.append(v)
4 changes: 4 additions & 0 deletions graphgen/operators/generate/generate_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(
self.llm_client,
num_of_questions=generate_kwargs.get("num_of_questions", 5),
)
elif self.method == "masked_fill_in_blank":
from graphgen.models import MaskedFillInBlankGenerator

self.generator = MaskedFillInBlankGenerator(self.llm_client)
elif self.method == "true_false":
from graphgen.models import TrueFalseGenerator

Expand Down
Loading