Source code for CHAP.utils.general

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""A collection of generic functions for use in any CHAP
Processor, Reader or Writer."""

# FIX write function that returns a list of peak indices for a given plot
# FIX use raise_error concept on more functions

# System modules
from ast import literal_eval
from copy import deepcopy
import collections.abc
from logging import getLogger
import os
import re
import sys
from typing import Union

# Third party modules
import numpy as np
try:
    import matplotlib.pyplot as plt
except ImportError:
    pass

logger = getLogger(__name__)

# pylint: disable=no-member
tiny = np.finfo(np.float64).resolution
# pylint: enable=no-member

Int = Union[int, np.integer]
Float = Union[float, np.floating]
Num = Union[int, np.integer, float, np.floating]

[docs] def gformat(value, length=11): """Format a number with '%g'-like format, while: - the length of the output string will be of the requested length - positive numbers will have a leading blank - the precision will be as high as possible - trailing zeros will not be trimmed :param value: Input value. :type value: float :param length: Length of the number field, defaults to `11`. :type length: int, optional :return: Formatted number. :rtype: str """ # Code taken from lmfit library if value is None or isinstance(value, bool): return f'{repr(value):>{length}s}' try: expon = int(np.log10(abs(value))) except (OverflowError, ValueError): expon = 0 except TypeError: return f'{repr(value):>{length}s}' length = max(length, 7) form = 'e' prec = length - 7 if abs(expon) > 99: prec -= 1 elif 0 < expon < prec+4 or -expon < prec-1 <= 0: form = 'f' prec += 4 if expon > 0: prec -= expon return f'{value:{length}.{prec}{form}}'
[docs] def getfloat_attr(obj, attr, length=11): """Format an attribute of an object for printing. :param obj: Object that the attr belongs to. :type obj: str :param attr: Attribute. :type attr: str :param length: Length of the number field, defaults to `11`. :type length: int, optional :return: Representation of the attribute. :rtype: str """ # Code taken from lmfit library value = getattr(obj, attr, None) if value is None: return 'unknown' if isinstance(value, Int): return f'{value}' if isinstance(value, Float): return gformat(value, length=length).strip() return repr(value)
[docs] def depth_list(l): """Return the depth of a list. :param l: Input list. :type l: list :return: Depth of a list. :rtype: int """ return isinstance(l, list) and 1+max(map(depth_list, l))
[docs] def depth_tuple(t): """Return the depth of a tuple. :param t: Input tuple. :type t: tuple :return: Depth of a tuple. :rtype: int """ return isinstance(t, tuple) and 1+max(map(depth_tuple, t))
[docs] def unwrap_tuple(t): """Unwrap a tuple. :param t: Input tuple. :type t: tuple :return: Unwrapped tupple. :rtype: tuple """ if depth_tuple(t) > 1 and len(t) == 1: t = unwrap_tuple(*t) return t
[docs] def all_any(l, key): """Check for a common key in a list of dictionaries, looping at maximum only once over the entire list. :param l: Input list. :type l: list[dict] :param key: Common dictionary key. :type key: Any :return: `1` if `all(l, key)`, `0` if `not any(l, key)`, or `-1` otherwise. Return `None` for a zero length input list. :rtype: `None` or int """ ret = None for d in l: if key in d: if ret == 0: ret = -1 break if ret is None: ret = 1 else: if ret == 1: ret = -1 break if ret is None: ret = 0 return ret
[docs] def illegal_value(value, name, location=None, raise_error=False, log=True): """Print illegal value message and/or raise error. :param value: Input value. :param name: Value name. :type name: str :param location: Input location. :type location: str, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :raise: ValueError when `raise_error` is set to `True`. """ if not isinstance(location, str): location = '' else: location = f'in {location} ' if isinstance(name, str): error_msg = \ f'Illegal value for {name} {location}({value}, {type(value)})' else: error_msg = f'Illegal value {location}({value}, {type(value)})' if log: logger.error(error_msg) if raise_error: raise ValueError(error_msg)
[docs] def illegal_combination( value1, name1, value2, name2, location=None, raise_error=False, log=True): """Print illegal combination message and/or raise error. :param value1: Input value. :param name1: Value name. :type name1: str :param value2: Input value. :param name2: Value name. :type name2: str :param location: Input location. :type location: str, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :raise: ValueError when `raise_error` is set to `True`. """ if not isinstance(location, str): location = '' else: location = f'in {location} ' if isinstance(name1, str): error_msg = f'Illegal combination for {name1} and {name2} {location}' \ f'({value1}, {type(value1)} and {value2}, {type(value2)})' else: error_msg = f'Illegal combination {location}' \ f'({value1}, {type(value1)} and {value2}, {type(value2)})' if log: logger.error(error_msg) if raise_error: raise ValueError(error_msg)
[docs] def not_zero(value): """Return value with a minimal absolute size of `tiny`, (numpy.finfo(numpy.float64).resolution) preserving the sign. :param value: Input value. :type value: float :return: Minimum of input value and `tiny`, preserving sign. """ return float(np.copysign(max(tiny, abs(value)), value))
[docs] def test_ge_gt_le_lt( ge, gt, le, lt, func, location=None, raise_error=False, log=True): """Check individual and mutual validity of ge, gt, le, lt qualifiers. :param ge: Greater or equal to. :type ge: int or float :param gt: Greater than. :type gt: int or float :param le: Smaller or equal to. :type le: int or float :param lt: Smaller than. :type lt: int or float :param func: Test for integers or numbers. :type func: callable: is_int, is_num :param location: Input location. :type location: str, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` upon success or `False` when mutually exlusive. :rtype: bool """ if ge is None and gt is None and le is None and lt is None: return True if ge is not None: if not func(ge): illegal_value(ge, 'ge', location, raise_error, log) return False if gt is not None: illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log) return False elif gt is not None and not func(gt): illegal_value(gt, 'gt', location, raise_error, log) return False if le is not None: if not func(le): illegal_value(le, 'le', location, raise_error, log) return False if lt is not None: illegal_combination(le, 'le', lt, 'lt', location, raise_error, log) return False elif lt is not None and not func(lt): illegal_value(lt, 'lt', location, raise_error, log) return False if ge is not None: if le is not None and ge > le: illegal_combination(ge, 'ge', le, 'le', location, raise_error, log) return False if lt is not None and ge >= lt: illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log) return False elif gt is not None: if le is not None and gt >= le: illegal_combination(gt, 'gt', le, 'le', location, raise_error, log) return False if lt is not None and gt >= lt: illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log) return False return True
[docs] def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): """Return a range string representation matching the ge, gt, le, lt qualifiers. Does not validate the inputs, do that as needed before calling. :param ge: Greater or equal to. :type ge: int or float, optional :param gt: Greater than. :type gt: int or float, optional :param le: Smaller or equal to. :type le: int or float, optional :param lt: Smaller than. :type lt: int or float, optional :return: Range string representation. :rtype: str """ range_string = '' if ge is not None: if le is None and lt is None: range_string += f'>= {ge}' else: range_string += f'[{ge}, ' elif gt is not None: if le is None and lt is None: range_string += f'> {gt}' else: range_string += f'({gt}, ' if le is not None: if ge is None and gt is None: range_string += f'<= {le}' else: range_string += f'{le}]' elif lt is not None: if ge is None and gt is None: range_string += f'< {lt}' else: range_string += f'{lt})' return range_string
[docs] def is_int( value, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """Value is an integer in range ge <= value <= le or gt < value < lt or some combination. :param value: Input value. :type value: int :param ge: Greater or equal to. :type ge: int, optional :param gt: Greater than. :type gt: int, optional :param le: Smaller or equal to. :type le: int, optional :param lt: Smaller than. :type lt: int, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is in valid range or `False` if not. :rtype: bool """ return _is_int_or_num(value, 'int', ge, gt, le, lt, raise_error, log)
[docs] def is_num( value, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """Value is a number in range ge <= value <= le or gt < value < lt or some combination. :param value: Input value. :type value: int or float :param ge: Greater or equal to. :type ge: int or float, optional :param gt: Greater than. :type gt: int or float, optional :param le: Smaller or equal to. :type le: int or float, optional :param lt: Smaller than. :type lt: int or float, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is in valid range or `False` if not. :rtype: bool """ return _is_int_or_num(value, 'num', ge, gt, le, lt, raise_error, log)
def _is_int_or_num( value, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): if type_str == 'int': if not isinstance(value, Int): illegal_value(value, 'value', '_is_int_or_num', raise_error, log) return False if not test_ge_gt_le_lt( ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log): return False elif type_str == 'num': if not isinstance(value, Num): illegal_value(value, 'value', '_is_int_or_num', raise_error, log) return False if not test_ge_gt_le_lt( ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log): return False else: illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log) return False if ge is None and gt is None and le is None and lt is None: return True error = False error_msg = '' if ge is not None and value < ge: error = True error_msg = f'Value {value} out of range: {value} !>= {ge}' if not error and gt is not None and value <= gt: error = True error_msg = f'Value {value} out of range: {value} !> {gt}' if not error and le is not None and value > le: error = True error_msg = f'Value {value} out of range: {value} !<= {le}' if not error and lt is not None and value >= lt: error = True error_msg = f'Value {value} out of range: {value} !< {lt}' if error: if log: logger.error(error_msg) if raise_error: raise ValueError(error_msg) return False return True
[docs] def is_int_pair( values, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """Value is an integer pair, each in range ge <= values[i] <= le or gt < values[i] < lt or ge[i] <= values[i] <= le[i] or gt[i] < values[i] < lt[i] or some combination. :param values: Input values. :type values: list[int, int] :param ge: Greater or equal to. :type ge: int, optional :param gt: Greater than. :type gt: int, optional :param le: Smaller or equal to. :type le: int, optional :param lt: Smaller than. :type lt: int, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is a valid pair in the valid range or `False` if not. :rtype: bool """ return _is_int_or_num_pair(values, 'int', ge, gt, le, lt, raise_error, log)
[docs] def is_num_pair( values, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """Value is a number pair, each in range ge <= values[i] <= le or gt < values[i] < lt or ge[i] <= values[i] <= le[i] or gt[i] < values[i] < lt[i] or some combination. :param values: Input values. :type values: list[int, int] or list[float, float] :param ge: Greater or equal to. :type ge: int or float, optional :param gt: Greater than. :type gt: int or float, optional :param le: Smaller or equal to. :type le: int or float, optional :param lt: Smaller than. :type lt: int or float, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is a valid pair in the valid range or `False` if not. :rtype: bool """ return _is_int_or_num_pair(values, 'num', ge, gt, le, lt, raise_error, log)
def _is_int_or_num_pair( values, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): if type_str == 'int': if not (isinstance(values, (tuple, list)) and len(values) == 2 and isinstance(values[0], Int) and isinstance(values[1], Int)): illegal_value( values, 'values', '_is_int_or_num_pair', raise_error, log) return False func = is_int elif type_str == 'num': if not (isinstance(values, (tuple, list)) and len(values) == 2 and isinstance(values[0], Num) and isinstance(values[1], Num)): illegal_value( values, 'values', '_is_int_or_num_pair', raise_error, log) return False func = is_num else: illegal_value( type_str, 'type_str', '_is_int_or_num_pair', raise_error, log) return False if ge is None and gt is None and le is None and lt is None: return True if ge is None or func(ge, log=True): ge = 2*[ge] elif not _is_int_or_num_pair( ge, type_str, raise_error=raise_error, log=log): return False if gt is None or func(gt, log=True): gt = 2*[gt] elif not _is_int_or_num_pair( gt, type_str, raise_error=raise_error, log=log): return False if le is None or func(le, log=True): le = 2*[le] elif not _is_int_or_num_pair( le, type_str, raise_error=raise_error, log=log): return False if lt is None or func(lt, log=True): lt = 2*[lt] elif not _is_int_or_num_pair( lt, type_str, raise_error=raise_error, log=log): return False if (not func(values[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or not func( values[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): return False return True
[docs] def is_int_series( t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """Value is a tuple or list of integers, each in range ge <= l[i] <= le or gt < l[i] < lt or some combination. :param t_or_l: Input values. :type t_or_l: list[int] :param ge: Greater or equal to. :type ge: int, optional :param gt: Greater than. :type gt: int, optional :param le: Smaller or equal to. :type le: int, optional :param lt: Smaller than. :type lt: int, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is a valid list or `False` if not. :rtype: bool """ if not test_ge_gt_le_lt( ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): return False if not isinstance(t_or_l, (tuple, list)): illegal_value(t_or_l, 't_or_l', 'is_int_series', raise_error, log) return False if any(not is_int(v, ge, gt, le, lt, raise_error, log) for v in t_or_l): return False return True
[docs] def is_num_series( t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or gt < l[i] < lt or some combination. :param t_or_l: Input values. :type t_or_l: list[int] or list[float] :param ge: Greater or equal to. :type ge: int or float, optional :param gt: Greater than. :type gt: int or float, optional :param le: Smaller or equal to. :type le: int or float, optional :param lt: Smaller than. :type lt: int or float, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is a valid list or `False` if not. :rtype: bool """ if not test_ge_gt_le_lt( ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): return False if not isinstance(t_or_l, (tuple, list)): illegal_value(t_or_l, 't_or_l', 'is_num_series', raise_error, log) return False if any(not is_num(v, ge, gt, le, lt, raise_error, log) for v in t_or_l): return False return True
[docs] def is_str_series(t_or_l, raise_error=False, log=True): """Value is a tuple or list of strings. :param t_or_l: Input values. :type t_or_l: tuple[str] or list[str] :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is a valid `False` if not. :rtype: bool """ if (not isinstance(t_or_l, (tuple, list)) or any(not isinstance(s, str) for s in t_or_l)): illegal_value(t_or_l, 't_or_l', 'is_str_series', raise_error, log) return False return True
[docs] def is_str_or_str_series(t_or_l, raise_error=False, log=True): """Value is a string ot a tuple or list of strings. :param t_or_l: Input values. :type t_or_l: str or tuple[str] or list[str] :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is valid or `False` if not. :rtype: bool """ if isinstance(t_or_l, str): return True if (not isinstance(t_or_l, (tuple, list)) or any(not isinstance(s, str) for s in t_or_l)): illegal_value( t_or_l, 't_or_l', 'is_str_or_str_series', raise_error, log) return False return True
[docs] def is_dict_series(t_or_l, raise_error=False, log=True): """Value is a tuple or list of dictionaries. :param t_or_l: Input values. :type t_or_l: tuple[dict] or list[dict] :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is valid or `False` if not. :rtype: bool """ if (not isinstance(t_or_l, (tuple, list)) or any(not isinstance(d, dict) for d in t_or_l)): illegal_value(t_or_l, 't_or_l', 'is_dict_series', raise_error, log) return False return True
[docs] def is_dict_nums(d, raise_error=False, log=True): """Value is a dictionary with single number values. :param t_or_l: Input values. :type t_or_l: dict[str, int] :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is valid or `False` if not. :rtype: bool """ if (not isinstance(d, dict) or any(not is_num(v, log=False) for v in d.values())): illegal_value(d, 'd', 'is_dict_nums', raise_error, log) return False return True
[docs] def is_dict_strings(d, raise_error=False, log=True): """Value is a dictionary with single string values. :param t_or_l: Input values. :type t_or_l: dict[str, str] :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is valid or `False` if not. :rtype: bool """ if (not isinstance(d, dict) or any(not isinstance(v, str) for v in d.values())): illegal_value(d, 'd', 'is_dict_strings', raise_error, log) return False return True
[docs] def is_index(value, ge=0, lt=None, raise_error=False, log=True): """Value is an array index in range ge <= value < lt. .. note:: The value for `lt` IS NOT included! :param value: Input value. :type value: int :param ge: Greater or equal to, defaults to `0`. :type ge: int, optional :param lt: Smaller than. :type lt: int, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is a valid array index or `False` if not. :rtype: bool """ if isinstance(lt, Int): if lt <= ge: illegal_combination( ge, 'ge', lt, 'lt', 'is_index', raise_error, log) return False return is_int(value, ge=ge, lt=lt, raise_error=raise_error, log=log)
[docs] def is_index_range(value, ge=0, le=None, lt=None, raise_error=False, log=True): """Value is an array index range in range ge <= value[0] <= value[1] <= le or ge <= value[0] <= value[1] < lt. .. note:: The value for `le` IS included! :param value: Input value. :type value: list[int, int] :param ge: Greater or equal to, defaults to `0`. :type ge: int, optional :param le: Smaller or equal to. :type le: int, optional :param lt: Smaller than. :type lt: int, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: `True` if input value is a valid array index range or `False` if not. :rtype: bool """ if not is_int_pair(value, raise_error=raise_error, log=log): return False if not test_ge_gt_le_lt( ge, None, le, lt, is_int, 'is_index_range', raise_error, log): return False if (not ge <= value[0] <= value[1] or (le is not None and value[1] > le) or (lt is not None and value[1] >= lt)): if le is not None: error_msg = f'Value {value} out of range: ' \ f'!({ge} <= {value[0]} <= {value[1]} <= {le})' else: error_msg = f'Value {value} out of range: ' \ f'!({ge} <= {value[0]} <= {value[1]} < {lt})' if log: logger.error(error_msg) if raise_error: raise ValueError(error_msg) return False return True
[docs] def index_nearest(a, value): """Return index of nearest array value. :param a: Input array. :type a: array-like :param value: Input value. :type value: int or float :return: Index or array value nearest to input value. :rtype: int """ a = np.asarray(a) if a.ndim > 1: raise ValueError( f'Invalid array dimension for parameter a ({a.ndim}, {a})') # Round up for .5 value *= 1.0 + sys.float_info.epsilon return (int)(np.argmin(np.abs(a-value)))
[docs] def index_nearest_down(a, value): """Return index of nearest array value, rounded down. :param a: Input array. :type a: array-like :param value: Input value. :type value: int or float :return: Index or array value nearest to input value, rounded down. :rtype: int """ a = np.asarray(a) if a.ndim > 1: raise ValueError( f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value < a[index] and index > 0: index -= 1 return index
[docs] def index_nearest_up(a, value): """Return index of nearest array value, rounded up. :param a: Input array. :type a: array-like :param value: Input value. :type value: int or float :return: Index or array value nearest to input value, rounded up. :rtype: int """ a = np.asarray(a) if a.ndim > 1: raise ValueError( f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value > a[index] and index < a.size-1: index += 1 return index
[docs] def get_consecutive_int_range(a): """Return a list of pairs of integers marking consecutive ranges of integers. :param a: Input array. :type a: array-like :return: Pairs of integers marking consecutive ranges of integers. :rtype: list """ a.sort() i = 0 int_ranges = [] while i < len(a): j = i while j < len(a)-1: if a[j+1] > 1 + a[j]: break j += 1 int_ranges.append([a[i], a[j]]) i = j+1 return int_ranges
[docs] def round_to_n(value, n=1): """Round to a specific number of sig figs. :param value: Input value. :type value: float :param n: Number of sig figs, defaults to `1`. :type n: int, optional :return: Input value rounded to `n` sig figs. :rtype: float """ if value == 0.0: return 0 return type(value)(round(value, n-1-int(np.floor(np.log10(abs(value))))))
[docs] def round_up_to_n(value, n=1): """Round up to a specific number of sig figs. :param value: Input value. :type value: float :param n: Number of sig figs, defaults to `1`. :type n: int, optional :return: Input value rounded up to `n` sig figs. :rtype: float """ value_round = round_to_n(value, n) if abs(value/value_round) > 1.0: value_round += \ np.sign(value) * 10**(np.floor(np.log10(abs(value)))+1-n) return type(value)(value_round)
[docs] def trunc_to_n(value, n=1): """Truncate to a specific number of sig figs. :param value: Input value. :type value: float :param n: Number of sig figs, defaults to `1`. :type n: int, optional :return: Input value trunctated to `n` sig figs. :rtype: float """ value_round = round_to_n(value, n) if abs(value_round/value) > 1.0: value_round -= \ np.sign(value) * 10**(np.floor(np.log10(abs(value)))+1-n) return type(value)(value_round)
[docs] def almost_equal(value1, value2, sig_figs): """Check if equal to within a certain number of significant digits. :param value1: Input value. :type value1: float :param value2: Input value. :type value2: float :param sig_figs: Number of sig figs. :type sig_figs: int :return: `True` is inputs are equal within `sig_figs` sig figs, `False` if not. :rtype: bool """ if is_num(value1) and is_num(value2): return abs(round_to_n(value1-value2, sig_figs)) < pow(10, 1-sig_figs) raise ValueError( f'Invalid value for value1 or value2 in almost_equal ' f'(value1: {value1}, {type(value1)}, ' f'value2: {value2}, {type(value2)})')
[docs] def string_to_list( s, split_on_dash=True, remove_duplicates=True, sort=True, raise_error=False): """Return a list of numbers by splitting/expanding a string on any combination of commas, whitespaces, or dashes (when split_on_dash=True). e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12] :param s: Input string. :type s: str :param split_on_dash: Allow dashes in input string, defaults to `True`. :type split_on_dash: bool, optional :param remove_duplicates: Removes duplicates (may also change the order), defaults to `True`. :type remove_duplicates: bool, optional :param sort: Sort in ascending order, defaults to `True`. :type sort: bool, optional :param raise_error: Raise an exception upon any error, defaults to `False`. :type raise_error: bool, optional :return: Input list or `None` upon an illegal input. :rtype: list """ if not isinstance(s, str): illegal_value(s, 's', location='string_to_list') return None if not s: return [] try: list1 = re.split(r'\s+,\s+|\s+,|,\s+|\s+|,', s.strip()) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError) as e: if not raise_error: return None raise e if split_on_dash: try: l_of_i = [] for v in list1: list2 = [ literal_eval(x) for x in re.split(r'\s+-\s+|\s+-|-\s+|\s+|-', v)] if len(list2) == 1: l_of_i += list2 elif len(list2) == 2 and list2[1] > list2[0]: l_of_i += list(range(list2[0], 1+list2[1])) else: raise ValueError except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError) as e: if not raise_error: return None raise e else: l_of_i = [literal_eval(x) for x in list1] if remove_duplicates: l_of_i = list(dict.fromkeys(l_of_i)) if sort: l_of_i = sorted(l_of_i) return l_of_i
[docs] def list_to_string(a): """Return an array index list of integers in string notation. e.g: [1, 3, 5, 6, 7, 8, 12] -> '1, 3, 5-8, 12' :param a: Input array. :type a: array-like :return: Index list in string notation. :rtype: str """ int_ranges = get_consecutive_int_range(a) if not int_ranges: return '' if int_ranges[0][0] == int_ranges[0][1]: s = f'{int_ranges[0][0]}' else: s = f'{int_ranges[0][0]}-{int_ranges[0][1]}' for int_range in int_ranges[1:]: if int_range[0] == int_range[1]: s += f', {int_range[0]}' else: s += f', {int_range[0]}-{int_range[1]}' return s
[docs] def get_trailing_int(s): """Get the trailing integer in a string. :param s: Input value :type s: str :return: Trailing integer in a string. :rtype: int """ index_regex = re.compile(r'\d+$') match = index_regex.search(s) if match is None: return None return int(match.group())
[docs] def input_int( s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, raise_error=False, log=True): """Interactively prompt the user to enter an integer. :param s: Interactive user prompt, defaults to "Enter an integer" with an optional range. :type s: str, optional :param ge: Greater or equal to. :type ge: int, optional :param gt: Greater than. :type gt: int, optional :param le: Smaller or equal to. :type le: int, optional :param lt: Smaller than. :type lt: int, optional :param default: Default value to return. :type default: int, optional :param inset: Valid values to choose from. :type inset: list[int], optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: Entered value. :rtype: int """ return _input_int_or_num( 'int', s, ge, gt, le, lt, default, inset, raise_error, log)
[docs] def input_num( s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False, log=True): """Interactively prompt the user to enter a number. :param s: Interactive user prompt, defaults to "Enter a value" with an optional range. :type s: str, optional :param ge: Greater or equal to. :type ge: int or float, optional :param gt: Greater than. :type gt: int or float, optional :param le: Smaller or equal to. :type le: int or float, optional :param lt: Smaller than. :type lt: int or float, optional :param default: Default value to return. :type default: int or float, optional :param raise_error: Raise an error, defaults to `False`. :type raise_error: bool, optional :param log: Write error message to the logger, defaults to `True`. :type log: bool, optional :return: Entered value. :rtype: int or float """ return _input_int_or_num( 'num', s, ge, gt, le, lt, default, None, raise_error,log)
def _input_int_or_num( type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, raise_error=False, log=True): """Interactively prompt the user to enter an integer or number. """ if type_str == 'int': if not test_ge_gt_le_lt( ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log): return None elif type_str == 'num': if not test_ge_gt_le_lt( ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log): return None else: illegal_value( type_str, 'type_str', '_input_int_or_num', raise_error, log) return None if default is not None: if not _is_int_or_num( default, type_str, raise_error=raise_error, log=log): return None if ge is not None and default < ge: illegal_combination( ge, 'ge', default, 'default', '_input_int_or_num', raise_error, log) return None if gt is not None and default <= gt: illegal_combination( gt, 'gt', default, 'default', '_input_int_or_num', raise_error, log) return None if le is not None and default > le: illegal_combination( le, 'le', default, 'default', '_input_int_or_num', raise_error, log) return None if lt is not None and default >= lt: illegal_combination( lt, 'lt', default, 'default', '_input_int_or_num', raise_error, log) return None default_string = f' [{default}]' else: default_string = '' if inset is not None: if (not isinstance(inset, (tuple, list)) or any(not isinstance(i, Int) for i in inset)): illegal_value( inset, 'inset', '_input_int_or_num', raise_error, log) return None v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}' if v_range: v_range = f' {v_range}' if s is None: if type_str == 'int': print(f'Enter an integer{v_range}{default_string}: ') else: print(f'Enter a number{v_range}{default_string}: ') else: print(f'{s}{v_range}{default_string}: ') try: i = input() if isinstance(i, str) and not i: v = default print(f'{v}') else: v = literal_eval(i) if inset and v not in inset: raise ValueError(f'{v} not part of the set {inset}') except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): v = None if not _is_int_or_num(v, type_str, ge, gt, le, lt): v = _input_int_or_num( type_str, s, ge, gt, le, lt, default, inset, raise_error, log) return v
[docs] def input_int_list( s=None, num_max=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, sort=True, raise_error=False, log=True): """Prompt the user to input a list of integers and split the entered string on any combination of commas, whitespaces, or dashes (when split_on_dash is True). e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12] :param s: Interactive user prompt. :type s: str, optional :param num_max: Maximum number of inputs in list. :type num_max: int, optional :param ge: Minimum value of inputs in list. :type ge: int, optional :param le: Minimum value of inputs in list. :type le: int, optional :param split_on_dash: Allow dashes in input string, defaults to `True`. :type split_on_dash: bool, optional :param remove_duplicates: Removes duplicates (may also change the order), defaults to `True`. :type remove_duplicates: bool, optional :param sort: Sort in ascending order, defaults to `True`. :type sort: bool, optional :param raise_error: Raise an exception upon any error, defaults to `False`. :type raise_error: bool, optional :param log: Print an error message upon any error, defaults to `True`. :type log: bool, optional :return: Input list or `None` upon an illegal input. :rtype: list """ return _input_int_or_num_list( 'int', s, num_max, ge, le, split_on_dash, remove_duplicates, sort, raise_error, log)
[docs] def input_num_list( s=None, num_max=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False, log=True): """Prompt the user to input a list of numbers and split the entered string on any combination of commas or whitespaces. e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0] :param s: Interactive user prompt. :type s: str, optional :param num_max: Maximum number of inputs in list. :type num_max: int, optional :param ge: Minimum value of inputs in list. :type ge: float, optional :param le: Minimum value of inputs in list. :type le: float, optional :param remove_duplicates: Removes duplicates (may also change the order), defaults to `True`. :type remove_duplicates: bool, optional :param sort: Sort in ascending order, defaults to `True`. :type sort: bool, optional :param raise_error: Raise an exception upon any error, defaults to `False`. :type raise_error: bool, optional :param log: Print an error message upon any error, defaults to `True`. :type log: bool, optional :return: Input list or `None` upon an illegal input. :rtype: list """ return _input_int_or_num_list( 'num', s, num_max, ge, le, False, remove_duplicates, sort, raise_error, log)
def _input_int_or_num_list( type_str, s=None, num_max=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, sort=True, raise_error=False, log=True): # FIX do we want a limit on max dimension? if type_str == 'int': if not test_ge_gt_le_lt( ge, None, le, None, is_int, 'input_int_or_num_list', raise_error, log): return None elif type_str == 'num': if not test_ge_gt_le_lt( ge, None, le, None, is_num, 'input_int_or_num_list', raise_error, log): return None else: illegal_value(type_str, 'type_str', '_input_int_or_num_list') return None if (num_max is not None and not is_int(num_max, gt=0, raise_error=raise_error, log=log)): return None v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}' if v_range: v_range = f' (each value in {v_range})' if s is None: print(f'Enter a series of integers{v_range}: ') else: print(f'{s}{v_range}: ') try: l = string_to_list(input(), split_on_dash, remove_duplicates, sort) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): l = None except Exception as exc: raise ValueError('Unexpected error') from exc if (not isinstance(l, list) or (num_max is not None and len(l) > num_max) or any( not _is_int_or_num(v, type_str, ge=ge, le=le) for v in l)): num = '' if num_max is None else f'up to {num_max} ' if split_on_dash: print(f'Invalid input: enter a valid set of {num}dash/comma/' 'whitespace separated numbers e.g. 1 3,5-8 , 12') else: print(f'Invalid input: enter a valid set of {num}comma/whitespace ' 'separated numbers e.g. 1 3,5 8 , 12') l = _input_int_or_num_list( type_str, s, ge, le, split_on_dash, remove_duplicates, sort, raise_error, log) return l
[docs] def input_yesno(s=None, default=None): """Interactively prompt the user to enter a y/n question. :param s: Interactive user prompt, defaults to "Enter yes or no". :type s: str, optional :param default: Default value to return. :type default: int or float, optional :return: `True` if "yes", `False` if "no". :rtype: bool """ if default is not None: if not isinstance(default, str): illegal_value(default, 'default', 'input_yesno') return None if default.lower() in 'yes': default = 'y' elif default.lower() in 'no': default = 'n' else: illegal_value(default, 'default', 'input_yesno') return None default_string = f' [{default}]' else: default_string = '' if s is None: print(f'Enter yes or no{default_string}: ') else: print(f'{s}{default_string}: ') i = input() if isinstance(i, str) and not i: i = default print(f'{i}') if i is not None and i.lower() in 'yes': v = True elif i is not None and i.lower() in 'no': v = False else: print('Invalid input, enter yes or no') v = input_yesno(s, default) return v
[docs] def input_menu(items, default=None, header=None): """Interactively prompt the user to select from a menu. :param items: Menu items. :type s: list[str] :param default: Default value to return. :type default: str :return: Choosen entry. :rtype: str """ if (not isinstance(items, (tuple, list)) or any(not isinstance(i, str) for i in items)): illegal_value(items, 'items', 'input_menu') return None if default is not None: if not (isinstance(default, str) and default in items): logger.error( f'Invalid value for default ({default}), must be in {items}') return None default_string = f' [{1+items.index(default)}]' else: default_string = '' if header is None: print('Choose one of the following items ' f'(1, {len(items)}){default_string}:') else: print(f'{header} (1, {len(items)}){default_string}:') for i, choice in enumerate(items): print(f' {i+1}: {choice}') try: choice = input() if isinstance(choice, str) and not choice: choice = items.index(default) print(f'{1+choice}') else: choice = literal_eval(choice) if isinstance(choice, int) and 1 <= choice <= len(items): choice -= 1 else: raise ValueError except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): choice = None except Exception as exc: raise Exception('Unexpected error') from exc if choice is None: print(f'Invalid choice, enter a number between 1 and {len(items)}') choice = input_menu(items, default) return choice
[docs] def assert_no_duplicates_in_list_of_dicts(l, raise_error=False): """Assert that there are no duplicates in a list of dictionaries. :param l: Input list. :type l: list[dict] :param raise_error: Raise an exception upon any error, defaults to `False`. :type raise_error: bool, optional """ if not isinstance(l, list): illegal_value( l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) return None if any(not isinstance(d, dict) for d in l): illegal_value( l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) return None if (len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}])): if raise_error: raise ValueError(f'Duplicate items found in {l}') logger.error(f'Duplicate items found in {l}') return None return l
[docs] def assert_no_duplicate_key_in_list_of_dicts(l, key, raise_error=False): """Assert that there are no duplicate keys in a list of dictionaries. :param l: Input list. :type l: list[dict] :param key: Dictionary key. :type key: str :param raise_error: Raise an exception upon any error, defaults to `False`. :type raise_error: bool, optional """ if not isinstance(key, str): illegal_value( key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) return None if not isinstance(l, list): illegal_value( l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) return None if any(isinstance(d, dict) for d in l): illegal_value( l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) return None keys = [d.get(key, None) for d in l] if None in keys or len(set(keys)) != len(l): if raise_error: raise ValueError( f'Duplicate or missing key ({key}) found in {l}') logger.error(f'Duplicate or missing key ({key}) found in {l}') return None return l
[docs] def assert_no_duplicate_attr_in_list_of_objs(l, attr, raise_error=False): """Assert that there are no duplicate attributes in a list of objects. :param l: Input list. :type l: list :param key: Attribute. :type key: str :param raise_error: Raise an exception upon any error, defaults to `False`. :type raise_error: bool, optional """ if not isinstance(attr, str): illegal_value( attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error) return None if not isinstance(l, list): illegal_value( l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error) return None attrs = [getattr(obj, attr, None) for obj in l] if None in attrs or len(set(attrs)) != len(l): if raise_error: raise ValueError( f'Duplicate or missing attr ({attr}) found in {l}') logger.error(f'Duplicate or missing attr ({attr}) found in {l}') return None return l
[docs] def file_exists_and_readable(filename): """Check if a file exists and is readable. :param f: File name. :type f: str :raise ValueError: Invalid file or inaccessible for reading. :return: File name. :rtype: str """ if not os.path.isfile(filename): raise ValueError(f'{filename} is not a valid file') if not os.access(filename, os.R_OK): raise ValueError(f'{filename} is not accessible for reading') return filename
[docs] def rolling_average( y, x=None, dtype=None, start=0, end=None, width=None, stride=None, num=None, average=True, mode='valid', use_convolve=None): """Returns the rolling sum or average of an array over the last dimension. :param y: Input data. :type y: array-like :param x: Independent dimension. :type x: array-like, optional :param dtype: Input data type, defaults to the type of `y` if average is `True` or numpy.float if not. :type dtype: numpy.dtype :param start: First array index, defaults to `0`. :type start: int, optional :param end: Last array index. :type end: int, optional :param width: Number of elements in rolling sum or average. :type width: int, optional :param stride: Stride in rolling sum or average. :type stride: int, optional :param num: Number of outputs of rolling sum or average. :type num: int, optional :param average: Compute the rolling average if `True` or the rolling sum otherwise. :type average: bool :param mode: Only return results for full sized windows if `"valid"`, include partial windows if `"full"`, defaults to `"valid"`. :type mode: Literal['valid', 'full'], optional :param use_convolve: Use numpy.convolve if `True`. :type use_convolve: bool, optional :raise ValueError: Invalid input parameters. :return: Rolling sum or average of the input array, optionally with their independent coordinates. :rtype: numpy.ndarray or (numpy.ndarray, numpy.ndarray) .. note:: Specify only one or two of `width`, `stride`, and `num`. """ y = np.asarray(y) y_shape = y.shape if y.ndim == 1: y = np.expand_dims(y, 0) else: y = y.reshape((np.prod(y.shape[0:-1]), y.shape[-1])) if x is not None: x = np.asarray(x) if x.ndim != 1: raise ValueError('Parameter "x" must be a 1D array-like') if x.size != y.shape[1]: raise ValueError(f'Dimensions of "x" and "y[1]" do not ' f'match ({x.size} vs {y.shape[1]})') if dtype is None: if average: dtype = y.dtype else: dtype = np.float32 if width is None and stride is None and num is None: raise ValueError('Invalid input parameters, specify at least one of ' '"width", "stride" or "num"') if width is not None and not is_int(width, ge=1): raise ValueError(f'Invalid "width" parameter ({width})') if stride is not None and not is_int(stride, ge=1): raise ValueError(f'Invalid "stride" parameter ({stride})') if num is not None and not is_int(num, ge=1): raise ValueError(f'Invalid "num" parameter ({num})') if not isinstance(average, bool): raise ValueError(f'Invalid "average" parameter ({average})') if mode not in ('valid', 'full'): raise ValueError(f'Invalid "mode" parameter ({mode})') size = y.shape[1] if size < 2: raise ValueError(f'Invalid y[1] dimension ({size})') if not is_int(start, ge=0, lt=size): raise ValueError(f'Invalid "start" parameter ({start})') if end is None: end = size elif not is_int(end, gt=start, le=size): raise ValueError(f'Invalid "end" parameter ({end})') if use_convolve is None: use_convolve = bool(len(y_shape) == 1) if use_convolve and (start or end < size): y = np.take(y, range(start, end), axis=1) if x is not None: x = x[start:end] size = y.shape[1] else: size = end-start if stride is None: if width is None: width = max(1, int(size/num)) stride = width else: width = min(width, size) if num is None: stride = width else: stride = max(1, int((size-width) / (num-1))) else: stride = min(stride, size-stride) if width is None: width = stride if mode == 'valid': num = 1 + max(0, int((size-width) / stride)) else: num = int(size/stride) if num*stride < size: num += 1 if use_convolve: n_start = 0 n_end = width weight = np.empty((num)) for n in range(num): n_num = n_end-n_start weight[n] = n_num n_start += stride n_end = min(size, n_end+stride) window = np.ones((width)) if x is not None: if mode == 'valid': rx = np.convolve(x, window)[width-1:1-width:stride] else: rx = np.convolve(x, window)[width-1::stride] rx /= weight ry = [] if mode == 'valid': for i in range(y.shape[0]): ry.append(np.convolve(y[i], window)[width-1:1-width:stride]) else: for i in range(y.shape[0]): ry.append(np.convolve(y[i], window)[width-1::stride]) ry = np.reshape(ry, (*y_shape[0:-1], num)) if len(y_shape) == 1: ry = np.squeeze(ry) if average: ry = (np.asarray(ry).astype(np.float32)/weight).astype(dtype) elif mode != 'valid': weight = np.where(weight < width, width/weight, 1.0) ry = (np.asarray(ry).astype(np.float32)*weight).astype(dtype) else: ry = np.zeros((num, y.shape[0]), dtype=y.dtype) if x is not None: rx = np.zeros(num, dtype=x.dtype) n_start = start n_end = n_start+width for n in range(num): y_sum = np.sum(y[:,n_start:n_end], 1) n_num = n_end-n_start if n_num < width: y_sum *= width/n_num ry[n] = y_sum if x is not None: rx[n] = np.sum(x[n_start:n_end])/n_num n_start += stride n_end = min(start+size, n_end+stride) ry = np.reshape(ry.T, (*y_shape[0:-1], num)) if len(y_shape) == 1: ry = np.squeeze(ry) if average: ry = (ry.astype(np.float32)/width).astype(dtype) if x is None: return ry return ry, rx
[docs] def baseline_arPLS( y, mask=None, w=None, tol=1.e-8, lam=1.e6, max_iter=20, full_output=False): """Returns the smoothed baseline estimate of a spectrum. Based on S.-J. Baek, A. Park, Y.-J Ahn, and J. Choo, "Baseline correction using asymmetrically reweighted penalized least squares smoothing", Analyst, 2015,140, 250-257 :param y: Spectrum. :type y: array-like :param mask: Mask to apply to the spectrum before baseline construction. :type mask: array-like, optional :param w: Weights (allows restart for additional iterations). :type w: numpy.array, optional :param tol: Convergence tolerence, defaults to `1.e-8`. :type tol: float, optional :param lam: &lambda (smoothness) parameter (the balance between the residual of the data and the baseline and the smoothness of the baseline). The suggested range is between 100 and 10^8, defaults to `10^6`. :type lam: float, optional :param max_iter: Maximum number of iterations, defaults to `20`. :type max_iter: int, optional :param full_output: Whether or not to also output the baseline corrected spectrum, the number of iterations and error in the returned result, defaults to `False`. :type full_output: bool, optional :return: Smoothed baseline, with optionally the baseline corrected spectrum, the weights, the number of iterations and the error in the returned result. :rtype: numpy.array [, numpy.array, int, float] """ # With credit to: Daniel Casas-Orozco # https://stackoverflow.com/questions/29156532/python-baseline-correction-library # Third party modules from scipy.sparse import ( spdiags, linalg, ) if not is_num(tol, gt=0): raise ValueError(f'Invalid tol parameter ({tol})') if not is_num(lam, gt=0): raise ValueError(f'Invalid lam parameter ({lam})') if not is_int(max_iter, gt=0): raise ValueError(f'Invalid max_iter parameter ({max_iter})') if not isinstance(full_output, bool): raise ValueError(f'Invalid full_output parameter ({max_iter})') y = np.asarray(y) if mask is not None: mask = mask.astype(bool) y_org = y y = y[mask] num = y.size diag = np.ones((num-2)) D = spdiags([diag, -2*diag, diag], [0, -1, -2], num, num-2) H = lam * D.dot(D.T) if w is None: w = np.ones(num) W = spdiags(w, 0, num, num) error = 1 num_iter = 0 exp_max = int(np.log(sys.float_info.max)) while error > tol and num_iter < max_iter: z = linalg.spsolve(W + H, W * y) d = y - z dn = d[d < 0] m = np.mean(dn) s = np.std(dn) w_new = 1.0 / (1.0 + np.exp( np.clip(2.0 * (d - (2.0*s - m))/s, None, exp_max))) error = np.linalg.norm(w_new - w) / np.linalg.norm(w) num_iter += 1 w = w_new W.setdiag(w) if mask is not None: zz = np.zeros(y_org.size) zz[mask] = z z = zz if full_output: d = y_org - z if full_output: return z, d, w, num_iter, float(error) return z
[docs] def fig_to_iobuf(fig, fileformat=None): """Return an in-memory object as a byte stream represention of a Matplotlib figure. :param fig: Matplotlib figure object. :type fig: matplotlib.figure.Figure :param fileformat: Valid Matplotlib saved figure file format, defaults to `'png'`. :type fileformat: str, optional :return: Byte stream representation of the Matplotlib figure and the associated file format. :rtype: tuple[_io.BytesIO, str] """ # System modules from io import BytesIO if fileformat is None: fileformat = 'png' else: if fileformat not in plt.gcf().canvas.get_supported_filetypes(): fileformat = 'png' buf = BytesIO() fig.savefig(buf, format=fileformat) return buf, fileformat
[docs] def save_iobuf_fig(buf, filename, force_overwrite=False): """Save a byte stream represention of a Matplotlib figure to file. :param buf: Byte stream representation of the Matplotlib figure. :type buf: _io.BytesIO :param filename: Filename (with a valid extension). :type filename: str :param force_overwrite: Flag to allow `filename` to be overwritten if it already exists, defaults to `False`. :type force_overwrite: bool, optional :raises RuntimeError: If a file already exists and `force_overwrite` is `False`. """ # Third party modules from PIL import Image exts = Image.registered_extensions() exts = {ex for ex, f in exts.items() if f in Image.SAVE} # Validate filename and extension _, ext = os.path.splitext(filename) if not ext or ext not in exts: raise ValueError(f'Invalid file format {ext[1:]}') filedir = os.path.dirname(filename) if not os.path.isdir(filedir): os.makedirs(filedir) if os.path.isfile(filename) and not force_overwrite: raise FileExistsError(f'{filename} already exists') # Write image to file buf.seek(0) img = Image.open(buf) img.save(filename)
[docs] def select_mask_1d( y, x=None, preselected_index_ranges=None, preselected_mask=None, title=None, xlabel=None, ylabel=None, min_num_index_ranges=None, max_num_index_ranges=None, interactive=True, filename=None, return_buf=False): """Display a lineplot and have the user select a mask. :param y: One-dimensional data array for which a mask will be constructed. :type y: array-like :param x: x-coordinates of the reference data. :type x: array-like, optional :param preselected_index_ranges: List of preselected index ranges to mask (bounds are inclusive). :type preselected_index_ranges: list[tuple(int, int)] or list[list[int]] or list[tuple(float, float)] or list[list[float]]), optional :param preselected_mask: Preselected boolean mask array. :type preselected_mask: array-like, optional :param title: Title for the displayed figure. :type title: str, optional :param xlabel: Label for the x-axis of the displayed figure. :type xlabel: str, optional :param ylabel: Label for the y-axis of the displayed figure. :type ylabel: str, optional :param min_num_index_ranges: Minimum number of selected index ranges. :type min_num_index_ranges: int, optional :param max_num_index_ranges: Maximum number of selected index ranges. :type max_num_index_ranges: int, optional :param interactive: Show the plot and allow user interactions with the matplotlib figure, defaults to `True`. :type interactive: bool, optional :param filename: Save a .png of the plot to filename, defaults to `None`, in which case the plot is not saved. :type filename: str, optional :param return_buf: Return an in-memory object as a byte stream represention of the Matplotlib figure, defaults to `False`. :type return_buf: bool, optional :return: Byte stream represention of the Matplotlib figure if return_buf is `True` (`None` otherwise), a boolean mask array, and the list of selected index ranges. :rtype: tuple[_io.BytesIO, str] or `None`, numpy.ndarray, list[list[int, int]] """ # Third party modules # pylint: disable=possibly-used-before-assignment if interactive or filename is not None or return_buf: from matplotlib.patches import Patch from matplotlib.widgets import Button, SpanSelector def _change_fig_title(title): if fig_title: fig_title[0].remove() fig_title.pop() fig_title.append(plt.figtext(*title_pos, title, **title_props)) def _change_error_text(error): if error_texts: error_texts[0].remove() error_texts.pop() error_texts.append(plt.figtext(*error_pos, error, **error_props)) def _get_selected_index_ranges(change_fnc=None, title=''): selected_index_ranges = sorted( [[index_nearest(x, span.extents[0]), index_nearest(x, span.extents[1])+1] for span in spans]) if change_fnc is not None: if len(selected_index_ranges) > 1: change_fnc( f'{title}Selected ROIs: {selected_index_ranges}') elif selected_index_ranges: change_fnc( f'{title}Selected ROI: {tuple(selected_index_ranges[0])}') else: change_fnc(f'{title}Selected ROI: None') return selected_index_ranges def add_span(event, xrange_init=None): """Callback function for the "Add span" button.""" if (max_num_index_ranges is not None and len(spans) >= max_num_index_ranges): _change_error_text( 'Exceeding max number of ranges, adjust an existing ' 'range or click "Reset"/"Confirm"') else: spans.append( SpanSelector( ax, select_span, 'horizontal', props=included_props, useblit=True, interactive=interactive, drag_from_anywhere=True, ignore_event_outside=True, grab_range=5)) if xrange_init is None: xmin_init, xmax_init = min(x), 0.05*(max(x)-min(x)) else: xmin_init, xmax_init = xrange_init spans[-1]._selection_completed = True spans[-1].extents = (xmin_init, xmax_init) spans[-1].onselect(xmin_init, xmax_init) plt.draw() def select_span(xmin, xmax): """Callback function for the SpanSelector widget.""" combined_spans = True while combined_spans: combined_spans = False for i, span1 in enumerate(spans): for span2 in spans[i+1:]: if (span1.extents[1] >= span2.extents[0] and span1.extents[0] <= span2.extents[1]): _change_error_text( 'Combined overlapping spans in currently ' 'selected mask') span2.extents = ( min(span1.extents[0], span2.extents[0]), max(span1.extents[1], span2.extents[1])) span1.set_visible(False) spans.remove(span1) combined_spans = True break if combined_spans: break _get_selected_index_ranges(_change_error_text) plt.draw() def reset(event): """Callback function for the "Reset" button.""" if error_texts: error_texts[0].remove() error_texts.pop() for span in reversed(spans): span.set_visible(False) spans.remove(span) _get_selected_index_ranges(_change_error_text) plt.draw() def confirm(event): """Callback function for the "Confirm" button.""" if (min_num_index_ranges is not None and len(spans) < min_num_index_ranges): _change_error_text( f'Select at least {min_num_index_ranges} unique index ranges') plt.draw() else: if error_texts: error_texts[0].remove() error_texts.pop() _get_selected_index_ranges(_change_fig_title, title) plt.close() def update_mask(mask, selected_index_ranges): """Update the mask with the selected index ranges.""" for min_, max_ in selected_index_ranges: mask = np.logical_or( mask, np.logical_and(x >= x[min_], x <= x[min(max_, num_data-1)])) return mask def update_index_ranges(mask): """Update the selected index ranges (where mask = True).""" selected_index_ranges = [] for i, m in enumerate(mask): if m: if (not selected_index_ranges or isinstance(selected_index_ranges[-1], tuple)): selected_index_ranges.append(i) else: if (selected_index_ranges and isinstance(selected_index_ranges[-1], Int)): selected_index_ranges[-1] = \ (selected_index_ranges[-1], i-1) if (selected_index_ranges and isinstance(selected_index_ranges[-1], Int)): selected_index_ranges[-1] = (selected_index_ranges[-1], num_data-1) return selected_index_ranges # Check inputs y = np.asarray(y) if y.ndim > 1: raise ValueError(f'Invalid y dimension ({y.ndim})') num_data = y.size if x is None: x = np.arange(num_data)+0.5 else: x = np.asarray(x, dtype=np.float64) if x.ndim > 1 or x.size != num_data: raise ValueError(f'Invalid x shape ({x.shape})') if not np.all(x[:-1] < x[1:]): raise ValueError('Invalid x: must be monotonically increasing') if title is None: title = '' else: title = f'{title}: ' if preselected_index_ranges is None: preselected_index_ranges = [] else: if not isinstance(preselected_index_ranges, list): raise ValueError('Invalid parameter preselected_index_ranges ' f'({preselected_index_ranges})') if interactive or filename is not None or return_buf: index_ranges = [] for v in preselected_index_ranges: if not is_num_pair(v): raise ValueError( 'Invalid parameter preselected_index_ranges ' f'({preselected_index_ranges})') index_ranges.append( (max(0, int(v[0])), min(num_data, int(v[1])-1))) preselected_index_ranges = index_ranges # Setup the preselected mask and index ranges if provided if preselected_mask is not None: preselected_index_ranges = update_index_ranges( update_mask( np.copy(np.asarray(preselected_mask, dtype=bool)), preselected_index_ranges)) if not interactive and filename is None and not return_buf: # Update the mask with the preselected index ranges selected_mask = update_mask(len(x)*[False], preselected_index_ranges) return None, selected_mask, preselected_index_ranges spans = [] fig_title = [] error_texts = [] # Setup the Matplotlib figure title_pos = (0.5, 0.95) title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} error_pos = (0.5, 0.90) error_props = {'fontsize': 'x-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} excluded_props = { 'facecolor': 'white', 'edgecolor': 'gray', 'linestyle': ':'} included_props = { 'alpha': 0.5, 'facecolor': 'tab:blue', 'edgecolor': 'blue'} fig, ax = plt.subplots(figsize=(11, 8.5)) handles = ax.plot(x, y, color='k', label='Reference Data') handles.append(Patch( label='Excluded / unselected ranges', **excluded_props)) handles.append(Patch( label='Included / selected ranges', **included_props)) ax.legend(handles=handles) ax.set_xlabel(xlabel, fontsize='x-large') ax.set_ylabel(ylabel, fontsize='x-large') ax.set_xlim(x[0], x[-1]) fig.subplots_adjust(bottom=0.0, top=0.85) # Add the preselected index ranges for min_, max_ in preselected_index_ranges: add_span(None, xrange_init=(x[min_], x[min(max_, num_data-1)])) if not interactive: _get_selected_index_ranges(_change_fig_title, title) if error_texts: error_texts[0].remove() error_texts.pop() else: _change_fig_title(f'{title}Click and drag to select ranges') _get_selected_index_ranges(_change_error_text) fig.subplots_adjust(bottom=0.2) # Setup "Add span" button add_span_btn = Button( plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span') add_span_cid = add_span_btn.on_clicked(add_span) # Setup "Reset" button reset_btn = Button(plt.axes([0.45, 0.05, 0.15, 0.075]), 'Reset') reset_cid = reset_btn.on_clicked(reset) # Setup "Confirm" button confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') confirm_cid = confirm_btn.on_clicked(confirm) # Show figure for user interaction plt.show() # Disconnect all widget callbacks when figure is closed add_span_btn.disconnect(add_span_cid) reset_btn.disconnect(reset_cid) confirm_btn.disconnect(confirm_cid) # ...and remove the buttons before returning the figure add_span_btn.ax.remove() reset_btn.ax.remove() confirm_btn.ax.remove() plt.subplots_adjust(bottom=0.0) selected_index_ranges = _get_selected_index_ranges() # Update the mask with the currently selected index ranges selected_mask = update_mask(len(x)*[False], selected_index_ranges) buf = None if filename is not None or return_buf: if interactive: if len(selected_index_ranges) > 1: title += f'Selected ROIs: {selected_index_ranges}' else: title += f'Selected ROI: {tuple(selected_index_ranges[0])}' fig_title[0]._text = title fig_title[0].set_in_layout(True) fig.tight_layout(rect=(0, 0, 1, 0.95)) if filename is not None: fig.savefig(filename) if return_buf: buf = fig_to_iobuf(fig) plt.close() return buf, selected_mask, selected_index_ranges
[docs] def select_roi_1d( y, x=None, preselected_roi=None, title=None, xlabel=None, ylabel=None, interactive=True, filename=None, return_buf=False): """Display a 2D plot and have the user select a single region of interest. :param y: One-dimensional data array for which a for which a region of interest will be selected. :type y: array-like :param x: x-coordinates of the data :type x: array-like, optional :param preselected_roi: Preselected region of interest. :type preselected_roi: tuple(int, int), optional :param title: Title for the displayed figure. :type title: str, optional :param xlabel: Label for the x-axis of the displayed figure. :type xlabel: str, optional :param ylabel: Label for the y-axis of the displayed figure. :type ylabel: str, optional :param interactive: Show the plot and allow user interactions with the matplotlib figure, defaults to `True`. :type interactive: bool, optional :param filename: Save a .png of the plot to filename, defaults to `None`, in which case the plot is not saved. :type filename: str, optional :param return_buf: Return an in-memory object as a byte stream represention of the Matplotlib figure, defaults to `False`. :type return_buf: bool, optional :return: Byte stream represention of the Matplotlib figure if return_buf is `True` (`None` otherwise), and the selected region of interest. :rtype: io.BytesIO or `None`, tuple(int, int) """ # Check inputs y = np.asarray(y) if y.ndim != 1: raise ValueError(f'Invalid image dimension ({y.ndim})') if preselected_roi is not None: if not is_int_pair(preselected_roi, ge=0, le=y.size, log=False): raise ValueError('Invalid parameter preselected_roi ' f'({preselected_roi})') preselected_roi = [preselected_roi] buf, _, roi = select_mask_1d( y, x=x, preselected_index_ranges=preselected_roi, title=title, xlabel=xlabel, ylabel=ylabel, min_num_index_ranges=1, max_num_index_ranges=1, interactive=interactive, filename=filename, return_buf=return_buf) return buf, tuple(roi[0])
[docs] def select_roi_2d( a, preselected_roi=None, title=None, title_a=None, row_label='row index', column_label='column index', interactive=True, filename=None, return_buf=False): """Display a 2D image and have the user select a single rectangular region of interest. :param a: Two-dimensional image data array for which a region of interest will be selected. :type a: array-like :param preselected_roi: Preselected region of interest. :type preselected_roi: tuple(int, int, int, int), optional :param title: Title for the displayed figure. :type title: str, optional :param title_a: Title for the image of a. :type title_a: str, optional :param row_label: Label for the y-axis of the displayed figure, defaults to `row index`. :type row_label: str, optional :param column_label: Label for the x-axis of the displayed figure, defaults to `column index`. :type column_label: str, optional :param interactive: Show the plot and allow user interactions with the matplotlib figure, defaults to `True`. :type interactive: bool, optional :param filename: Save a .png of the plot to filename, defaults to `None`, in which case the plot is not saved. :type filename: str, optional :param return_buf: Return an in-memory object as a byte stream represention of the Matplotlib figure, defaults to `False`. :type return_buf: bool, optional :return: Byte stream represention of the Matplotlib figure if return_buf is `True` (`None` otherwise), and the selected region of interest. :rtype: tuple[_io.BytesIO, str] or `None`, tuple(int, int, int, int) """ # Third party modules # pylint: disable=possibly-used-before-assignment if interactive or filename is not None or return_buf: from matplotlib.widgets import Button, RectangleSelector def _change_fig_title(title): if fig_title: fig_title[0].remove() fig_title.pop() fig_title.append(plt.figtext(*title_pos, title, **title_props)) def _change_subfig_title(error): if subfig_title: subfig_title[0].remove() subfig_title.pop() subfig_title.append(plt.figtext(*error_pos, error, **error_props)) def _clear_selection(): rects[0].set_visible(False) rects.pop() rects.append( RectangleSelector( ax, on_rect_select, props=rect_props, useblit=True, interactive=interactive, drag_from_anywhere=True, ignore_event_outside=False)) def on_rect_select(eclick, erelease): """Callback function for the RectangleSelector widget.""" if (not int(rects[0].extents[1]) - int(rects[0].extents[0]) or not int(rects[0].extents[3]) - int(rects[0].extents[2])): _clear_selection() _change_subfig_title( 'Selected ROI too small, try again') else: _change_subfig_title( f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}') plt.draw() def reset(event): """Callback function for the "Reset" button.""" if subfig_title: subfig_title[0].remove() subfig_title.pop() _clear_selection() plt.draw() def confirm(event): """Callback function for the "Confirm" button.""" if subfig_title: subfig_title[0].remove() subfig_title.pop() roi = tuple(int(v) for v in rects[0].extents) if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1: roi = None _change_fig_title(f'Selected ROI: {roi}') plt.close() # Check inputs a = np.asarray(a) if a.ndim != 2: raise ValueError(f'Invalid image dimension ({a.ndim})') if preselected_roi is not None: if (not is_int_series(preselected_roi, ge=0, log=False) or len(preselected_roi) != 4): raise ValueError('Invalid parameter preselected_roi ' f'({preselected_roi})') if title is None: title = 'Click and drag to select or adjust a region of interest (ROI)' if not interactive and filename is None and not return_buf: return None, preselected_roi fig_title = [] subfig_title = [] title_pos = (0.5, 0.95) title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} error_pos = (0.5, 0.90) error_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} rect_props = { 'alpha': 0.5, 'facecolor': 'tab:blue', 'edgecolor': 'blue'} fig, ax = plt.subplots(figsize=(11, 8.5)) ax.imshow(a) ax.set_title(title_a, fontsize='xx-large') ax.set_xlabel(column_label, fontsize='x-large') ax.set_ylabel(row_label, fontsize='x-large') ax.set_xlim(0, a.shape[1]) ax.set_ylim(a.shape[0], 0) fig.subplots_adjust(bottom=0.0, top=0.85) # Setup the preselected range of interest if provided rects = [RectangleSelector( ax, on_rect_select, props=rect_props, useblit=True, interactive=interactive, drag_from_anywhere=True, ignore_event_outside=True)] if preselected_roi is not None: rects[0].extents = preselected_roi if not interactive: if preselected_roi is not None: _change_fig_title( f'Selected ROI: {tuple(int(v) for v in preselected_roi)}') else: _change_fig_title(title) if preselected_roi is not None: _change_subfig_title( f'Preselected ROI: {tuple(int(v) for v in preselected_roi)}') fig.subplots_adjust(bottom=0.2) # Setup "Reset" button reset_btn = Button(plt.axes([0.125, 0.05, 0.15, 0.075]), 'Reset') reset_cid = reset_btn.on_clicked(reset) # Setup "Confirm" button confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') confirm_cid = confirm_btn.on_clicked(confirm) # Show figure for user interaction plt.show() # Disconnect all widget callbacks when figure is closed reset_btn.disconnect(reset_cid) confirm_btn.disconnect(confirm_cid) # ... and remove the buttons before returning the figure reset_btn.ax.remove() confirm_btn.ax.remove() buf = None if filename is not None or return_buf: if fig_title: fig_title[0].set_in_layout(True) fig.tight_layout(rect=(0, 0, 1, 0.95)) else: fig.tight_layout(rect=(0, 0, 1, 1)) # Remove the handles if interactive: rects[0]._center_handle.set_visible(False) rects[0]._corner_handles.set_visible(False) rects[0]._edge_handles.set_visible(False) if filename is not None: fig.savefig(filename) if return_buf: buf = fig_to_iobuf(fig) plt.close() roi = tuple(int(v) for v in rects[0].extents) if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1: roi = None return buf, roi
[docs] def select_image_indices( a, axis, b=None, preselected_indices=None, axis_index_offset=0, min_range=None, min_num_indices=2, max_num_indices=2, title=None, title_a=None, title_b=None, row_label='row index', column_label='column index', interactive=True, return_buf=False): """Display a 2D image and have the user select a set of image indices in either row or column direction. :param a: Two-dimensional image data array for which a region of interest will be selected. :type a: array-like :param axis: Selection direction (0: row, 1: column) :type axis: int :param b: Secondary two-dimensional image data array for which a shared region of interest will be selected. :type b: array-like, optional :param preselected_indices: Preselected image indices. :type preselected_indices: tuple(int), list(int), optional :param axis_index_offset: Offset in axis index range and preselected indices, defaults to `0`. :type axis_index_offset: int, optional :param min_range: Minimal range spanned by the selected indices. :type min_range: int, optional :param min_num_indices: Minimum number of selected indices. :type min_num_indices: int, optional :param max_num_indices: Maximum number of selected indices. :type max_num_indices: int, optional :param title: Title for the displayed figure. :type title: str, optional :param title_a: Title for the image of a. :type title_a: str, optional :param title_b: Title for the image of b. :type title_b: str, optional :param row_label: Label for the y-axis of the displayed figure, defaults to `row index`. :type row_label: str, optional :param column_label: Label for the x-axis of the displayed figure, defaults to `column index`. :type column_label: str, optional :param interactive: Show the plot and allow user interactions with the matplotlib figure, defaults to `True`. :type interactive: bool, optional :param return_buf: Return an in-memory object as a byte stream represention of the Matplotlib figure instead of the matplotlib figure, defaults to `False`. :type return_buf: bool, optional :return: Selected region of interest as array indices and a matplotlib figure. :rtype: tuple[_io.BytesIO, str] or `None`, tuple(int, int, int, int) or `None`. """ # Third party modules from matplotlib.widgets import TextBox, Button index_input = None def _change_fig_title(title): if fig_title: fig_title[0].remove() fig_title.pop() fig_title.append(plt.figtext(*title_pos, title, **title_props)) def _change_error_text(error): if error_texts: error_texts[0].remove() error_texts.pop() error_texts.append(plt.figtext(*error_pos, error, **error_props)) def get_selected_indices(change_fnc=None): """Get the selected indices.""" selected_indices = tuple(sorted(indices)) if change_fnc is not None: num_indices = len(indices) if len(selected_indices) > 1: text = f'Selected {row_column} indices: {selected_indices}' elif selected_indices: text = f'Selected {row_column} index: {selected_indices[0]}' else: text = f'Selected {row_column} indices: None' if min_num_indices is not None and num_indices < min_num_indices: if min_num_indices == max_num_indices: text += \ f', select another {max_num_indices-num_indices}' else: text += \ f', select at least {max_num_indices-num_indices} more' change_fnc(text) return selected_indices def add_index(index): """Add an index.""" if index in indices: raise ValueError(f'Ignoring duplicate of selected {row_column}s') if max_num_indices is not None and len(indices) >= max_num_indices: raise ValueError( f'Exceeding maximum number of selected {row_column}s, click ' 'either "Reset" or "Confirm"') if (indices and min_range is not None and abs(max(index, *indices) - min(index, *indices)) < min_range): raise ValueError( f'Selected {row_column} range is smaller than required ' 'minimal range of {min_range}: ignoring last selection') indices.append(index) if not axis: for ax in axs: lines.append(ax.axhline(indices[-1], c='r', lw=2)) else: for ax in axs: lines.append(ax.axvline(indices[-1], c='r', lw=2)) def select_index(expression): """Callback function for the "Select row/column index" TextBox. """ if not expression: return if error_texts: error_texts[0].remove() error_texts.pop() try: index = int(expression) if (index < axis_index_offset or index > axis_index_offset+a.shape[axis]): raise ValueError except ValueError: _change_error_text( f'Invalid {row_column} index ({expression}), enter an integer ' f'between {axis_index_offset} and ' f'{axis_index_offset+a.shape[axis]-1}') else: try: add_index(index) get_selected_indices(_change_error_text) except ValueError as exc: _change_error_text(exc) index_input.set_val('') for ax in axs: ax.get_figure().canvas.draw() def reset(event): """Callback function for the "Reset" button.""" if error_texts: error_texts[0].remove() error_texts.pop() for line in reversed(lines): line.remove() indices.clear() lines.clear() get_selected_indices(_change_error_text) for ax in axs: ax.get_figure().canvas.draw() def confirm(event): """Callback function for the "Confirm" button.""" if len(indices) < min_num_indices: _change_error_text( f'Select at least {min_num_indices} unique {row_column}s') for ax in axs: ax.get_figure().canvas.draw() else: # Remove error texts and add selected indices if set if error_texts: error_texts[0].remove() error_texts.pop() get_selected_indices(_change_fig_title) plt.close() # Check inputs a = np.asarray(a) if a.ndim != 2: raise ValueError(f'Invalid image dimension ({a.ndim})') if axis < 0 or axis >= a.ndim: raise ValueError(f'Invalid parameter axis ({axis})') if not axis: row_column = 'row' else: row_column = 'column' if not is_int(axis_index_offset, ge=0, log=False): raise ValueError( 'Invalid parameter axis_index_offset ({axis_index_offset})') if preselected_indices is not None: if not is_int_series( preselected_indices, ge=axis_index_offset, le=axis_index_offset+a.shape[axis], log=False): if interactive: logger.warning( 'Invalid parameter preselected_indices ' f'({preselected_indices}), ignoring preselected_indices') preselected_indices = None else: raise ValueError('Invalid parameter preselected_indices ' f'({preselected_indices})') if min_range is not None and not 2 <= min_range <= a.shape[axis]: raise ValueError('Invalid parameter min_range ({min_range})') if title is None: title = f'Select or adjust image {row_column} indices' if b is not None: b = np.asarray(b) if b.ndim != 2: raise ValueError(f'Invalid image dimension ({b.ndim})') if a.shape[0] != b.shape[0]: raise ValueError(f'Inconsistent image shapes({a.shape} vs ' f'{b.shape})') indices = [] lines = [] fig_title = [] error_texts = [] title_pos = (0.5, 0.95) title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} error_pos = (0.5, 0.90) error_props = {'fontsize': 'x-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} if b is None: fig, axs = plt.subplots(figsize=(11, 8.5)) axs = [axs] else: if a.shape[0]+b.shape[0] > max(a.shape[1], b.shape[1]): fig, axs = plt.subplots(1, 2, figsize=(11, 8.5)) else: fig, axs = plt.subplots(2, 1, figsize=(11, 8.5)) extent = (0, a.shape[1], axis_index_offset+a.shape[0], axis_index_offset) axs[0].imshow(a, extent=extent) axs[0].set_title(title_a, fontsize='xx-large') if b is not None: axs[1].imshow(b, extent=extent) axs[1].set_title(title_b, fontsize='xx-large') if a.shape[0]+b.shape[0] > max(a.shape[1], b.shape[1]): axs[0].set_xlabel(column_label, fontsize='x-large') axs[0].set_ylabel(row_label, fontsize='x-large') axs[1].set_xlabel(column_label, fontsize='x-large') else: axs[0].set_ylabel(row_label, fontsize='x-large') axs[1].set_xlabel(column_label, fontsize='x-large') axs[1].set_ylabel(row_label, fontsize='x-large') for ax in axs: ax.set_xlim(extent[0], extent[1]) ax.set_ylim(extent[2], extent[3]) fig.subplots_adjust(bottom=0.0, top=0.85) # Setup the preselected indices if provided if preselected_indices is not None: preselected_indices = sorted(list(preselected_indices)) for index in preselected_indices: add_index(index) if not interactive: get_selected_indices(_change_fig_title) else: _change_fig_title(title) get_selected_indices(_change_error_text) fig.subplots_adjust(bottom=0.2) # Setup TextBox index_input = TextBox( plt.axes([0.25, 0.05, 0.15, 0.075]), f'Select {row_column} index ') indices_cid = index_input.on_submit(select_index) # Setup "Reset" button reset_btn = Button(plt.axes([0.5, 0.05, 0.15, 0.075]), 'Reset') reset_cid = reset_btn.on_clicked(reset) # Setup "Confirm" button confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') confirm_cid = confirm_btn.on_clicked(confirm) plt.show() # Disconnect all widget callbacks when figure is closed index_input.disconnect(indices_cid) reset_btn.disconnect(reset_cid) confirm_btn.disconnect(confirm_cid) # ... and remove the buttons before returning the figure index_input.ax.remove() reset_btn.ax.remove() confirm_btn.ax.remove() fig_title[0].set_in_layout(True) fig.tight_layout(rect=(0, 0, 1, 0.95)) if return_buf: buf = fig_to_iobuf(fig) else: buf = None plt.close() if indices: return buf, tuple(sorted(indices)) return buf, None
[docs] def quick_imshow( a, title=None, row_label='row index', column_label='column index', path=None, name=None, show_fig=True, save_fig=False, return_fig=False, block=None, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, colorbar=False, **kwargs): """Display and or save a 2D `Matplotlib <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html>`__ image and or return an in-memory object as a byte stream represention. :param a: Input array. :type a: array-like :param title: Graph title. :type title: str, optional :param row_label: Row label title. :type row_label: str, optional :param column_label: Column label title. :type column_label: str, optional :param path: File path to save image to (ignored if `save_fig` is `False`). :type path: str, optional :param name: File name of image (ignored if `save_fig` is `False`). :type name: str, optional :param show_fig: Display image, defaults to `True`. :type show_fig: bool, optional :param save_fig: Save image to file, defaults to `False`. :type save_fig: bool, optional :param return_fig: Return an in-memory object as a byte stream represention of the Matplotlib image, defaults to `False`. :type return_fig: bool, optional :param block: Wait for the image to be closed before returning. :type block: bool, optional :param extent: Bounding box in data coordinates that the image will fill. :type extent: floats (left, right, bottom, top), optional :param show_grid: Show grid lines, defaults to `False`. :type show_grid: bool, optional :param grid_color: Grid color, defaults to `"w"` or white. :type grid_color: str, optional :param grid_linewidth: Grid line width, defaults to `1`. :type grid_linewidth: int, optional :param colorbar: Include a colorbar, defaults to `False`. :type colorbar: bool, optional :param kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.imshow <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html>`__. :type kwargs: dict, optional :raise: ValueError for invalid input data or parameters. :return: In-memory object as a byte stream represention if `return_fig` is set. :rtype: tuple[_io.BytesIO, str] or `None` """ if title is not None and not isinstance(title, str): raise ValueError(f'Invalid parameter title ({title})') if path is not None and not isinstance(path, str): raise ValueError(f'Invalid parameter path ({path})') if not isinstance(show_fig, bool): raise ValueError(f'Invalid parameter show_fig ({show_fig})') if not isinstance(save_fig, bool): raise ValueError(f'Invalid parameter save_fig ({save_fig})') if not isinstance(return_fig, bool): raise ValueError(f'Invalid parameter return_fig ({return_fig})') if block is not None and not isinstance(block, bool): raise ValueError(f'Invalid parameter block ({block})') if not title: title = 'quick imshow' if ('cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4)): use_cmap = True if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max(): use_cmap = False if any( a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] for i in range(a.shape[0]) for j in range(a.shape[1])): use_cmap = False if use_cmap: a = a[:,:,0] else: logger.warning('Image incompatible with cmap option, ignore cmap') kwargs.pop('cmap') if extent is None: extent = (0, a.shape[1], a.shape[0], 0) plt.ioff() fig, ax = plt.subplots(figsize=(11, 8.5)) im = ax.imshow(a, extent=extent, **kwargs) ax.set_title(title, fontsize='xx-large') ax.set_xlabel(column_label, fontsize='x-large') ax.set_ylabel(row_label, fontsize='x-large') if colorbar: fig.colorbar(im, ax=ax) if show_grid: ax.grid(color=grid_color, linewidth=grid_linewidth) if show_fig: plt.show(block=block) if save_fig: if name is None: title = re.sub(r'\s+', '_', title) if path is None: path = title else: path = f'{path}/{title}' else: if path is None: path = name else: path = f'{path}/{name}' if (os.path.splitext(path)[1] not in plt.gcf().canvas.get_supported_filetypes()): path += '.png' plt.savefig(path) if return_fig: buf = fig_to_iobuf(fig) else: buf = None plt.close() return buf
[docs] def quick_plot( *args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, save_fig=False, save_only=False, block=False, **kwargs): """Display a 2D `Matplotlib <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html>`__ line plot. :param args: Tuple or tuple of tuples of input x-coordinates (optional), y-coordinates, and optional line formatting parameters. :type args: [x], y, [fmt] [, [x2], y2, [fmt2], ...] :param xerr: Errors in x-coordinates. :type: array-like or float, optional :param yerr: Errors in y-coordinates. :type: array-like or float, optional :param vlines: List of vertical lines to add to plot. :type vlines: list[float], optional :param title: Graph title. :type title: str, optional :param xlim: x-axis view limits. :type xlim: list[float, float], optional :param ylim: y-axis view limits. :type ylim: list[float, float], optional :param xlabel: x-axis label. :type xlabel: str, optional :param ylabel: y-axis label. :type ylabel: str, optional :param path: File path to save image to (ignored if `save_fig` is `False`). :type path: str, optional :param name: File name of image (ignored if `save_fig` is `False`). :type name: str, optional :param show_grid: Show grid lines, defaults to `False`. :type show_grid: bool, optional :param save_fig: Save image to file, defaults to `False`. :type save_fig: bool, optional :param save_only: Don not display the figure, only save it to file, defaults to `False`. :type save_only: bool, optional :param block: Wait for the image to be closed before returning. :type block: bool, optional :param kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.plot <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html>`__ :type kwargs: dict, optional :raise: ValueError for invalid input data or parameters. """ #FIX: Update with return_buf if title is not None and not isinstance(title, str): illegal_value(title, 'title', 'quick_plot') title = None if (xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2): illegal_value(xlim, 'xlim', 'quick_plot') xlim = None if (ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2): illegal_value(ylim, 'ylim', 'quick_plot') ylim = None if xlabel is not None and not isinstance(xlabel, str): illegal_value(xlabel, 'xlabel', 'quick_plot') xlabel = None if ylabel is not None and not isinstance(ylabel, str): illegal_value(ylabel, 'ylabel', 'quick_plot') ylabel = None if legend is not None and not isinstance(legend, (tuple, list)): illegal_value(legend, 'legend', 'quick_plot') legend = None if path is not None and not isinstance(path, str): illegal_value(path, 'path', 'quick_plot') return if not isinstance(show_grid, bool): illegal_value(show_grid, 'show_grid', 'quick_plot') return if not isinstance(save_fig, bool): illegal_value(save_fig, 'save_fig', 'quick_plot') return if not isinstance(save_only, bool): illegal_value(save_only, 'save_only', 'quick_plot') return if not isinstance(block, bool): illegal_value(block, 'block', 'quick_plot') return if title is None: title = 'quick plot' if name is None: ttitle = re.sub(r'\s+', '_', title) if path is None: path = f'{ttitle}.png' else: path = f'{path}/{ttitle}.png' else: if path is None: path = name else: path = f'{path}/{name}' args = unwrap_tuple(args) if depth_tuple(args) > 1 and (xerr is not None or yerr is not None): logger.warning('Error bars ignored for multiple curves') if not save_only: if block: plt.ioff() else: plt.ion() plt.figure(title) if depth_tuple(args) > 1: for y in args: plt.plot(*y, **kwargs) else: if xerr is None and yerr is None: plt.plot(*args, **kwargs) else: plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs) if vlines is not None: if isinstance(vlines, Num): vlines = [vlines] for v in vlines: plt.axvline(v, color='r', linestyle='--', **kwargs) if xlim is not None: plt.xlim(xlim) if ylim is not None: plt.ylim(ylim) if xlabel is not None: plt.xlabel(xlabel) if ylabel is not None: plt.ylabel(ylabel) if show_grid: ax = plt.gca() ax.grid(color='k') # , linewidth=1) if legend is not None: plt.legend(legend) if save_only: plt.savefig(path) plt.close(fig=title) else: if save_fig: plt.savefig(path) plt.show(block=block) plt.close()
[docs] def nxcopy( nxobject, exclude_nxpaths=None, nxpath_prefix=None, nxpathabs_prefix=None, nxpath_copy_abspath=None, nxgroup_to_nxdata=False): """Function that returns a copy of a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__ optionally exluding certain child items. :param nxobject: Input nexus object to "copy". :type nxobject: nexusformat.nexus.NXobject :param exlude_nxpaths: List of relative paths to child nexus objects that should be excluded from the returned "copy". :type exclude_nxpaths: str or list[str], optional :param nxpath_prefix: For use in recursive calls from inside this function only. :type nxpath_prefix: str, optional :param nxpathabs_prefix: For use in recursive calls from inside this function only. :type nxpathabs_prefix: str, optional :param nxpath_copy_abspath: For use in recursive calls from inside this function only. :type nxpath_copy_abspath: str, optional :return: Copy of the input `nxobject` with some children optionally exluded. :rtype: nexusformat.nexus.NXobject """ # Third party modules from nexusformat.nexus import ( NXdata, NXentry, NXfield, NXgroup, NXlink, NXlinkgroup, NXroot, ) if isinstance(nxobject, NXlinkgroup): # Top level nxobject is a linked group # Create a group with the same name as the top level's target nxobject_copy = nxobject[nxobject.nxtarget].__class__( name=nxobject.nxname) elif isinstance(nxobject, (NXlink, NXfield)): # Top level nxobject is a (linked) field: return a copy attrs = nxobject.attrs attrs.pop('target', None) nxobject_copy = NXfield( value=nxobject.nxdata, name=nxobject.nxname, attrs=attrs) return nxobject_copy else: # Create a group with the same type/name as the nxobject if nxgroup_to_nxdata and isinstance(nxobject, NXgroup): nxobject_copy = NXdata(name=nxobject.nxname) else: nxobject_copy = nxobject.__class__(name=nxobject.nxname) # Copy attributes if isinstance(nxobject, NXroot): if 'default' in nxobject.attrs: nxobject_copy.attrs['default'] = nxobject.default else: for k, v in nxobject.attrs.items(): nxobject_copy.attrs[k] = v # Setup paths if exclude_nxpaths is None: exclude_nxpaths = [] elif isinstance(exclude_nxpaths, str): exclude_nxpaths = [exclude_nxpaths] for i, exclude_nxpath in enumerate(exclude_nxpaths): if exclude_nxpath[0] == '/': exclude_nxpaths[i] = exclude_nxpath[1:] if nxpath_prefix is None: nxpath_prefix = '' if nxpathabs_prefix is None: if isinstance(nxobject, NXentry): nxpathabs_prefix = nxobject.nxpath else: nxpathabs_prefix = nxobject.nxpath.removesuffix(nxobject.nxname) if nxpath_copy_abspath is None: nxpath_copy_abspath = '' # Loop over all nxobject's children for k, v in nxobject.items(): nxpath = os.path.join(nxpath_prefix, k) nxpathabs = os.path.join(nxpathabs_prefix, nxpath) if nxpath in exclude_nxpaths: if 'default' in nxobject_copy.attrs and nxobject_copy.default == k: nxobject_copy.attrs.pop('default') continue if isinstance(v, NXlinkgroup): if nxpathabs == v.nxpath and not any( v.nxtarget.startswith(os.path.join(nxpathabs_prefix, p)) for p in exclude_nxpaths): nxobject_copy[k] = NXlink(v.nxtarget) else: nxobject_copy[k] = nxcopy( v, exclude_nxpaths=exclude_nxpaths, nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix, nxpath_copy_abspath=os.path.join(nxpath_copy_abspath, k)) elif isinstance(v, NXlink): if nxpathabs == v.nxpath and not any( v.nxtarget.startswith(os.path.join(nxpathabs_prefix, p)) for p in exclude_nxpaths): nxobject_copy[k] = v else: nxobject_copy[k] = v.nxdata for kk, vv in v.attrs.items(): nxobject_copy[k].attrs[kk] = vv nxobject_copy[k].attrs.pop('target', None) elif isinstance(v, NXgroup): nxobject_copy[k] = nxcopy( v, exclude_nxpaths=exclude_nxpaths, nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix, nxpath_copy_abspath=os.path.join(nxpath_copy_abspath, k)) else: nxobject_copy[k] = v.nxdata for kk, vv in v.attrs.items(): nxobject_copy[k].attrs[kk] = vv if nxpathabs != os.path.join(nxpath_copy_abspath, k): nxobject_copy[k].attrs.pop('target', None) return nxobject_copy
[docs] def get_default_path(nxobject): """Return the relative path to the default plottable NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object within the parent `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__ provided. :param nxobject: Parent NXobject containing plottable NXdata. :type nxobject: nexusformat.nexus.NXobject :returns: Path to default NXdata group. :rtype: str """ # Third party modules from nexusformat.nexus import NXroot if (isinstance(nxobject, NXroot) and 'default' not in nxobject.attrs and len(nxobject.entries) == 1): current = nxobject.keys()[0] path = current.nxname else: path = '' current = nxobject while current.attrs.get('default') is not None: path += '/' + current.attrs['default'] current = current[current.attrs['default']] return path
[docs] def dictionary_update(target, source, merge_key_paths=None, sort=False): """Recursively updates a target dictionary with values from a source dictionary. Source values superseed target values for identical keys unless both values are lists of dictionaries in which case they are merged according to the merge_key_paths parameter. :param target: Target dictionary. :type target: collections.abc.Mapping :param source: Source dictionary. :type target: collections.abc.Mapping :param merge_key_paths: List key paths to merge dictionary lists, only used if items in the target and source dictionary trees are lists of dictionaries. :type merge_key_paths: str or list[str] :param sort: Sort dictionary lists on the key. :type sort: bool, optional :return: Updated target directory. :rtype: collections.abc.Mapping """ if not isinstance(target, dict): raise ValueError( 'Invalid parameter type "target" ({type(target)})') if not isinstance(source, dict): raise ValueError( 'Invalid parameter type "source" ({type(source)})') for k, v in source.items(): if (isinstance(v, collections.abc.Mapping) and isinstance(target.get(k), collections.abc.Mapping)): if merge_key_paths is not None: raise NotImplementedError( f'"merge_key_paths" ({type(merge_key_paths)}) ' 'for source and target dictionaries not yet implemented') # merge_key_path = None # if '/' in merge_key_paths: # print(f'"/" in merge_key_path') # merge_key_path = merge_key_paths.split('/', 1)[1:] # elif is_str_series(merge_key_paths): # print(f'merge_key_path is string series') # merge_key_path = [ # vv[1] for vv in [ # v for v in [merge_key_paths.split('/', 1) # for s in sss]] # if (vv[0]==k and len(vv)>1)] # print(f'---> merge_key_path: {merge_key_path}') target[k] = dictionary_update(target.get(k, {}), v) elif (is_dict_series(v, log=False) and is_dict_series(target.get(k), log=False)): if isinstance(merge_key_paths, str): merge_key_path = merge_key_paths merge_key_type = None elif isinstance(merge_key_paths, dict): merge_key_path = merge_key_paths.get('key_path') merge_key_type = merge_key_paths.get('type') elif merge_key_path is not None: raise NotImplementedError( 'Invalid/unimplemeted parameter type "merge_key_path" ' f'({type(merge_key_path)}) for source and target ' 'lists of dictionaries') merge_key = l[1] if len( l:=merge_key_path.split('/')) == 2 else None # if '/' in merge_key_paths: # merge_key_paths = [merge_key_paths] # if is_str_series(merge_key_paths): # paths = paths if len( # paths:=[l[1] for path in merge_key_paths # if (l:=path.split('/', 1))[0] == k and len(l)>1] # ) else [None] # if len(paths) > 1: # raise ValueError( # 'Ambiguous parameter merge_key_paths ' # f'({merge_key_paths}) while trying to merge ' # f'{source} with {target}') # merge_path = paths[0] # else: # merge_path = None target[k] = list_dictionary_update( target.get(k), v, key=merge_key, key_type=merge_key_type, sort=sort) else: target[k] = v return target
[docs] def list_dictionary_update( target, source, key=None, key_type=None, sort=False): """Recursively updates a target list of dictionaries with values from a source list of dictionaries. Each list item is updated item by item based on the key if given and equal to a key that is shared among all sets of source and target list item keys. Otherwise the target list appended to the source list is returned. :param target: Target list. :type target: list :param source: Source list. :type source: list :param key: Selected key to merge the lists of dictionaries. :type key: str, optional :param key_type: Key type to enforce. :type key_type: type, optional :param sort: Sort the returned list on the key. :type sort: bool, optional :return: Updated list. :rtype: list """ if not isinstance(target, list): raise ValueError( 'Invalid parameter type "target" ({type(target)})') if not isinstance(source, list): raise ValueError( 'Invalid parameter type "source" ({type(source)})') if not source or key is None: return source + target if not isinstance(key, str) or '/' in key: raise ValueError('Invalid parameter "key" ({key}, {type(key)})') if not (key_type is None or isinstance(key_type, type)): raise ValueError( 'Invalid parameter "key_type" ({key_type}, {type(key_type)})') all_any_source = all_any(source, key) if all_any_source < 0: raise ValueError( f'Partially shared key ({key}) while trying to merge {source} ' f'with {target}') all_any_target = all_any(target, key) if all_any_target < 0 or all_any_source != all_any_target: raise ValueError( f'Partially shared key ({key}) while trying to merge {source} ' f'with {target}') if not all_any_source and not all_any_target: return source + target merged = [] ssource = deepcopy(source) for target_dict in target: value = target_dict[key] if key_type is not None: value = key_type(value) for i, source_dict in enumerate(ssource): vvalue = source_dict[key] if key_type is not None: vvalue = key_type(vvalue) if value == vvalue: merged.append(dictionary_update( target_dict, source_dict, sort=sort)) ssource.pop(i) break else: merged.append(target_dict) merged.extend(ssource) if sorted: if key_type is None: merged.sort(key=lambda x: x[key]) else: merged.sort(key=lambda x: key_type(x[key])) return merged