Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def from_dict(key: str, data: dict) -> "Chunk":
return Chunk(
id=key,
content=data.get("content", ""),
type=data.get("type", "unknown"),
type=data.get("type", "text"),
metadata={k: v for k, v in data.items() if k != "content"},
)

Expand Down
19 changes: 19 additions & 0 deletions graphgen/configs/protein_qa_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
read:
input_file: resources/input_examples/protein_qa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
anchor_type: protein # get protein information from chunks
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
search: # web search configuration
enabled: false # whether to enable web search
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
enabled: false
partition: # graph partition configuration
method: anchor_bfs # partition method
method_params:
anchor_type: protein # node type to select anchor nodes
max_units_per_community: 10 # atomic partition, one node or edge per community
generate:
mode: protein_qa # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML
138 changes: 37 additions & 101 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
Tokenizer,
)
from graphgen.operators import (
build_mm_kg,
build_text_kg,
build_kg,
chunk_documents,
generate_qas,
init_llm,
Expand Down Expand Up @@ -96,109 +95,46 @@ async def insert(self, read_config: Dict, split_config: Dict):
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
new_text_docs = {k: v for k, v in new_docs.items() if v.get("type") == "text"}
new_mm_docs = {k: v for k, v in new_docs.items() if v.get("type") != "text"}

await self.full_docs_storage.upsert(new_docs)

async def _insert_text_docs(text_docs):
if len(text_docs) == 0:
logger.warning("All text docs are already in the storage")
return
logger.info("[New Docs] inserting %d text docs", len(text_docs))
# Step 2.1: Split chunks and filter existing ones
inserting_chunks = await chunk_documents(
text_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)

_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}

if len(inserting_chunks) == 0:
logger.warning("All text chunks are already in the storage")
return

logger.info("[New Chunks] inserting %d text chunks", len(inserting_chunks))
await self.chunks_storage.upsert(inserting_chunks)

# Step 2.2: Extract entities and relations from text chunks
logger.info("[Text Entity and Relation Extraction] processing ...")
_add_entities_and_relations = await build_text_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[
Chunk(id=k, content=v["content"], type="text")
for k, v in inserting_chunks.items()
],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted from text chunks")
return

await self._insert_done()
return _add_entities_and_relations

async def _insert_multi_modal_docs(mm_docs):
if len(mm_docs) == 0:
logger.warning("No multi-modal documents to insert")
return

logger.info("[New Docs] inserting %d multi-modal docs", len(mm_docs))

# Step 3.1: Transform multi-modal documents into chunks and filter existing ones
inserting_chunks = await chunk_documents(
mm_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)
if len(new_docs) == 0:
logger.warning("All documents are already in the storage")
return

_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
inserting_chunks = await chunk_documents(
new_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)

if len(inserting_chunks) == 0:
logger.warning("All multi-modal chunks are already in the storage")
return
_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}

logger.info(
"[New Chunks] inserting %d multimodal chunks", len(inserting_chunks)
)
await self.chunks_storage.upsert(inserting_chunks)

# Step 3.2: Extract multi-modal entities and relations from chunks
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
_add_entities_and_relations = await build_mm_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning(
"No entities or relations extracted from multi-modal chunks"
)
return
await self._insert_done()
return _add_entities_and_relations

# Step 2: Insert text documents
await _insert_text_docs(new_text_docs)
# Step 3: Insert multi-modal documents
await _insert_multi_modal_docs(new_mm_docs)
if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return

logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
await self.chunks_storage.upsert(inserting_chunks)

_add_entities_and_relations = await build_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
anchor_type=read_config.get("anchor_type", None),
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted from text chunks")
Copy link

Copilot AI Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning message refers to 'text chunks' but this code path handles all chunk types (both text and multi-modal). The message should be updated to 'No entities or relations extracted from chunks' to accurately reflect the unified processing.

Suggested change
logger.warning("No entities or relations extracted from text chunks")
logger.warning("No entities or relations extracted from chunks")

Copilot uses AI. Check for mistakes.
return

await self._insert_done()
return _add_entities_and_relations

async def _insert_done(self):
tasks = []
Expand Down
2 changes: 1 addition & 1 deletion graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
MultiHopGenerator,
VQAGenerator,
)
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
from .kg_builder import LightRAGKGBuilder, MMKGBuilder, MOKGBuilder
from .llm import HTTPClient, OllamaClient, OpenAIClient
from .partitioner import (
AnchorBFSPartitioner,
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/kg_builder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .light_rag_kg_builder import LightRAGKGBuilder
from .mm_kg_builder import MMKGBuilder
from .mo_kg_builder import MOKGBuilder
100 changes: 100 additions & 0 deletions graphgen/models/kg_builder/mo_kg_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import re
from collections import defaultdict
from typing import Dict, List, Tuple

from graphgen.bases import Chunk
from graphgen.templates import PROTEIN_KG_EXTRACTION_PROMPT
from graphgen.utils import (
detect_main_language,
handle_single_entity_extraction,
handle_single_relationship_extraction,
logger,
split_string_by_multi_markers,
)

from .light_rag_kg_builder import LightRAGKGBuilder


class MOKGBuilder(LightRAGKGBuilder):
@staticmethod
async def scan_document_for_schema(
chunk: Chunk, schema: Dict[str, List[str]]
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
"""
Scan the document chunk to extract entities and relationships based on the provided schema.
:param chunk: The document chunk to be scanned.
:param schema: A dictionary defining the entities and relationships to be extracted.
:return: A tuple containing two dictionaries - one for entities and one for relationships.
"""
# TODO: use hard-coded PROTEIN_KG_EXTRACTION_PROMPT for protein chunks,
# support schema for other chunk types later
print(chunk.id, schema)
return {}, {}

async def extract(
self, chunk: Chunk
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
"""
Multi-Omics Knowledge Graph Builder
Step1: Extract and output a JSON object containing protein information from the given chunk.
Step2: Get more details about the protein by querying external databases if necessary.
Step3: Construct entities and relationships for the protein knowledge graph.
Step4: Return the entities and relationships.
:param chunk
Copy link

Copilot AI Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing description for the chunk parameter in the docstring. Should document the expected type and purpose of this parameter.

Suggested change
:param chunk
:param chunk: Chunk: The input data chunk containing information to extract protein entities and relationships from.

Copilot uses AI. Check for mistakes.
:return: Tuple containing entities and relationships.
"""
# TODO: Implement the multi-omics KG extraction logic here
chunk_id = chunk.id
chunk_type = chunk.type # genome | protein | ...
metadata = chunk.metadata

# choose different extraction strategies based on chunk type
if chunk_type == "protein":
protein_caption = ""
for key, value in metadata["protein_caption"].items():
protein_caption += f"{key}: {value}\n"
logger.debug("Protein chunk caption: %s", protein_caption)

language = detect_main_language(protein_caption)
prompt_template = PROTEIN_KG_EXTRACTION_PROMPT[language].format(
**PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"],
input_text=protein_caption,
)
result = await self.llm_client.generate_answer(prompt_template)
logger.debug("Protein chunk extraction result: %s", result)

# parse the result
records = split_string_by_multi_markers(
result,
[
PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
],
)

nodes = defaultdict(list)
edges = defaultdict(list)

for record in records:
match = re.search(r"\((.*)\)", record)
if not match:
continue
inner = match.group(1)

attributes = split_string_by_multi_markers(
inner, [PROTEIN_KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)

entity = await handle_single_entity_extraction(attributes, chunk_id)
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue

relation = await handle_single_relationship_extraction(
attributes, chunk_id
)
if relation is not None:
key = (relation["src_id"], relation["tgt_id"])
edges[key].append(relation)

return dict(nodes), dict(edges)
Loading
Loading