import networkx as nx
from mmodel.visualizer import plain_visualizer
from copy import deepcopy
from mmodel.filter import subnodes_by_inputs, subnodes_by_outputs
from mmodel.utility import replace_subgraph
[docs]class Graph(nx.DiGraph):
"""Create model graphs.
mmodel.Graph inherits from `networkx.DiGraph()`, which has all `DiGraph`
methods.
The class adds the "type" attribute to the graph attribute. The factory method
returns a copy of the dictionary. It is equivalent to
``{"type": "mmodel_graph"}.copy()`` when called.
The additional graph operations are added:
- add_grouped_edges and set_node_objects.
- Method ``add_grouped_edges``, cannot have both edges list.
- Method ``set_node_object`` updates nodes with the node callable "func" and output.
- The method adds callable signature 'sig' to the node attribute.
"""
graph_attr_dict_factory = {"type": "mmodel_graph"}.copy
[docs] def set_node_object(self, node_object):
"""Add or update the functions of an existing node."""
self.nodes[node_object.name]["node_object"] = node_object
self.nodes[node_object.name]["signature"] = node_object.signature
self.nodes[node_object.name]["output"] = node_object.output
self.update_graph()
[docs] def set_node_objects_from(self, node_objects: list):
"""Update the functions of existing nodes.
The method is the same as adding a node object.
"""
for node_object in node_objects:
# unzipping works for input with or without modifiers
self.set_node_object(node_object)
[docs] def add_edge(self, u_of_edge, v_of_edge, **attr):
"""Modify add_edge to update the edge attribute in the end."""
super().add_edge(u_of_edge, v_of_edge, **attr)
self.update_graph()
[docs] def add_edges_from(self, ebunch_to_add, **attr):
"""Modify add_edges_from to update the edge attributes."""
super().add_edges_from(ebunch_to_add, **attr)
self.update_graph()
[docs] def add_grouped_edge(self, u, v):
"""Add linked edge.
For mmodel, a group edge (u, v) allows u or v
to be a list of nodes. A grouped edge represents one or several
nodes flowing into one node.
"""
if isinstance(u, list) and isinstance(v, list):
raise Exception("only one edge node can be a list")
# use add edges from to run less update graph
# currently a compromise
if isinstance(u, list):
self.add_edges_from([(_u, v) for _u in u])
elif isinstance(v, list):
self.add_edges_from([(u, _v) for _v in v])
else: # neither is a list
self.add_edge(u, v)
[docs] def add_grouped_edges_from(self, group_edges: list):
"""Add edges from grouped values."""
for u, v in group_edges:
self.add_grouped_edge(u, v)
[docs] def update_graph(self):
"""Update edge attributes based on node objects and edges."""
for u, v in self.edges:
if self.nodes[u] and self.nodes[v]:
# the edge "output" is not defined if the parent node does not
# have "output" attribute or the child node does not have
# the parameter
# extract the parameter dictionary
v_sig = self.nodes[v]["signature"].parameters
if self.nodes[u]["output"] in v_sig:
self.edges[u, v]["output"] = self.nodes[u]["output"]
# graph operations
[docs] def subgraph(self, nodes=None, inputs=None, outputs=None):
"""Extract subgraph by nodes, inputs, and output.
If multiple parameters are specified, the result is a union
of the selection. The subgraph is a deep copy of the original graph.
The behavior is different from the parent class method, where the subgraph
returns a view of the original graph.
"""
nodes = nodes or []
node_inputs = subnodes_by_inputs(self, inputs or [])
node_outputs = subnodes_by_outputs(self, outputs or [])
# convert nodes to list because the parent class method accepts generator
# for nodes.
# may consider not using the same name as the parent class to avoid collision
subgraph_nodes = set(list(nodes) + node_inputs + node_outputs) # unique nodes
return super().subgraph(subgraph_nodes).deepcopy()
[docs] def replace_subgraph(self, subgraph, node_object):
"""Replace subgraph with a node."""
return replace_subgraph(self, subgraph, node_object)
[docs] def get_node(self, node):
"""Get node attributes from the graph."""
return self.nodes[node]
[docs] def get_node_object(self, node):
"""Get node object from the graph."""
return self.nodes[node]["node_object"]
[docs] def edit_node(self, node, **kwargs):
"""Edit node attributes.
Returns a new graph.
"""
node_object = self.nodes[node]["node_object"].edit(**kwargs)
graph = self.deepcopy()
graph.set_node_object(node_object)
return graph
[docs] def visualize(self, outfile=None):
"""Draw the graph.
Draws the default styled graph.
:param str outfile: filename to save the graph as. The file extension
is needed.
"""
return plain_visualizer(self, str(self), outfile)
[docs] def deepcopy(self):
"""Deepcopy graph.
The ``graph.copy`` method is a shallow copy. Deepcopy creates a copy for
the attributes dictionary.
`graph.copy<https://networkx.org/documentation/stable/reference/classes
/generated/networkx.Graph.copy.html>_`
However, for subgraphs, ``deepcopy`` is incredibly inefficient because
subgraph contains '_graph', which stores the original graph.
An alternative method is to copy the code from the copy method,
but use ``deepcopy`` for the items.
The parser is redefined in the new graph.
"""
G = self.__class__()
G.graph.update(deepcopy(self.graph))
G.add_nodes_from((n, deepcopy(d)) for n, d in self._node.items())
G.add_edges_from(
(u, v, deepcopy(datadict))
for u, nbrs in self._adj.items()
for v, datadict in nbrs.items()
)
return G