From 876448870d8c7804a3effd139d3ee55e0f9803ef Mon Sep 17 00:00:00 2001 From: ywh555hhh <1916647616@qq.com> Date: Tue, 3 Mar 2026 13:07:05 +0800 Subject: [PATCH 1/2] add graph_net_sample_groups_util.py and graph_net_sample_util.py - get_all_graph_net_sample_groups: get groups from DB with filters - GraphNetSampleTypeGetter: get sample_type for a given sample_uid - GraphNetSampleOpSeqGetter: get op_seq for a given sample_uid - Both getters support bulk_get for batch operations - Tested with real downstream: util.filter_valid_groups --- sqlite/graph_net_sample_groups_util.py | 49 ++++++++++ sqlite/graph_net_sample_util.py | 127 +++++++++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 sqlite/graph_net_sample_groups_util.py create mode 100644 sqlite/graph_net_sample_util.py diff --git a/sqlite/graph_net_sample_groups_util.py b/sqlite/graph_net_sample_groups_util.py new file mode 100644 index 000000000..1bcd5bf23 --- /dev/null +++ b/sqlite/graph_net_sample_groups_util.py @@ -0,0 +1,49 @@ +# graph_net_sample_groups_util.py +from typing import List, Set, Dict +from collections import defaultdict + +from orm_models import get_session, GraphNetSampleGroup + + +def get_all_graph_net_sample_groups( + db_path: str, + group_types: List[str], + group_policies: List[str], + versions: List[str], +) -> List[Set[str]]: + """ + Get all graph_net sample groups from database. + + Viba: + get_all_graph_net_sample_groups := + list[set[$sample_uid str]] + <- $group_net_db_file_path str + <- $group_type list[str] + <- $group_policy list[str] + <- $version list[str] + + Args: + db_path: Path to the SQLite database file. + group_types: List of group types to filter (e.g., ["shape_diversity", "dtype_diversity"]). + group_policies: List of group policies to filter (e.g., ["by_bucket"]). + versions: List of policy versions to filter (e.g., ["v0.1"]). + + Returns: + List of sets, each set contains sample UIDs belonging to one group. + """ + session = get_session(db_path) + + query = session.query(GraphNetSampleGroup).filter( + GraphNetSampleGroup.deleted.is_(False), + GraphNetSampleGroup.group_type.in_(group_types), + GraphNetSampleGroup.group_policy.in_(group_policies), + GraphNetSampleGroup.policy_version.in_(versions), + ) + + groups_dict: Dict[str, List[str]] = defaultdict(list) + for row in query.all(): + groups_dict[row.group_uid].append(row.sample_uid) + + session.close() + + return [set(uids) for uids in groups_dict.values()] diff --git a/sqlite/graph_net_sample_util.py b/sqlite/graph_net_sample_util.py new file mode 100644 index 000000000..4edcb9ee6 --- /dev/null +++ b/sqlite/graph_net_sample_util.py @@ -0,0 +1,127 @@ +# graph_net_sample_util.py +import json +from typing import Dict, List + +from orm_models import get_session, GraphSample, SampleOpNameList + + +class GraphNetSampleTypeGetter: + """ + Get sample_type for a given sample_uid. + + Viba: + GraphNetSampleTypeGetter := + # __call__ + $sample_type str + <- $sample_uid str + # __init__ + <- $group_net_db_file_path str + <- $fetch_cache dict[$sample_uid str, $sample_type str] + """ + + def __init__(self, db_path: str): + self.db_path = db_path + self._cache: Dict[str, str] = {} + + def __call__(self, sample_uid: str) -> str: + """Get sample_type for the given sample_uid.""" + if sample_uid in self._cache: + return self._cache[sample_uid] + + session = get_session(self.db_path) + sample = ( + session.query(GraphSample).filter(GraphSample.uuid == sample_uid).first() + ) + session.close() + + sample_type = sample.sample_type if sample else "" + self._cache[sample_uid] = sample_type + return sample_type + + def bulk_get(self, sample_uids: List[str]) -> Dict[str, str]: + """Bulk get sample_types for multiple sample UIDs.""" + session = get_session(self.db_path) + + samples = ( + session.query(GraphSample).filter(GraphSample.uuid.in_(sample_uids)).all() + ) + + result = {} + for s in samples: + result[s.uuid] = s.sample_type + self._cache[s.uuid] = s.sample_type + + for uid in sample_uids: + if uid not in result: + result[uid] = "" + + session.close() + return result + + +class GraphNetSampleOpSeqGetter: + """ + Get op_seq for a given sample_uid. + + Viba: + GraphNetSampleOpSeqGetter := + # __call__ + $sample_op_seq list[str] + <- $sample_uid str + # __init__ + <- $group_net_db_file_path str + <- $fetch_cache dict[$sample_uid str, $sample_op_seq list[str]] + """ + + def __init__(self, db_path: str): + self.db_path = db_path + self._cache: Dict[str, List[str]] = {} + + def __call__(self, sample_uid: str) -> List[str]: + """Get op_seq for the given sample_uid.""" + if sample_uid in self._cache: + return self._cache[sample_uid] + + session = get_session(self.db_path) + op_list = ( + session.query(SampleOpNameList) + .filter(SampleOpNameList.sample_uuid == sample_uid) + .first() + ) + session.close() + + if op_list and op_list.op_names_json: + op_data = json.loads(op_list.op_names_json) + op_seq = [op["op_name"] for op in op_data] + else: + op_seq = [] + + self._cache[sample_uid] = op_seq + return op_seq + + def bulk_get(self, sample_uids: List[str]) -> Dict[str, List[str]]: + """Bulk get op_seqs for multiple sample UIDs.""" + session = get_session(self.db_path) + + op_lists = ( + session.query(SampleOpNameList) + .filter(SampleOpNameList.sample_uuid.in_(sample_uids)) + .all() + ) + + result = {} + for op_list in op_lists: + if op_list.op_names_json: + op_data = json.loads(op_list.op_names_json) + op_seq = [op["op_name"] for op in op_data] + else: + op_seq = [] + result[op_list.sample_uuid] = op_seq + self._cache[op_list.sample_uuid] = op_seq + + for uid in sample_uids: + if uid not in result: + result[uid] = [] + + session.close() + return result From f44b7aea5cb570476f8716e859e28b98dbe9b464 Mon Sep 17 00:00:00 2001 From: ywh555hhh <1916647616@qq.com> Date: Tue, 3 Mar 2026 13:52:17 +0800 Subject: [PATCH 2/2] move util files to sqlite/util/ and update imports - Move graph_net_sample_groups_util.py to sqlite/util/ - Move graph_net_sample_util.py to sqlite/util/ - Update imports to use sqlite.orm_models --- sqlite/{ => util}/graph_net_sample_groups_util.py | 2 +- sqlite/{ => util}/graph_net_sample_util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename sqlite/{ => util}/graph_net_sample_groups_util.py (96%) rename sqlite/{ => util}/graph_net_sample_util.py (98%) diff --git a/sqlite/graph_net_sample_groups_util.py b/sqlite/util/graph_net_sample_groups_util.py similarity index 96% rename from sqlite/graph_net_sample_groups_util.py rename to sqlite/util/graph_net_sample_groups_util.py index 1bcd5bf23..39297fe48 100644 --- a/sqlite/graph_net_sample_groups_util.py +++ b/sqlite/util/graph_net_sample_groups_util.py @@ -2,7 +2,7 @@ from typing import List, Set, Dict from collections import defaultdict -from orm_models import get_session, GraphNetSampleGroup +from sqlite.orm_models import get_session, GraphNetSampleGroup def get_all_graph_net_sample_groups( diff --git a/sqlite/graph_net_sample_util.py b/sqlite/util/graph_net_sample_util.py similarity index 98% rename from sqlite/graph_net_sample_util.py rename to sqlite/util/graph_net_sample_util.py index 4edcb9ee6..e2d2c30ff 100644 --- a/sqlite/graph_net_sample_util.py +++ b/sqlite/util/graph_net_sample_util.py @@ -2,7 +2,7 @@ import json from typing import Dict, List -from orm_models import get_session, GraphSample, SampleOpNameList +from sqlite.orm_models import get_session, GraphSample, SampleOpNameList class GraphNetSampleTypeGetter: