# -*- coding: utf-8 -*-
"""
This file is part of the nodeworks core library
Licence
-------
As a work of the United States Government, this project is in the public domain
within the United States. As such, this code is licensed under
CC0 1.0 Universal public domain.
Please see the LICENSE.md for more information.
"""
import importlib
import pkgutil
from typing import Any, Mapping
import os
import inspect
from nodeworks.node import Node
[docs]class NodeLibrary(object):
''' Collection of Nodes '''
def __init__(self):
self.nodetree: Mapping[str, Any] = {}
self.failed_nodetree: Mapping[str, Any] = {}
[docs] def addNode(self, node, path):
''' Add a Node to the library '''
# Check to make sure that the node is derived from Node
if not inspect.isclass(node) or not issubclass(node, Node):
raise ValueError("Object {} is not a Node subclass".format(str(node)))
# Add node to tree
temp = self.nodetree
for p in path:
if p not in temp:
temp[p] = {}
temp = temp[p]
temp[node.name] = node
def add_failed(self, node, path):
# Add node to tree
temp = self.failed_nodetree
for p in path:
if p not in temp:
temp[p] = {}
temp = temp[p]
temp[node[0]] = {'error': node[1], 'traceback': node[2]}
[docs] def addNodeLibrary(self, name, library):
''' add a node library '''
self.nodetree[name] = library.nodetree
def buildDefualtLibrary(self):
raise DeprecationWarning('please spell default correctly')
[docs] def buildDefaultLibrary(self):
''' build the default node library'''
# This import has to be here to avoid cyclical imports
import nodeworks.defaultnodes as defaultnodes
all_node_modnames = [modinfo[1] for modinfo in pkgutil.iter_modules(defaultnodes.__path__)] # type: ignore
for nodemodulename in all_node_modnames:
try:
nodemodule = importlib.import_module('nodeworks.defaultnodes.'+nodemodulename)
if hasattr(nodemodule, 'returnNodes'):
for node in nodemodule.returnNodes(): # type: ignore
self.addNode(node, [nodemodule.NAME]) # type: ignore
if hasattr(nodemodule, 'failed'):
for node in nodemodule.failed:
self.add_failed(node, [nodemodule.NAME])
except ImportError:
pass
[docs] def getNode(self, path):
''' Return the node at path '''
temp = self.nodetree
for p in path:
if p not in temp:
raise ValueError('.'.join(path)+' does not exist in the node library')
temp = temp[p]
return temp
def reloadNodes(self):
import nodeworks.defaultnodes as defaultnodes
for directory, dirnames, filenames in os.walk(defaultnodes.__path__[0]):
for dirname in dirnames:
path = os.path.join(directory, dirname)
all_node_modnames = [modinfo[1] for modinfo in pkgutil.walk_packages([path])] # type: ignore
for nodemodulename in all_node_modnames:
try:
nodemodule = importlib.import_module('nodeworks.defaultnodes.'+'.'.join([dirname, nodemodulename]))
importlib.reload(nodemodule)
except ImportError:
pass
all_node_modnames = [modinfo[1] for modinfo in pkgutil.walk_packages(defaultnodes.__path__)] # type: ignore
for nodemodulename in all_node_modnames:
try:
nodemodule = importlib.import_module('nodeworks.defaultnodes.'+ nodemodulename)
importlib.reload(nodemodule)
except ImportError:
pass