Source code for mmodel.filter
"""Filters that are used to create subgraphs."""
import networkx as nx
[docs]
def subnodes_by_outputs(graph, outputs: list) -> list:
"""Obtain a list of subgraph nodes based on node outputs.
For mmodel graphs, outputs from all internal nodes are unique.
Therefore, the function only checks if function nodes overlap with
the target return list. If a child node is included, so are the
parent nodes.
:return: list of node names
"""
subgraph_nodes = []
for node, output in nx.get_node_attributes(graph, "output").items():
if output in outputs:
subgraph_nodes.append(node)
subgraph_nodes.extend(nx.ancestors(graph, node))
return subgraph_nodes