Source code for mmodel.graph

import networkx as nx
from mmodel.visualizer import plain_visualizer
from copy import deepcopy
from mmodel.filter import subnodes_by_inputs, subnodes_by_outputs
from mmodel.utility import replace_subgraph

[docs]class Graph(nx.DiGraph): """Create model graphs. mmodel.Graph inherits from `networkx.DiGraph()`, which has all `DiGraph` methods. The class adds the "type" attribute to the graph attribute. The factory method returns a copy of the dictionary. It is equivalent to ``{"type": "mmodel_graph"}.copy()`` when called. The additional graph operations are added: - add_grouped_edges and set_node_objects. - Method ``add_grouped_edges``, cannot have both edges list. - Method ``set_node_object`` updates nodes with the node callable "func" and output. - The method adds callable signature 'sig' to the node attribute. """ graph_attr_dict_factory = {"type": "mmodel_graph"}.copy
[docs] def set_node_object(self, node_object): """Add or update the functions of an existing node.""" self.nodes[]["node_object"] = node_object self.nodes[]["signature"] = node_object.signature self.nodes[]["output"] = node_object.output self.update_graph()
[docs] def set_node_objects_from(self, node_objects: list): """Update the functions of existing nodes. The method is the same as adding a node object. """ for node_object in node_objects: # unzipping works for input with or without modifiers self.set_node_object(node_object)
[docs] def add_edge(self, u_of_edge, v_of_edge, **attr): """Modify add_edge to update the edge attribute in the end.""" super().add_edge(u_of_edge, v_of_edge, **attr) self.update_graph()
[docs] def add_edges_from(self, ebunch_to_add, **attr): """Modify add_edges_from to update the edge attributes.""" super().add_edges_from(ebunch_to_add, **attr) self.update_graph()
[docs] def add_grouped_edge(self, u, v): """Add linked edge. For mmodel, a group edge (u, v) allows u or v to be a list of nodes. A grouped edge represents one or several nodes flowing into one node. """ if isinstance(u, list) and isinstance(v, list): raise Exception("only one edge node can be a list") # use add edges from to run less update graph # currently a compromise if isinstance(u, list): self.add_edges_from([(_u, v) for _u in u]) elif isinstance(v, list): self.add_edges_from([(u, _v) for _v in v]) else: # neither is a list self.add_edge(u, v)
[docs] def add_grouped_edges_from(self, group_edges: list): """Add edges from grouped values.""" for u, v in group_edges: self.add_grouped_edge(u, v)
[docs] def update_graph(self): """Update edge attributes based on node objects and edges.""" for u, v in self.edges: if self.nodes[u] and self.nodes[v]: # the edge "output" is not defined if the parent node does not # have "output" attribute or the child node does not have # the parameter # extract the parameter dictionary v_sig = self.nodes[v]["signature"].parameters if self.nodes[u]["output"] in v_sig: self.edges[u, v]["output"] = self.nodes[u]["output"]
# graph operations
[docs] def subgraph(self, nodes=None, inputs=None, outputs=None): """Extract subgraph by nodes, inputs, and output. If multiple parameters are specified, the result is a union of the selection. The subgraph is a deep copy of the original graph. The behavior is different from the parent class method, where the subgraph returns a view of the original graph. """ nodes = nodes or [] node_inputs = subnodes_by_inputs(self, inputs or []) node_outputs = subnodes_by_outputs(self, outputs or []) # convert nodes to list because the parent class method accepts generator # for nodes. # may consider not using the same name as the parent class to avoid collision subgraph_nodes = set(list(nodes) + node_inputs + node_outputs) # unique nodes return super().subgraph(subgraph_nodes).deepcopy()
[docs] def replace_subgraph(self, subgraph, node_object): """Replace subgraph with a node.""" return replace_subgraph(self, subgraph, node_object)
[docs] def get_node(self, node): """Get node attributes from the graph.""" return self.nodes[node]
[docs] def get_node_object(self, node): """Get node object from the graph.""" return self.nodes[node]["node_object"]
[docs] def edit_node(self, node, **kwargs): """Edit node attributes. Returns a new graph. """ node_object = self.nodes[node]["node_object"].edit(**kwargs) graph = self.deepcopy() graph.set_node_object(node_object) return graph
[docs] def visualize(self, outfile=None): """Draw the graph. Draws the default styled graph. :param str outfile: filename to save the graph as. The file extension is needed. """ return plain_visualizer(self, str(self), outfile)
[docs] def deepcopy(self): """Deepcopy graph. The ``graph.copy`` method is a shallow copy. Deepcopy creates a copy for the attributes dictionary. `graph.copy< /generated/networkx.Graph.copy.html>_` However, for subgraphs, ``deepcopy`` is incredibly inefficient because subgraph contains '_graph', which stores the original graph. An alternative method is to copy the code from the copy method, but use ``deepcopy`` for the items. The parser is redefined in the new graph. """ G = self.__class__() G.graph.update(deepcopy(self.graph)) G.add_nodes_from((n, deepcopy(d)) for n, d in self._node.items()) G.add_edges_from( (u, v, deepcopy(datadict)) for u, nbrs in self._adj.items() for v, datadict in nbrs.items() ) return G