Skip to content

Commit 0a00d11

Browse files
committed
use ratio parameter to add external data
1 parent 48725dd commit 0a00d11

File tree

1 file changed

+57
-3
lines changed

1 file changed

+57
-3
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,18 +702,24 @@ class ChEBIOverXPartial(ChEBIOverX):
702702
top_class_id (int): The ID of the top class from which to extract subclasses.
703703
"""
704704

705-
def __init__(self, top_class_id: int, **kwargs):
705+
def __init__(self, top_class_id: int, external_data_ratio: float, **kwargs):
706706
"""
707707
Initializes the ChEBIOverXPartial dataset.
708708
709709
Args:
710710
top_class_id (int): The ID of the top class from which to extract subclasses.
711711
**kwargs: Additional keyword arguments passed to the superclass initializer.
712+
external_data_ratio (float): How much external data (i.e., samples where top_class_id
713+
is no positive label) to include in the dataset. 0 means no external data, 1 means
714+
the maximum amount (i.e., the complete ChEBI dataset).
712715
"""
713716
if "top_class_id" not in kwargs:
714717
kwargs["top_class_id"] = top_class_id
718+
if "external_data_ratio" not in kwargs:
719+
kwargs["external_data_ratio"] = external_data_ratio
715720

716721
self.top_class_id: int = top_class_id
722+
self.external_data_ratio: float = external_data_ratio
717723
super().__init__(**kwargs)
718724

719725
@property
@@ -727,7 +733,7 @@ def processed_dir_main(self) -> str:
727733
return os.path.join(
728734
self.base_dir,
729735
self._name,
730-
f"partial_{self.top_class_id}",
736+
f"partial_{self.top_class_id}_ext_ratio_{self.external_data_ratio:.2f}",
731737
"processed",
732738
)
733739

@@ -746,9 +752,53 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
746752
descendants of the top class ID.
747753
"""
748754
g = super()._extract_class_hierarchy(chebi_path)
749-
g = g.subgraph(list(g.successors(self.top_class_id)) + [self.top_class_id])
755+
top_class_successors = list(g.successors(self.top_class_id)) + [
756+
self.top_class_id
757+
]
758+
external_nodes = list(set(n for n in g.nodes if n not in top_class_successors))
759+
if 0 < self.external_data_ratio < 1:
760+
n_external_nodes = int(
761+
len(top_class_successors)
762+
* self.external_data_ratio
763+
/ (1 - self.external_data_ratio)
764+
)
765+
print(
766+
f"Extracting {n_external_nodes} external nodes from the ChEBI dataset (ratio: {self.external_data_ratio:.2f})"
767+
)
768+
external_nodes = external_nodes[: int(n_external_nodes)]
769+
elif self.external_data_ratio == 0:
770+
external_nodes = []
771+
772+
g = g.subgraph(top_class_successors + external_nodes)
773+
print(
774+
f"Subgraph contains {len(g.nodes)} nodes, of which {len(top_class_successors)} are subclasses of the top class ID {self.top_class_id}."
775+
)
750776
return g
751777

778+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
779+
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
780+
smiles = nx.get_node_attributes(g, "smiles")
781+
nodes = list(
782+
sorted(
783+
{
784+
node
785+
for node in g.nodes
786+
if sum(
787+
1 if smiles[s] is not None else 0 for s in g.successors(node)
788+
)
789+
>= self.THRESHOLD
790+
and (
791+
self.top_class_id in g.predecessors(node)
792+
or node == self.top_class_id
793+
)
794+
}
795+
)
796+
)
797+
filename = "classes.txt"
798+
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
799+
fout.writelines(str(node) + "\n" for node in nodes)
800+
return nodes
801+
752802

753803
class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50):
754804
"""
@@ -1473,3 +1523,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
14731523
]
14741524

14751525
JCI_500_COLUMNS_INT = [int(n.split(":")[-1]) for n in JCI_500_COLUMNS]
1526+
1527+
if __name__ == "__main__":
1528+
# get arguments from command line
1529+
pass

0 commit comments

Comments
 (0)