@@ -702,18 +702,24 @@ class ChEBIOverXPartial(ChEBIOverX):
702702 top_class_id (int): The ID of the top class from which to extract subclasses.
703703 """
704704
705- def __init__ (self , top_class_id : int , ** kwargs ):
705+ def __init__ (self , top_class_id : int , external_data_ratio : float , ** kwargs ):
706706 """
707707 Initializes the ChEBIOverXPartial dataset.
708708
709709 Args:
710710 top_class_id (int): The ID of the top class from which to extract subclasses.
711711 **kwargs: Additional keyword arguments passed to the superclass initializer.
712+ external_data_ratio (float): How much external data (i.e., samples where top_class_id
713+ is no positive label) to include in the dataset. 0 means no external data, 1 means
714+ the maximum amount (i.e., the complete ChEBI dataset).
712715 """
713716 if "top_class_id" not in kwargs :
714717 kwargs ["top_class_id" ] = top_class_id
718+ if "external_data_ratio" not in kwargs :
719+ kwargs ["external_data_ratio" ] = external_data_ratio
715720
716721 self .top_class_id : int = top_class_id
722+ self .external_data_ratio : float = external_data_ratio
717723 super ().__init__ (** kwargs )
718724
719725 @property
@@ -727,7 +733,7 @@ def processed_dir_main(self) -> str:
727733 return os .path .join (
728734 self .base_dir ,
729735 self ._name ,
730- f"partial_{ self .top_class_id } " ,
736+ f"partial_{ self .top_class_id } _ext_ratio_ { self . external_data_ratio :.2f } " ,
731737 "processed" ,
732738 )
733739
@@ -746,9 +752,53 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
746752 descendants of the top class ID.
747753 """
748754 g = super ()._extract_class_hierarchy (chebi_path )
749- g = g .subgraph (list (g .successors (self .top_class_id )) + [self .top_class_id ])
755+ top_class_successors = list (g .successors (self .top_class_id )) + [
756+ self .top_class_id
757+ ]
758+ external_nodes = list (set (n for n in g .nodes if n not in top_class_successors ))
759+ if 0 < self .external_data_ratio < 1 :
760+ n_external_nodes = int (
761+ len (top_class_successors )
762+ * self .external_data_ratio
763+ / (1 - self .external_data_ratio )
764+ )
765+ print (
766+ f"Extracting { n_external_nodes } external nodes from the ChEBI dataset (ratio: { self .external_data_ratio :.2f} )"
767+ )
768+ external_nodes = external_nodes [: int (n_external_nodes )]
769+ elif self .external_data_ratio == 0 :
770+ external_nodes = []
771+
772+ g = g .subgraph (top_class_successors + external_nodes )
773+ print (
774+ f"Subgraph contains { len (g .nodes )} nodes, of which { len (top_class_successors )} are subclasses of the top class ID { self .top_class_id } ."
775+ )
750776 return g
751777
778+ def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
779+ """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
780+ smiles = nx .get_node_attributes (g , "smiles" )
781+ nodes = list (
782+ sorted (
783+ {
784+ node
785+ for node in g .nodes
786+ if sum (
787+ 1 if smiles [s ] is not None else 0 for s in g .successors (node )
788+ )
789+ >= self .THRESHOLD
790+ and (
791+ self .top_class_id in g .predecessors (node )
792+ or node == self .top_class_id
793+ )
794+ }
795+ )
796+ )
797+ filename = "classes.txt"
798+ with open (os .path .join (self .processed_dir_main , filename ), "wt" ) as fout :
799+ fout .writelines (str (node ) + "\n " for node in nodes )
800+ return nodes
801+
752802
753803class ChEBIOver50Partial (ChEBIOverXPartial , ChEBIOver50 ):
754804 """
@@ -1473,3 +1523,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
14731523]
14741524
14751525JCI_500_COLUMNS_INT = [int (n .split (":" )[- 1 ]) for n in JCI_500_COLUMNS ]
1526+
1527+ if __name__ == "__main__" :
1528+ # get arguments from command line
1529+ pass
0 commit comments