Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/demoGraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,25 @@ def delete_graph_force():


# delete_graph_force()


def retrieve_subgraph():
subgraph = litegraph.Graph.retrieve_subgraph(
graph_guid="b6eb533b-2f46-47e8-b732-2ec0ea09ae0a",
node_guid="2dfef492-601d-4c72-a17c-97b9edbe6b15",
)
print(subgraph)


retrieve_subgraph()


def retrieve_subgraph_statistics():
statistics = litegraph.Graph.retrieve_subgraph_statistics(
graph_guid="b6eb533b-2f46-47e8-b732-2ec0ea09ae0a",
node_guid="2dfef492-601d-4c72-a17c-97b9edbe6b15",
)
print(statistics)


retrieve_subgraph_statistics()
12 changes: 11 additions & 1 deletion src/litegraph/models/search_node_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..enums.enumeration_order_enum import EnumerationOrder_Enum
from ..models.edge import EdgeModel
from ..models.expression import ExprModel
from ..models.graphs import GraphModel
from ..models.node import NodeModel


Expand All @@ -17,10 +18,17 @@ class SearchRequest(BaseModel):
ordering: EnumerationOrder_Enum = Field(
EnumerationOrder_Enum.CreatedDescending, alias="Ordering"
)
max_results: int = Field(default=5, ge=1, le=1000, alias="MaxResults")
skip: int = Field(default=0, ge=0, alias="Skip")
include_data: Optional[bool] = Field(default=False, alias="IncludeData")
include_subordinates: Optional[bool] = Field(
default=False, alias="IncludeSubordinates"
)
expr: Optional[ExprModel] = Field(None, alias="Expr")
name: Optional[str] = Field(None, alias="Name")
labels: Optional[List] = Field(None, alias="Labels")
tags: Optional[Dict[str, str]] = Field(default_factory=dict, alias="Tags")
model_config = ConfigDict(populate_by_name=True)
model_config = ConfigDict(populate_by_name=True, populate_by_alias=True)


class SearchResult(BaseModel):
Expand All @@ -38,4 +46,6 @@ class SearchResultEdge(BaseModel):
"""

edges: Optional[List[EdgeModel]] = Field(None, alias="Edges")
graphs: Optional[List[GraphModel]] = Field(None, alias="Graphs")
nodes: Optional[List[NodeModel]] = Field(None, alias="Nodes")
model_config = ConfigDict(populate_by_name=True)
87 changes: 87 additions & 0 deletions src/litegraph/resources/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,90 @@ def retrieve_first(
"""
graph_id = graph_id or kwargs.get("graph_guid")
return super().retrieve_first(graph_id=graph_id, **kwargs)

@classmethod
def retrieve_subgraph_statistics(
cls,
graph_guid: str,
node_guid: str,
max_depth: int | None = None,
max_nodes: int | None = None,
max_edges: int | None = None,
) -> GraphStatisticsModel:
"""
Retrieve the statistics for a subgraph.

Args:
graph_guid: The GUID of the graph.
node_guid: The GUID of the node.
max_depth: Maximum depth for the subgraph traversal.
max_nodes: Maximum number of nodes to include.
max_edges: Maximum number of edges to include.

Returns:
GraphStatisticsModel: The statistics for the subgraph.
"""
client = get_client()
url = _get_url_v1(
cls,
client.tenant_guid,
graph_guid,
"nodes",
node_guid,
"subgraph",
"stats",
maxDepth=max_depth,
maxNodes=max_nodes,
maxEdges=max_edges,
)
response = client.request("GET", url)
return GraphStatisticsModel.model_validate(response)

@classmethod
def retrieve_subgraph(
cls,
graph_guid: str,
node_guid: str,
max_depth: int | None = None,
max_nodes: int | None = None,
max_edges: int | None = None,
include_data: bool = True,
include_sub: bool = True,
) -> GraphModel:
"""
Retrieve the subgraph.

Args:
graph_guid: The GUID of the graph.
node_guid: The GUID of the node.
max_depth: Maximum depth for the subgraph traversal.
max_nodes: Maximum number of nodes to include.
max_edges: Maximum number of edges to include.
include_data: Whether to include data in the response.
include_sub: Whether to include subgraphs in the response.

Returns:
GraphModel: The subgraph.
"""
client = get_client()
query_params = {
"maxDepth": max_depth,
"maxNodes": max_nodes,
"maxEdges": max_edges,
}
if include_data:
query_params["incldata"] = None
if include_sub:
query_params["inclsub"] = None

url = _get_url_v1(
cls,
client.tenant_guid,
graph_guid,
"nodes",
node_guid,
"subgraph",
**query_params,
)
response = client.request("GET", url)
return GraphModel.model_validate(response)
12 changes: 12 additions & 0 deletions tests/test_models/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,15 @@ def test_graph_retrieve_statistics_all(mock_client):
assert result["graph1"].nodes == 10
assert result["graph2"].nodes == 5
mock_client.request.assert_called_once()

def test_graph_retrieve_subgraph_statistics(mock_client):
"""Test retrieving statistics for a subgraph."""
mock_response = {
"Nodes": 10,
"Edges": 15,
"Labels": 3,
"Tags": 5,
"Vectors": 2
}
mock_client.request.return_value = mock_response
result = Graph.retrieve_subgraph_statistics(graph_guid="test-graph-guid", node_guid="test-node-guid")