Source code for mmodel.handler

from collections import UserDict
from mmodel.utility import graph_topological_sort, param_counter, modelgraph_signature
from datetime import datetime
import h5py
import string
import random
from textwrap import dedent
import sys


[docs]class TopologicalHandler: """Base class for executing graph nodes in topological order. "Returns" specifies the output order. If there is only one return the value is outputted, otherwise a tuple is outputted. This behavior is similar to the Python function. The topological handler assumes each node has exactly one output. :param str name: name of the handler (same as the model instance) :param networkx.digraph graph: graph :param list returns: handler returns order The list should have the same or more elements than the graph returns. See Model constructor definition. """ DataClass: type = callable def __init__(self, graph, returns: list, **datacls_kwargs): # self.__name__ = name # __signature__ allows the inspect module to properly generate the signature self.__signature__ = modelgraph_signature(graph) self.returns = returns self.order = graph_topological_sort(graph) self.graph = graph self.datacls_kwargs = datacls_kwargs def __call__(self, **kwargs): """Execute graph model by layer. The data object is not stored as an attribute to avoid repeated use and reduce memory usage. """ data = self.DataClass(kwargs, **self.datacls_kwargs) for node, node_attr in self.order: self.run_node(data, node, node_attr) result = self.finish(data, self.returns) return result
[docs] def node_exception(self, data, node_data, node, node_attr): """Exception handler for individual nodes. Overwrite this function for different exception formatting. """ exception_format = dedent( """\ An exception occurred when executing node '{node}': --- exception info --- {exc_str} --- node info --- {node_str} --- input info --- {input_str} """ ) node_object = node_attr["node_object"] # format the error message input_str = "\n".join( [f"{key} = {repr(value)}" for key, value in node_data.items()] ) exc_type, exc_value, _ = sys.exc_info() exc_str = f"{exc_type.__name__}: {exc_value}" msg = exception_format.format( node=node, exc_str=exc_str, node_str=str(node_object), input_str=input_str ) raise Exception(msg)
[docs] def run_node(self, data, node, node_attr): """Run the individual node.""" kwargs = {key: data[key] for key in node_attr["signature"].parameters} node_object = node_attr["node_object"] try: # execute func_result = node_object.node_func(**kwargs) output = node_attr["output"] if output: # skip the None data[output] = func_result except: # exception occurred while running the node if hasattr(data, "close"): data.close() self.node_exception(data, kwargs, node, node_attr)
[docs] def finish(self, data, returns): """Finish execution.""" if len(returns) == 0: result = None elif len(returns) == 1: result = data[returns[0]] else: result = tuple(data[rt] for rt in returns) # if the data class needs to be closed if hasattr(data, "close"): data.close() return result
class MemData(UserDict): """Modified dictionary that checks the counter every time a value is accessed.""" def __init__(self, data, counter): """Counter is a copy of the counter dictionary.""" self.counter = counter.copy() super().__init__(data) def __getitem__(self, key): """When a key is accessed, reduce the counter. If the counter has reached zero, pop the value (key is deleted) else wise return the key. """ self.counter[key] -= 1 if self.counter[key] == 0: # return the value and delete the key in the dictionary value = super().__getitem__(key) del self[key] return value else: return super().__getitem__(key) class H5Data: """Data class to interact with underlying h5 file. The "timestamp-uuid" is used to ensure unique entries to the H5 group. The randomly generated short uuid has 36^5, which is roughly 2e9 possibilities (picoseconds range). """ def __init__(self, data, fname, gname): self.fname = fname self.f = h5py.File(self.fname, "a") alphabet = string.ascii_lowercase + string.digits shortuuid = "".join(random.choices(alphabet, k=6)) self.gname = f"{gname} {datetime.now().strftime('%y%m%d-%H%M%S')}-{shortuuid}" self.group = self.f.create_group(self.gname) self.update(data) def update(self, data): """Write key values in bulk.""" for key, value in data.items(): self[key] = value def __getitem__(self, key): """Read dataset/attribute by the group. :param str key: value name :param h5py.group group: open h5py group object """ return self.group[key][()] def __setitem__(self, key, value): """Write h5 dataset/attribute by the group. If the object type cannot be recognized by HDF5, the string representation of the object is written as an attribute :param dict value_dict: the dictionary of values to write :param h5py.group group: open h5py group object """ try: self.group.create_dataset(key, data=value) except TypeError: # TypeError: Object dtype dtype('O') has no native HDF5 equivalent self.group.attrs[key] = str(value) def close(self): """Close the data object.""" self.f.close()
[docs]class BasicHandler(TopologicalHandler): """Basic handler, use the dictionary as a data class.""" DataClass = dict
[docs]class MemHandler(TopologicalHandler): """Memory optimized handler, delete intermediate values when necessary. The process works by keeping a record of the parameter counter. See MemData class for more details. """ DataClass = MemData def __init__(self, graph, returns: list): """Add counter to the object.""" counter = param_counter(graph, returns) super().__init__(graph, returns, counter=counter)
[docs]class H5Handler(TopologicalHandler): """H5 Handler, saves all calculation values to an h5 file. :param str fname: h5 file name :param str gname: group name for the data entry """ DataClass = H5Data def __init__(self, graph, returns, fname: str, gname: str = ""): super().__init__(graph, returns, fname=fname, gname=gname)