Source code for nodeworks.tools.general

# -*- coding: utf-8 -*-
"""
This file is part of the nodeworks core library

Licence
-------
As a work of the United States Government, this project is in the public domain
within the United States. As such, this code is licensed under
CC0 1.0 Universal public domain.

Please see the LICENSE.md for more information.
"""

from functools import reduce
from typing import Any, Dict, Tuple, Optional
import re
import os
import threading
import csv
import json

import urllib.parse
import urllib.request

try:
    import numpy as np
except ImportError:
    np = None

try:
    import pandas as pd
except ImportError:
    pd = None

try:
    import torch
except ImportError:
    torch = None

from qtpy import QtGui, QtCore, QtWidgets, uic
from nodeworks import IMAGEPATH, SCRIPT_DIRECTORY, UIPATHS

QTCOLORLIST = ['white', 'red', 'green', 'blue', 'black', 'darkred',
               'darkgreen', 'darkblue', 'cyan', 'magenta', 'yellow', 'gray',
               'darkcyan', 'darkmagenta', 'darkyellow', 'darkgray', 'lightgray',
               ]

TABLEAU20 = [(31, 119, 180),
             (174, 199, 232),
             (255, 127, 14),
             (255, 187, 120),
             (44, 160, 44),
             (152, 223, 138),
             (214, 39, 40),
             (255, 152, 150),
             (148, 103, 189),
             (197, 176, 213),
             (140, 86, 75),
             (196, 156, 148),
             (227, 119, 194),
             (247, 182, 210),
             (127, 127, 127),
             (199, 199, 199),
             (188, 189, 34),
             (219, 219, 141),
             (23, 190, 207),
             (158, 218, 229),
]

class DictList(list):
    '''
    Class used as a dtype for the list(dict()) table model.
    '''
    pass

# Build type map to convert strings back to objects
TYPEMAP: Dict[str, Optional[type]] = {'None': None, 'none': None, 'NoneType': None,
           "<type 'NoneType'>": None, "<class 'NoneType'>": None}
for dtype in [int, float, str, tuple, list, set, dict, bool, None,
              np.ndarray if np else None,
              np.float if np else None,
              np.float16 if np else None,
              np.float32 if np else None,
              np.float64 if np else None,
              np.int if np else None,
              np.int0 if np else None,
              np.int8 if np else None,
              np.int16 if np else None,
              np.int32 if np else None,
              np.int64 if np else None,
              pd.DataFrame if pd else None,
              pd.Series if pd else None,
              DictList]:

    dname = re.findall("<type '(.*)'>", str(dtype))
    if not dname:
        dname = re.findall("<class '(.*)'>", str(dtype))

    if dname:
        # in python 2 these are types
        TYPEMAP["<type '{}'>".format(dname[0])] = dtype
        # in python 3 these are classes
        TYPEMAP["<class '{}'>".format(dname[0])] = dtype

        TYPEMAP[dname[0]] = dtype

NAME_REGEX = re.compile(r'.*\.([0-9]+)', re.DOTALL)


def fuzzy_finder(inpt, search_list):
    # fuzzy finder
    # inspired by: https://github.com/amjith/fuzzyfinder
    suggestions = []
    pat = '.*?'.join(map(re.escape, inpt))
    regx = re.compile(pat, re.IGNORECASE)
    for item in search_list:
        r = regx.search(item)
        if r:
            suggestions.append((len(r.group()), r.start(), item))
    results = [z[-1] for z in sorted(suggestions)]
    return results


def get_separator(vertical=True, parent=None):
    """create a QFrame that looks like a separator"""
    f = QtWidgets.QFrame
    line = f(parent)
    if vertical:
        line.setFrameShape(f.VLine)
    else:
        line.setFrameShape(f.HLine)
    line.setFrameShadow(f.Sunken)
    return line


def get_unique_name(name, namelist, lower=True):
    '''
    Given a name and a name list, return a unique name.
    '''
    if lower:
        name = name.lower()

    if name in namelist:
        reg = NAME_REGEX.findall(name)
        if reg:
            base, num = name.split('.')
            name = ''.join([base, '.', str(int(num)+1)])
        else:
            name = '.'.join([name, '0'])

        if name in namelist:
            name = get_unique_name(name, namelist, lower)

    return name


class Cycle():
    """class for cycling through an iterable"""
    def __init__(self, iterable):
        self.iterable = iterable
        self.index = 0

    def reset(self):
        self.index = 0

    def next(self):
        val = self.iterable[self.index]
        self.index += 1
        if self.index > len(self.iterable) - 1:
            self.index = 0
        return val


def widget_iter(widget, recursive=True):
    """iterator function to recursively iterate over children of widgets"""
    for child in widget.children():
        if recursive and child.children():
            for child2 in widget_iter(child):
                yield child2
        yield child


def path2url(path):
    """Convert path to url."""
    return urllib.parse.urljoin('file:', urllib.request.pathname2url(path))


def color_func(color, alpha=1.0):
    '''
    Take a color and convert it to a Qt color. Defaults to black if color
    cannot be interpreted.

    Parameters
    ----------
    color:
        a color defined by:
        RGB[A] - (r, g, b [, a])
        HEX - '#F0F8FF'
        string - 'white', 'red', 'green', 'blue', 'black', 'darkred',
                 'darkgreen', 'darkblue', 'cyan', 'magenta', 'yellow', 'gray',
                 'darkcyan', 'darkmagenta', 'darkyellow', 'darkgray',
                 'lightgray'
    alpha (float):
        transparency value between 0.0 and 1.0 (default 1.0)
    '''

    if isinstance(color, str) and color.lower() in QTCOLORLIST:
        color = QtGui.QColor(color)
    elif isinstance(color, str) and color.startswith('#'):
        color = QtGui.QColor(color)
    elif type(color) == list or type(color) == tuple:
        color = QtGui.QColor(*color)
    else:
        color = QtGui.QColor('black')

    if type(alpha) == float:
        color.setAlphaF(alpha)
    elif type(alpha) == int:
        color.setAlpha(alpha)

    return color


def connection_type_func(dtype, style, term=False, default='default'):
    '''
    Return the color based on the dtype

    Parameters
    ----------
    dtype (object):
        an object to extract the type from i.e. type(dtype)
    style (dict):
        a dictionary of colors. The keys must be a string of the type. For
        example: 'float' for a float, 'int' for integer. The values are
        `QtGui.QColor` objects.
    term (bool):
        a boolean to toggle if the color thould be a terminal color (lighter)
        or regular color (default False)
    default (str):
        a key in style to use as the default color if the dtype is not in the
        style dicitonary (default 'default')

    Returns
    -------
    color (QColor), linestyle (PenStyle)
    '''

    typestring = re.findall("<type '(.*)'>", str(dtype))
    if len(typestring) == 0:
        typestring = re.findall("<class '(.*)'>", str(dtype))

    if typestring:
        if typestring[0] in style:
            if term:
                return [style[typestring[0]][0].lighter(),
                        style[typestring[0]][1]]
            else:
                return [style[typestring[0]][0], style[typestring[0]][1]]

    # If there is a default
    if default in style:
        if term:
            return [style[default][0].lighter(), style[default][1]]
        else:
            return [style[default][0], style[default][1]]

    # If all else fails
    return [QtCore.Qt.black, QtCore.Qt.SolidLine]


def get_image_path(name):
    if not name.endswith('.png') and not name.endswith('.svg'):
        name += '.svg'
    path = os.path.join(IMAGEPATH, name)
    return path


pixmap_cache: Dict[Tuple[str, int, int], Any] = {}


def get_pixmap(name, width, height):
    pixmap = pixmap_cache.get((name, width, height))
    if pixmap is None:
        img_path = get_image_path(name)
        if not os.path.exists(img_path):
            raise Warning('Icon: {} not found'.format(img_path))
        pixmap = QtGui.QPixmap(img_path).scaled(
            width, height, QtCore.Qt.KeepAspectRatio,
            QtCore.Qt.SmoothTransformation)
        pixmap_cache[(name, width, height)] = pixmap
    return pixmap


def get_icon(path):
    '''
    Convert a path to a QIcon.
    '''
    path = get_image_path(path)
    if os.path.exists(path):
        return QtGui.QIcon(path)
    else:
        raise Warning('Icon: {} not found'.format(path))


def get_ui(fname, parent=None):
    ui_file = None
    for path in UIPATHS:
        ui_file = os.path.join(path, fname)
        if os.path.exists(ui_file):
            break
    if ui_file is None:
        raise ValueError('Can not find uifile: {}'.format(fname))
    return uic.loadUi(ui_file, parent)


def isiterable(obj):
    try:
        iter(obj)
    except:
        return False
    return True


def triangle_path(rect, direction='right'):
    '''
    Given a rectangle and a direction, generate a QPainterPath of a triangle.

    Parameters
    ----------
    rect (QRect):
        the rectangle to draw the triangle in
    direction (str):
        the direction of the triangle, either 'left', 'right', 'up', or 'down'

    Returns
    -------
    path (QPainterPath):
        a path describing the triangle
    '''
    path = QtGui.QPainterPath()
    direction = direction.lower()
    if direction == 'left':
        path.moveTo(rect.right(), rect.center().y())
        path.lineTo(rect.topLeft())
        path.lineTo(rect.bottomLeft())
        path.lineTo(rect.right(), rect.center().y())
    elif direction == 'right':
        path.moveTo(rect.left(), rect.center().y())
        path.lineTo(rect.topRight())
        path.lineTo(rect.bottomRight())
        path.lineTo(rect.left(), rect.center().y())
    elif direction == 'up':
        path.moveTo(rect.center().x(), rect.top())
        path.lineTo(rect.bottomLeft())
        path.lineTo(rect.bottomRight())
        path.lineTo(rect.center().x(), rect.top())
    elif direction == 'down':
        path.moveTo(rect.center().x(), rect.bottom())
        path.lineTo(rect.topLeft())
        path.lineTo(rect.topRight())
        path.lineTo(rect.center().x(), rect.bottom())
    else:
        raise ValueError('{} is not a valid direction, needs to be in "left",'
                         '"right", "up", or "down"'.format(direction))

    return path


def get_from_dict(dataDict, mapList):
    return reduce(lambda d, k: d[k], mapList, dataDict)


def set_in_dict(dataDict, mapList, value):
    get_from_dict(dataDict, mapList[:-1])[mapList[-1]] = value


def recurse_dict(d, keys=()):
    '''
    Recursively loop through a dictionary of dictionaries.

    Parameters
    ----------
    d (dict):
        a dictionary to loop over

    Returns
    -------
    (keys (tuple), value):
        a tuple of keys (tuple) and the value.
    '''
    if type(d) == dict:
        for k in d:
            for rv in recurse_dict(d[k], keys + (k, )):
                yield rv
    else:
        yield (keys, d)


def string_to_slice(string, array=False):
    """given a string, return the slice object"""
    # TODO: this can cause trouble with [1,3,4,5,6], 2:
    string = string.strip().replace('[', '').replace(']', '')

    slice_ = []
    for indx in string.split(','):
        to_from_step = [is_int(i) for i in indx.split(':')]
        if len(to_from_step) == 1:
            slice_.append(to_from_step[0])
        else:
            slice_.append(slice(*to_from_step))

    # lists, tuples, etc...
    if not array:
        return slice_[0]
    # arrays ...
    else:
        return slice_


def flatten(lst):
    """flatten a list that could contain lists of lists of lists..."""
    for item in lst:
        if not isinstance(item, str):
            try:
                for i in flatten(item):
                    yield i
            except:
                yield item
        else:
            yield item


def is_int(i):
    try:
        return int(i)
    except ValueError:
        if len(i) > 0:
            return i
        else:
            return None


def index_list_with_float(f, l):
    '''
    Index a list with a float.

    Parameters
    ----------
    f : float
        a float between 0.0 and 1.0 to index the list with
    l : list
        list to be indexed

    Returns
    -------
    l[i] : unknown
        the indexed value
    '''
    n = len(l)  # list length
    if f >= 1.0:
        f = 0.99999999
    i = int(f*n)  # int
    return l[i]


def doe_to_iter(doe):
    '''
    Turn the DOE test matrix into an iterable.

    Parameters
    ----------
    doe (list(list)):
        the test matrix produced by buildDoe

    Return
    ------
    run (list):
        a list of [key, args, value]
    '''

    keys = doe[0]
    args = doe[1]
    matrix = doe[2:]
    run = []
    for values in matrix:
        row = []
        for key, arg, value in zip(keys, args, values):
            row.append([key, arg, value])
        run.append(row)
    return run


def check_on_main_thread(usingQT=True):
    if usingQT:
        return QtWidgets.QApplication.instance().thread() == \
            QtCore.QThread.currentThread()
    else:
        return threading.current_thread() is threading.main_thread()


def build_menu(menuList, parent=None):
    menu = QtWidgets.QMenu(parent)
    menuActions = {}
    for ap in menuList:
        if isinstance(ap[1], list):
            subMenu, subMenuActions = build_menu(ap[1], menu)
            subMenu.setTitle(ap[0])
            menuActions[ap[0]] = subMenuActions
            menu.addMenu(subMenu)
        else:
            action = QtWidgets.QAction(menu)
            if ap[0] == 'separator' and ap[1] is None:
                action.setSeparator(True)
            else:
                action.setText(ap[0])
                callBack = ap[1]
                if callBack is not None:
                    action.triggered.connect(callBack)
                checkable = False
                checked = False
                try:
                    checkable = ap[2]
                    checked = ap[3]
                except:
                    pass
                if checkable:
                    action.setCheckable(checkable)
                    action.setChecked(checked)
                menuActions[ap[0]] = action
            menu.addAction(action)
    return menu, menuActions


'''
The tail function was obtained from https://gist.github.com/volker48/3437288
and contains the following copywrite and license:

Copyright (c) 2012 Marcus McCurdy
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of the Software,
and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:

The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
Created 8/22/12
@author: Marcus McCurdy <marcus.mccurdy@gmail.com>
'''


def tail(path, lines_to_print=5):
    if lines_to_print < 1:
        return
    with open(path, 'r') as file:
        file.seek(0, os.SEEK_END)
        position = file.tell()
        lines_seen = 0
        if file.read(1) == '\n':
            position -= 1
            file.seek(position)
        while lines_seen < lines_to_print and file.tell() > 0:
            c = file.read(1)
            if c == '\n':
                lines_seen += 1
                if lines_seen == lines_to_print:
                    break
            position -= 2
            file.seek(position)
        return file.readlines()


def export_to_csv(table, fname, delim=',', header=None):
    """
    Export data to a csv file.

    Parameters
    ----------
    table (list): list of iterables to join and write.
    """

    with open(fname, 'w',  encoding="utf-8") as f:
        if header is not None:
            f.write('# ' + delim.join(header) + '\n')
        for row in table:
            f.write(delim.join([str(v) for v in row]) + '\n')


def safe_float(v, default=0.0):
    """
    try converting to a float. If it fails, return default Value.
    """
    try:
        return float(v)
    except ValueError:
        return default


def safe_int(v, default=0):
    """
    try converting to a float. If it fails, return default Value.
    """
    try:
        return int(v)
    except ValueError:
        return default


def import_csv(fname, delim=',', header=0, comment='#', reader='csv'):
    """
    Read a file and convert to a csv file.

    parameters
    ----------
    fname (str): path the file
    delim (str): deliminator, default ','
    header (int): row number that contains the header
    comment (str): any line that starts with comment will be ignored
    reader (str): backend reader ['csv', 'pandas']
    """

    reader = reader.lower()
    if reader not in ['csv', 'pandas']:
        reader = 'pandas'

    # use pandas
    if pd is not None and reader == 'pandas':
        table = pd.read_csv(fname, sep=delim, header=header, comment=comment)

    elif reader == 'csv':
        table = []

        with open(fname, newline='') as csvfile:
            reader = csv.reader(csvfile, delimiter=delim, quotechar=comment)
            for i, row in enumerate(reader):
                if i > header:
                    table.append([safe_float(v) for v in row])
                elif i == header:
                    table.append(row)

    return table


class CustomEncoder(json.JSONEncoder):
    def default(self, obj):  # pylint: disable=method-hidden
        if np is not None:
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif pd is not None and isinstance(obj, pd.Series):
                return {'__nw_type__': 'pd_series', 'data': obj.to_json()}
            elif pd is not None and isinstance(obj, pd.DataFrame):
                return {'__nw_type__': 'pd_frame', 'data': obj.to_json()}
            elif torch is not None and isinstance(obj, torch.Tensor):
                return {'__nw_type__': 'torch_tensor',
                        'data': obj.cpu().numpy().tolist()}
            else:
                return super().default(obj)
        else:
            return super().default(obj)


class CustomDecoder(json.JSONDecoder):
    def __init__(self, *args, **kwargs):
        json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)

    def object_hook(self, dct):  # pylint: disable=method-hidden
        if '__nw_type__' in dct:
            nw_type = dct.get('__nw_type__')
            if pd is not None and nw_type == 'pd_series':
                return pd.read_json(dct.get('data', ''), typ='series')
            elif pd is not None and nw_type == 'pd_frame':
                return pd.read_json(dct.get('data', ''), typ='frame')
            elif torch is not None and nw_type == 'torch_tensor':
                return torch.from_numpy(np.asarray(dct.get('data', '[]')))
            else:
                return dct
        return dct


[docs]class WorkerThread(QtCore.QThread): def __init__(self, function, parent=None): QtCore.QThread.__init__(self, parent) self.function = function self.parent = parent
[docs] def run(self): self.function() self.exit()