diff --git a/examples/demoGraphs.py b/examples/demoGraphs.py index 388215f..dc495a7 100644 --- a/examples/demoGraphs.py +++ b/examples/demoGraphs.py @@ -157,3 +157,25 @@ def delete_graph_force(): # delete_graph_force() + + +def retrieve_subgraph(): + subgraph = litegraph.Graph.retrieve_subgraph( + graph_guid="b6eb533b-2f46-47e8-b732-2ec0ea09ae0a", + node_guid="2dfef492-601d-4c72-a17c-97b9edbe6b15", + ) + print(subgraph) + + +retrieve_subgraph() + + +def retrieve_subgraph_statistics(): + statistics = litegraph.Graph.retrieve_subgraph_statistics( + graph_guid="b6eb533b-2f46-47e8-b732-2ec0ea09ae0a", + node_guid="2dfef492-601d-4c72-a17c-97b9edbe6b15", + ) + print(statistics) + + +retrieve_subgraph_statistics() diff --git a/src/litegraph/models/search_node_edge.py b/src/litegraph/models/search_node_edge.py index 5fc9c17..7dd00b9 100755 --- a/src/litegraph/models/search_node_edge.py +++ b/src/litegraph/models/search_node_edge.py @@ -5,6 +5,7 @@ from ..enums.enumeration_order_enum import EnumerationOrder_Enum from ..models.edge import EdgeModel from ..models.expression import ExprModel +from ..models.graphs import GraphModel from ..models.node import NodeModel @@ -17,10 +18,17 @@ class SearchRequest(BaseModel): ordering: EnumerationOrder_Enum = Field( EnumerationOrder_Enum.CreatedDescending, alias="Ordering" ) + max_results: int = Field(default=5, ge=1, le=1000, alias="MaxResults") + skip: int = Field(default=0, ge=0, alias="Skip") + include_data: Optional[bool] = Field(default=False, alias="IncludeData") + include_subordinates: Optional[bool] = Field( + default=False, alias="IncludeSubordinates" + ) expr: Optional[ExprModel] = Field(None, alias="Expr") + name: Optional[str] = Field(None, alias="Name") labels: Optional[List] = Field(None, alias="Labels") tags: Optional[Dict[str, str]] = Field(default_factory=dict, alias="Tags") - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, populate_by_alias=True) class SearchResult(BaseModel): @@ -38,4 +46,6 @@ class SearchResultEdge(BaseModel): """ edges: Optional[List[EdgeModel]] = Field(None, alias="Edges") + graphs: Optional[List[GraphModel]] = Field(None, alias="Graphs") + nodes: Optional[List[NodeModel]] = Field(None, alias="Nodes") model_config = ConfigDict(populate_by_name=True) diff --git a/src/litegraph/resources/graphs.py b/src/litegraph/resources/graphs.py index 2b9e468..ea1c4de 100755 --- a/src/litegraph/resources/graphs.py +++ b/src/litegraph/resources/graphs.py @@ -137,3 +137,90 @@ def retrieve_first( """ graph_id = graph_id or kwargs.get("graph_guid") return super().retrieve_first(graph_id=graph_id, **kwargs) + + @classmethod + def retrieve_subgraph_statistics( + cls, + graph_guid: str, + node_guid: str, + max_depth: int | None = None, + max_nodes: int | None = None, + max_edges: int | None = None, + ) -> GraphStatisticsModel: + """ + Retrieve the statistics for a subgraph. + + Args: + graph_guid: The GUID of the graph. + node_guid: The GUID of the node. + max_depth: Maximum depth for the subgraph traversal. + max_nodes: Maximum number of nodes to include. + max_edges: Maximum number of edges to include. + + Returns: + GraphStatisticsModel: The statistics for the subgraph. + """ + client = get_client() + url = _get_url_v1( + cls, + client.tenant_guid, + graph_guid, + "nodes", + node_guid, + "subgraph", + "stats", + maxDepth=max_depth, + maxNodes=max_nodes, + maxEdges=max_edges, + ) + response = client.request("GET", url) + return GraphStatisticsModel.model_validate(response) + + @classmethod + def retrieve_subgraph( + cls, + graph_guid: str, + node_guid: str, + max_depth: int | None = None, + max_nodes: int | None = None, + max_edges: int | None = None, + include_data: bool = True, + include_sub: bool = True, + ) -> GraphModel: + """ + Retrieve the subgraph. + + Args: + graph_guid: The GUID of the graph. + node_guid: The GUID of the node. + max_depth: Maximum depth for the subgraph traversal. + max_nodes: Maximum number of nodes to include. + max_edges: Maximum number of edges to include. + include_data: Whether to include data in the response. + include_sub: Whether to include subgraphs in the response. + + Returns: + GraphModel: The subgraph. + """ + client = get_client() + query_params = { + "maxDepth": max_depth, + "maxNodes": max_nodes, + "maxEdges": max_edges, + } + if include_data: + query_params["incldata"] = None + if include_sub: + query_params["inclsub"] = None + + url = _get_url_v1( + cls, + client.tenant_guid, + graph_guid, + "nodes", + node_guid, + "subgraph", + **query_params, + ) + response = client.request("GET", url) + return GraphModel.model_validate(response) diff --git a/tests/test_models/test_graphs.py b/tests/test_models/test_graphs.py index a7fc9d3..c02fe1f 100755 --- a/tests/test_models/test_graphs.py +++ b/tests/test_models/test_graphs.py @@ -304,3 +304,15 @@ def test_graph_retrieve_statistics_all(mock_client): assert result["graph1"].nodes == 10 assert result["graph2"].nodes == 5 mock_client.request.assert_called_once() + +def test_graph_retrieve_subgraph_statistics(mock_client): + """Test retrieving statistics for a subgraph.""" + mock_response = { + "Nodes": 10, + "Edges": 15, + "Labels": 3, + "Tags": 5, + "Vectors": 2 + } + mock_client.request.return_value = mock_response + result = Graph.retrieve_subgraph_statistics(graph_guid="test-graph-guid", node_guid="test-node-guid") \ No newline at end of file