Source code for CHAP.common.processor

#!/usr/bin/env python
#-*- coding: utf-8 -*-
"""Module for generic Processors used in multiple experiment-specific
workflows.
"""

# System modules
from copy import deepcopy
import os
from typing import Optional

# Third party modules
import numpy as np
from pydantic import (
    Field,
    PrivateAttr,
    conint,
    conlist,
    field_validator,
)

# Local modules
from CHAP.common.models.common import ImageProcessorConfig
from CHAP.common.models.map import (
    DetectorConfig,
    MapConfig,
)
from CHAP.pipeline import PipelineData
from CHAP.processor import Processor


[docs] class AsyncProcessor(Processor): """A Processor to process multiple sets of input data via asyncio module. :ivar mgr: The `Processor` used to process every set of input data. :vartype mgr: Processor """ def __init__(self, mgr): super().__init__() self.mgr = mgr
[docs] def process(self, data): """Asynchronously process the input documents with the `self.mgr` `Processor`. :param data: Input data. :type data: list[PipelineData] """ # System modules import asyncio async def task(mgr, doc): """Process given data using provided `Processor`. :param mgr: Object that will process given data. :type mgr: Processor :param doc: Data to process. :type doc: object :return: Processed data. :rtype: object """ return mgr.process(doc) async def execute_tasks(mgr, docs): """Process given set of documents using provided task manager. :param mgr: Object that will process all documents. :type mgr: Processor :param docs: Set of data documents to process. :type doc: iterable """ coroutines = [task(mgr, d) for d in docs] await asyncio.gather(*coroutines) asyncio.run(execute_tasks(self.mgr, data))
[docs] class BinarizeProcessor(Processor): """A Processor to binarize a dataset. :ivar nxmemory: Maximum memory usage when reading `NeXus <https://www.nexusformat.org>`__ files. :vartype nxmemory: int, optional """ nxmemory: Optional[conint(gt=0)] = 100000
[docs] def process(self, data, config=None): """Plot and return a binarized dataset from a dataset contained in `data`. The dataset must either be `array-like` or a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__ object with a default plottable data path or a specified path to a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ or `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ object. :param data: Input data. :type data: list[PipelineData] :param config: Initialization parameters for an instance of :class:`~CHAP.common.models.BinarizeConfig`. :type config: dict, optional :return: Binarized dataset for an `array-like` input or a return type equal that of the input object with the binarized dataset added. :rtype: numpy.ndarray | nexusformat.nexus.NXobject """ # Third party modules from nexusformat.nexus import ( NXdata, NXfield, NXlink, nxsetconfig, ) # Local modules from CHAP.utils.general import nxcopy nxsetconfig(memory=self.nxmemory) # Load the validated processor configuration if config is None: # Local modules from CHAP.common.models.common import BinarizeConfig config = BinarizeConfig() else: config = self.get_config( data, config=config, schema='common.models.BinarizeConfig') # Load the default data try: nxobject = self.get_data(data) if config.nxpath is None: dataset = nxobject.get_default() else: dataset = nxobject[config.nxpath] if isinstance(dataset, NXdata): nxsignal = dataset.nxsignal data = nxsignal.nxdata else: data = dataset.nxdata assert isinstance(data, np.ndarray) except Exception: try: data = np.asarray(self.get_pipelinedata_item(data)) except Exception as exc: raise ValueError( 'Unable the load a valid input data object') from exc if config.method == 'yen': min_ = data.min() max_ = data.max() data = 1 + (config.num_bin - 1) * (data - min_) / (max_ - min_) # Get a histogram of the data counts, edges = np.histogram(data, bins=config.num_bin) centers = edges[:-1] + 0.5 * np.diff(edges) # Calculate the data cutoff threshold # pylint: disable=no-name-in-module if config.method == 'CHAP': weights = np.cumsum(counts) means = np.cumsum(counts * centers) weights = weights[0:-1] / weights[-1] means = means[0:-1] / means[-1] variances = (means-weights)**2 / (weights * (1. - weights)) threshold = centers[np.argmax(variances)] elif config.method == 'otsu': # Third party modules from skimage.filters import threshold_otsu threshold = threshold_otsu(hist=(counts, centers)) elif config.method == 'yen': # Third party modules from skimage.filters import threshold_yen threshold = threshold_yen(hist=(counts, centers)) elif config.method == 'isodata': # Third party modules from skimage.filters import threshold_isodata threshold = threshold_isodata(hist=(counts, centers)) else: # Third party modules from skimage.filters import threshold_minimum threshold = threshold_minimum(hist=(counts, centers)) # pylint: enable=no-name-in-module # Apply the data cutoff threshold data = np.where(data < threshold, 0, 1).astype(np.ubyte) # Return the output for array-like or NXfield inputs if isinstance(dataset, np.ndarray): return data if isinstance(dataset, NXfield): attrs = dataset.attrs attrs.pop('target', None) nxfield = NXfield( value=data, name=f'{dataset.nxname}_binarized', attrs=attrs) return nxfield # Otherwise create a copy of the input data, add the binarized # data to the copied original dataset, and remove the original # dataset if config.remove_original_data is set name = f'{nxsignal.nxname}_binarized' nxdefault = nxobject.get_default() if isinstance(nxsignal, NXlink): link = dataset.nxpath path = os.path.split(nxsignal.nxtarget)[0] else: link = nxdefault.nxpath path = os.path.split(nxsignal.nxpath)[0] exclude_nxpaths = [] if config.remove_original_data: if link is not None: exclude_nxpaths.append(os.path.relpath( f'{link}/{nxsignal.nxname}', nxobject.nxpath)) exclude_nxpaths.append(os.path.relpath( f'{path}/{nxsignal.nxname}', nxobject.nxpath)) nxobject = nxcopy(nxobject, exclude_nxpaths=exclude_nxpaths) attrs = nxsignal.attrs attrs.pop('target', None) nxobject[f'{path}/{name}'] = NXfield( value=data, name=name, attrs=attrs) nxobject[path].attrs['signal'] = name if link is not None: nxobject[f'{link}/{name}'] = NXlink(f'{path}/{name}') nxobject[link].attrs['signal'] = name return nxobject
[docs] class ConstructBaseline(Processor): """A Processor to construct a baseline for a dataset."""
[docs] def process( self, data, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, save_figures=False): """Construct and return the baseline for a dataset. :param data: Input data. :type data: numpy.ndarray | list[PipelineData] :param x: Independent dimension (only used when running interactively or when filename is set). :type x: array-like, optional :param mask: Mask to apply to the spectrum before baseline construction. :type mask: array-like, optional :param tol: Convergence tolerence, defaults to `1.e-6`. :type tol: float, optional :param lam: &lambda (smoothness) parameter (the balance between the residual of the data and the baseline and the smoothness of the baseline). The suggested range is between 100 and 10^8, defaults to `10^6`. :type lam: float, optional :param max_iter: Maximum number of iterations, defaults to `20`. :type max_iter: int, optional :param save_figures: Save .pngs of plots for checking inputs & outputs of this Processor, defaults to `False`. :type save_figures: bool, optional :return: Smoothed baseline and the configuration. :rtype: numpy.ndarray, dict """ try: y = np.asarray(self.get_pipelinedata_item(data)) except Exception as exc: raise ValueError(f'No valid data found in {data}') from exc return self.construct_baseline( y, x=x, mask=mask, tol=tol, lam=lam, max_iter=max_iter, return_buf=save_figures)
[docs] @staticmethod def construct_baseline( y, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, title=None, xlabel=None, ylabel=None, interactive=False, return_buf=False): """Construct and return the baseline for a dataset. :param y: Input data. :type y: numpy.ndarray :param x: Independent dimension (only used when interactive is `True` of when filename is set). :type x: array-like, optional :param mask: Mask to apply to the spectrum before baseline construction. :type mask: array-like, optional :param tol: Convergence tolerence, defaults to `1.e-6`. :type tol: float, optional :param lam: &lambda (smoothness) parameter (the balance between the residual of the data and the baseline and the smoothness of the baseline). The suggested range is between 100 and 10^8, defaults to `10^6`. :type lam: float, optional :param max_iter: Maximum number of iterations, defaults to `20`. :type max_iter: int, optional :param title: Title for the displayed figure. :type title: str, optional :param xlabel: Label for the x-axis of the displayed figure. :type xlabel: str, optional :param ylabel: Label for the y-axis of the displayed figure. :type ylabel: str, optional :param interactive: Allows for user interactions, defaults to `False`. :type interactive: bool, optional :param return_buf: Return an in-memory object as a byte stream represention of the Matplotlib figure, defaults to `False`. :type return_buf: bool, optional :return: Smoothed baseline and the configuration and a byte stream represention of the Matplotlib figure if return_buf is `True` (`None` otherwise) :rtype: numpy.ndarray, dict, io.BytesIO | None """ # Third party modules from matplotlib.widgets import TextBox, Button import matplotlib.pyplot as plt # Local modules from CHAP.utils.general import ( baseline_arPLS, fig_to_iobuf, ) def _change_fig_subtitle(maxed_out=False, subtitle=None): """Change the figure's subtitle.""" if fig_subtitles: fig_subtitles[0].remove() fig_subtitles.pop() if subtitle is None: subtitle = r'$\lambda$ = 'f'{lambdas[-1]:.2e}, ' if maxed_out: subtitle += f'# iter = {num_iters[-1]} (maxed out) ' else: subtitle += f'# iter = {num_iters[-1]} ' subtitle += f'error = {errors[-1]:.2e}' fig_subtitles.append( plt.figtext(*subtitle_pos, subtitle, **subtitle_props)) def select_lambda(expression): """Callback function for the "Select lambda" TextBox.""" if not expression: return try: lam = float(expression) if lam < 0: raise ValueError except ValueError: _change_fig_subtitle( subtitle='Invalid lambda, enter a positive number') else: lambdas.pop() lambdas.append(10**lam) baseline, _, _, num_iter, error = get_baseline( y, mask=mask, tol=tol, lam=lambdas[-1], max_iter=max_iter) num_iters.pop() num_iters.append(num_iter) errors.pop() errors.append(error) if num_iter < max_iter: _change_fig_subtitle() else: _change_fig_subtitle(maxed_out=True) baseline_handle.set_ydata(baseline) lambda_box.set_val('') plt.draw() def continue_iter(event): """Callback function for the "Continue" button.""" baseline, _, w, n_iter, error = get_baseline( y, mask=mask, w=weights[-1], tol=tol, lam=lambdas[-1], max_iter=max_iter) num_iters[-1] += n_iter errors.pop() errors.append(error) if n_iter < max_iter: _change_fig_subtitle() else: _change_fig_subtitle(maxed_out=True) baseline_handle.set_ydata(baseline) plt.draw() weights.pop() weights.append(w) def confirm(event): """Callback function for the "Confirm" button.""" plt.close() def get_baseline( y, *, mask=None, w=None, tol=1.e-6, lam=1.6, max_iter=20): """Get a baseline. :param y: Input data. :type y: numpy.ndarray :param mask: Mask to apply to the spectrum before baseline construction. :type mask: array-like, optional :param w: Weights (allows restart for additional iterations). :type w: numpy.array, optional :param tol: Convergence tolerence, defaults to `1.e-8`. :type tol: float, optional :param lam: &lambda (smoothness) parameter (the balance between the residual of the data and the baseline and the smoothness of the baseline). The suggested range is between 100 and 10^8, defaults to `10^6`. :type lam: float, optional :param max_iter: Maximum number of iterations, defaults to `20`. :type max_iter: int, optional """ return baseline_arPLS( y, mask=mask, w=w, tol=tol, lam=lam, max_iter=max_iter, full_output=True) baseline, _, w, num_iter, error = get_baseline( y, mask=mask, tol=tol, lam=lam, max_iter=max_iter) if not interactive and not return_buf: config = { 'tol': tol, 'lambda': lam, 'max_iter': max_iter, 'num_iter': num_iter, 'error': error, 'mask': mask} return baseline, config, None lambdas = [lam] weights = [w] num_iters = [num_iter] errors = [error] fig_subtitles = [] # Check inputs if x is None: x = np.arange(y.size) # Setup the Matplotlib figure title_pos = (0.5, 0.95) title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} subtitle_pos = (0.5, 0.90) subtitle_props = {'fontsize': 'x-large', 'horizontalalignment': 'center', 'verticalalignment': 'bottom'} fig, ax = plt.subplots(figsize=(11, 8.5)) if mask is None: ax.plot(x, y, label='input data') else: ax.plot( x[mask.astype(bool)], y[mask.astype(bool)], label='input data') baseline_handle = ax.plot(x, baseline, label='baseline')[0] # ax.plot(x, y-baseline, label='baseline corrected data') ax.legend() ax.set_xlabel(xlabel, fontsize='x-large') ax.set_ylabel(ylabel, fontsize='x-large') ax.set_xlim(x[0], x[-1]) if title is None: fig_title = plt.figtext(*title_pos, 'Baseline', **title_props) else: fig_title = plt.figtext(*title_pos, title, **title_props) if num_iter < max_iter: _change_fig_subtitle() else: _change_fig_subtitle(maxed_out=True) fig.subplots_adjust(bottom=0.0, top=0.85) lambda_box = None if interactive: fig.subplots_adjust(bottom=0.2) # Setup TextBox lambda_box = TextBox( plt.axes([0.15, 0.05, 0.15, 0.075]), r'log($\lambda$)') lambda_cid = lambda_box.on_submit(select_lambda) # Setup "Continue" button continue_btn = Button( plt.axes([0.45, 0.05, 0.15, 0.075]), 'Continue smoothing') continue_cid = continue_btn.on_clicked(continue_iter) # Setup "Confirm" button confirm_btn = Button( plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') confirm_cid = confirm_btn.on_clicked(confirm) # Show figure for user interaction plt.show() # Disconnect all widget callbacks when figure is closed lambda_box.disconnect(lambda_cid) continue_btn.disconnect(continue_cid) confirm_btn.disconnect(confirm_cid) # ... and remove the buttons before returning the figure lambda_box.ax.remove() continue_btn.ax.remove() confirm_btn.ax.remove() if return_buf: fig_title.set_in_layout(True) fig_subtitles[-1].set_in_layout(True) fig.tight_layout(rect=(0, 0, 1, 0.90)) buf = fig_to_iobuf(fig) else: buf = None plt.close() config = { 'tol': tol, 'lambda': lambdas[-1], 'max_iter': max_iter, 'num_iter': num_iters[-1], 'error': errors[-1], 'mask': mask} return baseline, config, buf
[docs] class ConvertStructuredProcessor(Processor): """Processor for converting map data between structured / unstructued formats. """
[docs] def process(self, data): """Return the converted map data :param data: Input data. :type data: list[PipelineData] :return: Converted data. :rtype: nexusformat.nexus.NXdata """ # Local modules from CHAP.utils.converters import convert_structured_unstructured data = self.get_pipelinedata_item(data) return convert_structured_unstructured(data)
[docs] class ExpressionProcessor(Processor): """Processor to perform an arbitrary expression on input data."""
[docs] def process(self, data, expression, symtable=None, nxprocess=False, nxfieldtable=None, nxdata_name='data', nxfield_name='result'): """Return result of plugging input data into the given mathematical expression. :param data: Input data. :type data: list[PipelineData] :param expression: Mathemetical expression. May use the built-in function `round` and / or numpy functions with `np.<function_name>` or `numpy.<function_name>.` :type expression: str :param symtable: Values to use for names in `expression` that should not be obtained from input data. Defaults to `None`. :type symtable: dict[str, (float, int)], optional. :param nxprocess: Flag to indicate the results should be returned as a NeXus style `NXprocess <https://manual.nexusformat.org/classes/base_classes/NXprocess.html#index-0>`__ defaults to `False`. :type nxprocess: bool, optional :param nxfieldtable: Used only if `nxprocess` is `True`. Dictionary of additional `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ objects to include in the NXprocess, the result object right next to the expression result's NXfield. Dictionary keys become NXfield names in the returned object. Defaults to `None`. :type nxfieldtable: dict[str, nexusformat.nexus.NXfield], optional :param nxdata_name: Used only if `nxprocess` is `True`. Name for the NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object in the returned NXprocess object that contains actual result data. efaults to `'data'`. :type nxdata_name: str, optional :param nxfield_name: Used only if `nxprocess` is `True`. Name for the NXfield dataset that contains the evaluated expression results. Defaults to `'result'`. :type nxfield_name: str, optional :returns: Result of evaluating the expression. :rtype: object """ # Third party modules try: import zarr except ImportError: pass from ast import parse from asteval import get_ast_names, Interpreter if not isinstance(nxfieldtable, dict): self.logger.warning('No usable nxfieldtable provided, using {}') nxfieldtable = {} names = get_ast_names(parse(expression)) if symtable is None: symtable = {} for name in names: if name in symtable: continue if name == 'round': symtable[name] = round elif name in ('np', 'numpy'): symtable[name] = np else: symtable[name] = self.get_data( data, name=name, remove=False) for k, v in symtable.items(): try: if isinstance(v, zarr.core.array.Array): symtable[k] = v[()] except Exception: pass self.logger.debug(f'Asteval symtable: {symtable}') aeval = Interpreter(symtable=symtable) new_data = aeval(expression) if not nxprocess: return new_data # Third party modules from nexusformat.nexus import ( NXdata, NXfield, NXprocess, ) return NXprocess( name=nxprocess, entries={ nxdata_name: NXdata( signal=NXfield( name=nxfield_name, value=new_data, attrs={'expression': expression} ), **dict(nxfieldtable.items()), attrs={'expression': expression} ), } )
[docs] class ImageProcessor(Processor): """A Processor to perform various visualization operations on images (slices) selected from a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__. :ivar config: Initialization parameters for an instance of :class:`~CHAP.common.models.ImageProcessorConfig`. :vartype config: dict, optional :ivar nxmemory: Maximum memory usage when reading `NeXus <https://www.nexusformat.org>`__ files. :vartype nxmemory: int, optional :ivar save_figures: Return the plottable image(s) to be written to file downstream in the pipeline, defaults to `True`. :vartype save_figures: bool, optional """ pipeline_fields: dict = Field( default = { 'config': 'common.models.map.ImageProcessorConfig'}, init_var=True) config: ImageProcessorConfig nxmemory: Optional[conint(gt=0)] = 100000 save_figures: Optional[bool] = True _figconfig: dict = PrivateAttr(default={})
[docs] def process(self, data): """Plot and/or return image slices from a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__. object with a default plottable data path. :param data: Input data. :type data: list[PipelineData] :return: Plottable image(s) (for save_figures = `True`) or the input default NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object (for save_figures = `False`). :rtype: bytes | nexusformat.nexus.NXdata | numpy.ndarray """ if not self.save_figures and not self.interactive: return None # Third party modules from nexusformat.nexus import nxsetconfig nxsetconfig(memory=self.nxmemory) # Load the default data try: nxdata = self.get_data(data).get_default() except Exception as exc: raise ValueError( 'Unable the load the default NXdata object from the input ' f'pipeline ({data})') from exc # Get the axes info and image slice(s) try: data = nxdata.nxsignal except Exception as exc: raise ValueError('Unable the find the default signal in:\n' f'({nxdata.tree})') from exc axis = self.config.axis axes = nxdata.attrs.get('axes', None) if isinstance(axes, str): axes = [axes] if nxdata.nxsignal.ndim == 2: exit('ImageProcessor not tested yet for a 2D dataset') if axis is not None: axis = None self.logger.warning('Ignoring parameter axis') if index is not None: index = None self.logger.warning('Ignoring parameter index') if coord is not None: coord = None self.logger.warning('Ignoring parameter coord') elif nxdata.nxsignal.ndim == 3: if isinstance(axis, int): if not 0 <= axis < nxdata.nxsignal.ndim: raise ValueError(f'axis index out of range ({axis} not in ' f'[0, {nxdata.nxsignal.ndim-1}])') elif isinstance(axis, str): if axes is None or axis not in axes: raise ValueError( f'Unable to match axis = {axis} in {nxdata.tree}') axis = axes.index(axis) else: raise ValueError(f'Invalid parameter axis ({axis})') if axis: data = np.moveaxis(data, axis, 0) if axes is not None and hasattr(nxdata, axes[axis]): if axis == 1: axes = [axes[1], axes[0], axes[2]] elif axis: axes = [axes[2], axes[0], axes[1]] axis_name = axes[0] if 'units' in nxdata[axis_name].attrs: axis_unit = f' ({nxdata[axis_name].units})' else: axis_unit = '' row_label = axes[2] row_coords = nxdata[row_label].nxdata column_label = axes[1] column_coords = nxdata[column_label].nxdata if 'units' in nxdata[row_label].attrs: row_label += f' ({nxdata[row_label].units})' if 'units' in nxdata[column_label].attrs: column_label += f' ({nxdata[column_label].units})' else: exit('No axes attribute not tested yet') axes = [0, 1, 2] axes.pop(axis) axis_name = f'axis {axis}' axis_unit = '' # row_label = f'axis {axis[1]}' # row_coords = None # column_label = f'axis {axis[0]}' # column_coords = None axis_coords = nxdata[axis_name].nxdata else: raise ValueError('Invalid data dimension (must be 2D or 3D)') if self.config.coord_range is None: index_range = self.config.index_range else: # Local modules from CHAP.utils.general import ( index_nearest_down, index_nearest_up, ) if self.config.index_range is not None: self.logger.warning('Ignoring parameter index_range') if isinstance(self.config.coord_range, (int, float)): index_range = index_nearest_up( axis_coords, self.config.coord_range) elif len(self.config.coord_range) == 2: index_range = [ index_nearest_up(axis_coords, self.config.coord_range[0]), index_nearest_down( axis_coords, self.config.coord_range[1])] else: index_range = [ index_nearest_up(axis_coords, self.config.coord_range[0]), index_nearest_down( axis_coords, self.config.coord_range[1]), int(max(1, self.config.coord_range[2] / ((axis_coords[-1]-axis_coords[0])/data.shape[0])))] if index_range == -1: index_range = nxdata.nxsignal.shape[axis] // 2 if isinstance(index_range, int): data = data[index_range] axis_coords = [axis_coords[index_range]] elif index_range is not None: slice_ = slice(*tuple(index_range)) data = data[slice_] axis_coords = axis_coords[slice_] if self.config.vrange is None: vrange = (float(data.min()), float(data.max())) else: vrange = self.config.vrange if vrange[0] is None: vrange[0] = float(data.min()) if vrange[1] is None: vrange[1] = float(data.max()) # Create the figure configuration self._figconfig = { 'title': f'{nxdata.nxpath}/{nxdata.signal}', 'axis_name': axis_name, 'axis_unit': axis_unit, 'axis_coords': axis_coords, 'row_label': row_label, 'column_label': column_label, 'extent': (row_coords[0], row_coords[-1], column_coords[-1], column_coords[0]), 'vrange': vrange, } self.logger.debug(f'figure configuration:\n{self._figconfig}') if len(axis_coords) == 1: # Create a figure for a single image slice if self.config.animation: self.logger.warning( 'Ignoring animation parameter for a single image') fileformat = 'png' if self.config.fileformat is None: fileformat = 'png' else: fileformat = self.config.fileformat fig, plt = self._create_figure(np.squeeze(data)) if self.interactive: plt.show() if self.save_figures: # Local modules from CHAP.utils.general import fig_to_iobuf # Return a binary image of the figure buf, fileformat = fig_to_iobuf(fig, fileformat=fileformat) else: buf = None plt.close() if self.save_figures: return {'image_data': buf, 'fileformat': fileformat} return nxdata # Create an animation for a set of image slices if self.interactive or self.config.animation: ani = self._create_animation(data) else: ani = None if self.save_figures: if self.config.animation: # Return the animation object if (self.config.fileformat is not None and self.config.fileformat != 'gif'): self.logger.warning( 'Ignoring inconsistent file extension') fileformat = 'gif' image_data = ani else: # Local modules from CHAP.utils.general import fig_to_iobuf if self.config.fileformat in ('png', 'tif', 'tifstack'): fileformat = self.config.fileformat else: if self.config.fileformat is not None: self.logger.warning('Ignoring invalid fileformat ' f'({self.config.fileformat})') if data.shape[0] == 1: fileformat = 'tif' else: fileformat = 'tifstack' if fileformat != 'tifstack': # Return the set of image slices as individual figs num_digit = len(str(data.shape[0])) images = [] for i in range(data.shape[0]): fig, plt = self._create_figure(data[i]) images.append(( fig_to_iobuf(fig, fileformat=fileformat), f'{self.config.basename}_' f'{str(i).zfill(num_digit)}')) plt.close() return images # Return the set of image slices as a tif stack data = 255.0*((data - vrange[0])/ (vrange[1] - vrange[0])) fileformat = 'tif' image_data = data.astype(np.uint8) return {'image_data': image_data, 'fileformat': fileformat} return nxdata
def _create_animation(self, data): """Create a Matplotlib animation from a set of images.""" # Third party modules from functools import partial from matplotlib import animation def _set_title(self, i): """Set the title for a single Matplotlib image of the animation.""" return self._figconfig['axis_name'] +\ f' = {self._figconfig["axis_coords"][i]:.3f}' +\ self._figconfig['axis_unit'] def _animate(i, plt, title): """Create a single Matplotlib image for the animation.""" im.set_array(data[i]) title.set_text(self._set_title(i)) plt.draw() return im, fig, im, plt, title = self._create_figure(data[0], animated=True) ani = animation.FuncAnimation( fig, partial(_animate, plt=plt, title=title), frames=data.shape[0], interval=50, blit=True) if self.interactive: plt.show() plt.close() return ani def _create_figure(self, image, animated=False): """Create a Matplotlib figure from an image.""" # Third party modules import matplotlib.pyplot as plt fig, ax = plt.subplots() im = plt.imshow( image, extent=self._figconfig['extent'], origin='lower', vmin=self._figconfig['vrange'][0], vmax=self._figconfig['vrange'][1], cmap='gray', animated=animated) fig.suptitle(self._figconfig['title'], fontsize='x-large') title = ax.set_title(self._set_title(0), fontsize='x-large', pad=10) ax.set_xlabel(self._figconfig['row_label'], fontsize='x-large') ax.set_ylabel(self._figconfig['column_label'], fontsize='x-large') plt.colorbar() fig.tight_layout() if animated: return fig, im, plt, title return fig, plt
[docs] class MapProcessor(Processor): """A Processor that takes a map configuration and returns a NeXus style `NXentry <https://manual.nexusformat.org/classes/base_classes/NXentry.html#index-0>`__ object representing that map's metadata and any scalar-valued raw data requested by the supplied map configuration. :ivar config: Map configuration parameters to initialize an instance of :class:`~CHAP.common.models.map.MapConfig`. Any values in `'config'` supplant their corresponding values obtained from the pipeline data configuration. :vartype config: dict | MapConfig :ivar detector_config: Detector configurations of the detectors to include raw data for in the returned NXentry object (overruling detector info in the pipeline data, if present). :vartype detector_config: dict | DetectorConfig :ivar num_proc: Number of processors used to read map, defaults to `1`. :vartype num_proc: int, optional """ pipeline_fields: dict = Field( default = { 'config': 'common.models.map.MapConfig', 'detector_config': 'common.models.map.DetectorConfig'}, init_var=True) config: Optional[MapConfig] = None detector_config: DetectorConfig = DetectorConfig(detectors=[]) num_proc: Optional[conint(gt=0)] = 1
[docs] @field_validator('num_proc') @classmethod def validate_num_proc(cls, num_proc, info): """Validate the number of processors. :param num_proc: Number of processors used to read map, defaults to `1`. :type num_proc: int, optional :param info: Model parameter validation information. :type info: pydantic.ValidationInfo :return: Validated number of processors :rtype: str """ if num_proc > 1: logger = info.data['logger'] try: # Third party modules # pylint: disable=unused-import from mpi4py import MPI if num_proc > os.cpu_count(): logger.warning( f'The requested number of processors ({num_proc}) ' 'exceeds the maximum number of processors ' f'({os.cpu_count()}): reset it to {os.cpu_count()}') num_proc = os.cpu_count() except ImportError: logger.warning('Unable to load mpi4py, running serially') num_proc = 1 logger.debug(f'Number of processors: {num_proc}') return num_proc
[docs] def process( self, data, placeholder_data=False, fill_data=True, comm=None): """Process that takes a map configuration and returns a NeXus style `NXentry <https://manual.nexusformat.org/classes/base_classes/NXentry.html#index-0>`__ object representing the map. :param data: Pipeline data list with an optional item for the map configuration parameters with `'common.models.map.MapConfig'` as its `'schema'` key. :type data: list[PipelineData] :param placeholder_data: For SMB EDD maps only. Value to use for missing detector data frames, or `False` if missing data should raise an error, defaults to `False`. :type placeholder_data: object, optional :param fill_data: Flag to indicate whether or not to fill out datasets with real data; defaults to `True`. :type fill_data: bool, optional :param comm: MPI communicator. :type comm: mpi4py.MPI.Comm, optional :return: Map data and metadata. :rtype: Union[nexusformat.nexus.NXentry """ # System modules import logging # Third party modules import yaml # Update metadata # self._metadata['user_metadata'].update({ # 'map': self.config.model_dump()}) # Check for available metadata metadata = {} provenance = {} if data: metadata, provenance = self._get_metadata_provenance(data) if fill_data: # Create the sub-pipeline configuration for each processor # FIX: catered to EDD with one spec scan assert len(self.config.spec_scans) == 1 spec_scans = self.config.spec_scans[0] scan_numbers = spec_scans.scan_numbers num_scan = len(scan_numbers) if num_scan < self.num_proc: self.logger.warning( f'Requested number of processors ({self.num_proc}) exceeds ' f'the number of scans ({num_scan}): reset it to {num_scan}') self.num_proc = num_scan if self.num_proc == 1: common_comm = comm offsets = [0] else: # System modules from tempfile import NamedTemporaryFile # Local modules from CHAP.models import RunConfig raise NotImplementedError( 'MapProcessor needs testing for num_proc>1') scans_per_proc = num_scan//self.num_proc num = scans_per_proc if num_scan - scans_per_proc*self.num_proc > 0: num += 1 spec_scans.scan_numbers = scan_numbers[:num] n_scan = num pipeline_config = [] offsets = [0] for n_proc in range(1, self.num_proc): num = scans_per_proc if n_proc < num_scan - scans_per_proc*self.num_proc: num += 1 config = self.config.model_dump() config['spec_scans'][0]['scan_numbers'] = \ scan_numbers[n_scan:n_scan+num] pipeline_config.append([{ 'common.MapProcessor': { 'config': config, 'detector_config': self.detector_config.model_dump()}}]) offsets.append(n_scan) n_scan += num # Spawn the workers to run the sub-pipeline run_config = RunConfig( log_level=logging.getLevelName(self.logger.level), spawn=1) tmp_names = [] with NamedTemporaryFile(delete=False) as fp: # pylint: disable=c-extension-no-member fp_name = fp.name tmp_names.append(fp_name) with open(fp_name, 'w', encoding='utf-8') as f: yaml.dump({'config': {'spawn': 1}}, f, sort_keys=False) for n_proc in range(1, self.num_proc): f_name = f'{fp_name}_{n_proc}' tmp_names.append(f_name) with open(f_name, 'w', encoding='utf-8') as f: yaml.dump( # FIX once comm is a field of RunConfig #processor.py {'config': run_config.model_dump(exclude='comm'), {'config': run_config.model_dump(), 'pipeline': pipeline_config[n_proc-1]}, f, sort_keys=False) # pylint: disable=used-before-assignment sub_comm = MPI.COMM_SELF.Spawn( 'CHAP', args=[fp_name], maxprocs=self.num_proc-1) common_comm = sub_comm.Merge(False) # Align with the barrier in RunConfig() on common_comm # called from the spawned main() in common_comm common_comm.barrier() # Align with the barrier in run() on common_comm # called from the spawned main() common_comm.barrier() if common_comm is None: self.num_proc = 1 rank = 0 else: self.num_proc = common_comm.Get_size() rank = common_comm.Get_rank() if self.num_proc == 1: offset = 0 else: num_scan = common_comm.bcast(num_scan, root=0) offset = common_comm.scatter(offsets, root=0) # Read the raw data if self.config.experiment_type == 'EDD': data, independent_dimensions, all_scalar_data = \ self._read_raw_data_edd( common_comm, num_scan, offset, placeholder_data) else: data, independent_dimensions, all_scalar_data = \ self._read_raw_data(common_comm, num_scan, offset) if not rank: self.logger.debug(f'Data shape: {data.shape}') if independent_dimensions is not None: self.logger.debug('Independent dimensions shape: ' f'{independent_dimensions.shape}') if all_scalar_data is not None: self.logger.debug('Scalar data shape: ' f'{all_scalar_data.shape}') if rank: return None if self.num_proc > 1: # Reset the scan_numbers to the original full set spec_scans.scan_numbers = scan_numbers # Align with the barrier in main() on common_comm # when disconnecting the spawned worker common_comm.barrier() # Disconnect spawned workers and cleanup temporary files sub_comm.Disconnect() for tmp_name in tmp_names: os.remove(tmp_name) else: # fill_data is False, just use empty arrays map_len = 0 _independent_dimensions = { dim.label: [] for dim in self.config.independent_dimensions} det_shapes = False for scans in self.config.spec_scans: for scan_number in scans.scan_numbers: scanparser = scans.get_scanparser(scan_number) map_len += scanparser.spec_scan_npts for dim in self.config.independent_dimensions: val = dim.get_value( scans, scan_number, -1, self.config.scalar_data) if not isinstance(val, list): val = [val] _independent_dimensions[dim.label].extend(val) if not det_shapes: det_shapes = {} for detector in self.detector_config.detectors: ddata_init = scanparser.get_detector_data( detector.get_id(), 0) if isinstance(ddata_init, tuple): ddata_init = ddata_init[0].squeeze() det_shapes[detector.get_id()] = ddata_init.shape all_scalar_data = np.empty( (len(self.config.all_scalar_data), map_len)) if len(self.detector_config.detectors) > 0: data = np.empty( (len(self.detector_config.detectors), map_len, *det_shapes[self.detector_config.detectors[0].get_id()])) else: data = None independent_dimensions = np.asarray( [_independent_dimensions[dim.label] for dim in self.config.independent_dimensions]) # Construct and return the NXroot object nxroot = self._get_nxroot( data, independent_dimensions, all_scalar_data, placeholder_data) if metadata and provenance: return ( PipelineData( name=self.name, data=metadata, schema='foxden.reader.FoxdenMetadataReader'), PipelineData( name=self.name, data=provenance, schema='foxden.reader.FoxdenProvenanceReader'), PipelineData( name=self.name, data=nxroot, schema=self.get_schema())) return nxroot
def _get_metadata_provenance(self, data): """Get experiment specific configurational data from the FOXDEN metadata and provenance records. :param data: Input data. :type data: list[PipelineData] :return: Experiment specific metadata and provenance. :rtype: dict, dict """ # Local modules from CHAP.tomo.processor import ( read_metadata_provenance, create_metadata_provenance, ) # Read metadata and provenance from the pipeline data metadata, provenance = read_metadata_provenance(data, self.logger) if not metadata: return metadata, provenance # Update metadata and provenance experiment_type = [v.lower() for v in metadata.get('technique')] if 'tomography' in experiment_type: station = f'id{metadata.get("beamline")[0].lower()}' if station in ('id1a3', 'id3a'): spec_file = os.path.join( metadata.get('data_location_raw'), 'spec.log') else: raise ValueError(f'Invalid beamline parameter ({station})') sample_name = metadata.get('sample_name') # FIX We could add full self.config and self.detector_config # That would allow us to create a full map from just metadata user_metadata = { 'map_config': { 'experiment_type': 'TOMO', 'sample': {'name': sample_name, 'description': metadata.get('description')}, 'station': station, 'title': sample_name, 'spec_file': spec_file}} metadata, provenance = create_metadata_provenance( 'map', metadata=metadata, provenance=provenance, user_metadata=user_metadata, logger=self.logger, update=False, read=False) user_metadata = metadata.pop('user_metadata', {}) metadata['user_metadata'] = user_metadata else: raise ValueError( f'Experiment type {experiment_type} not implemented yet') return metadata, provenance def _get_nxroot( self, data, independent_dimensions, all_scalar_data, placeholder_data): """Use a `MapConfig` to construct a NeXus style `NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__ object. :param data: Map's raw data. :type data: numpy.ndarray :param independent_dimensions: Map's independent coordinates. :type independent_dimensions: numpy.ndarray :param all_scalar_data: Map's scalar data. :type all_scalar_data: numpy.ndarray :param placeholder_data: For SMB EDD maps only. Value to use for missing detector data frames, or `False` if missing data should raise an error. :type placeholder_data: object :return: Map's data and metadata. :rtype: nexusformat.nexus.NXroot """ # Third party modules # pylint: disable=no-name-in-module from nexusformat.nexus import ( NXcollection, NXdata, NXentry, NXfield, NXlinkfield, NXsample, NXroot, ) # pylint: enable=no-name-in-module # Local modules: from CHAP.common.models.map import PointByPointScanData def linkdims(nxgroup, nxdata_source): """Link the dimensions for an `NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__. """ source_axes = list(nxdata_source.keys()) if isinstance(source_axes, str): source_axes = [source_axes] axes = [] for dim in source_axes: axes.append(dim) if isinstance(nxdata_source[dim], NXlinkfield): nxgroup[dim] = nxdata_source[dim] else: nxgroup.makelink(nxdata_source[dim]) if f'{dim}_indices' in nxdata_source.attrs: nxgroup.attrs[f'{dim}_indices'] = \ nxdata_source.attrs[f'{dim}_indices'] if len(axes) == 1: nxgroup.attrs['axes'] = axes else: nxgroup.attrs['unstructured_axes'] = axes # Set up NXroot/NXentry and add CHESS-specific metadata nxroot = NXroot() nxentry = NXentry(name=self.config.title) nxroot[nxentry.nxname] = nxentry nxentry.map_config = self.config.model_dump_json() nxentry.detector_config = self.detector_config.model_dump_json() nxentry.attrs['station'] = self.config.station for k, v in self.config.attrs.items(): nxentry.attrs[k] = v nxentry.spec_scans = NXcollection() for scans in self.config.spec_scans: nxentry.spec_scans[scans.scanparsers[0].scan_name] = \ NXfield(value=scans.scan_numbers, dtype='int8', attrs={'spec_file': str(scans.spec_file)}) # Add sample metadata nxentry[self.config.sample.name] = NXsample( **self.config.sample.model_dump()) # Set up independent dimensions NXdata group # (squeeze out constant dimensions) constant_dim = [] for i, dim in enumerate(self.config.independent_dimensions): unique = np.unique(independent_dimensions[i]) if unique.size == 1: constant_dim.append(i) nxentry.independent_dimensions = NXdata() if len(constant_dim) < len(self.config.independent_dimensions): for i, dim in enumerate(self.config.independent_dimensions): if i not in constant_dim: nxentry.independent_dimensions[dim.label] = NXfield( independent_dimensions[i], dim.label, attrs={'units': dim.units, 'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, 'local_name': dim.name}) else: nxentry.independent_dimensions.index = NXfield( np.arange(independent_dimensions[0].size), 'index') # Set up scalar data NXdata group # (add the constant independent dimensions) if all_scalar_data is not None: self.logger.debug( f'all_scalar_data.shape = {all_scalar_data.shape}\n\n') scalar_signals = [] scalar_data = [] for i, dim in enumerate(self.config.all_scalar_data): scalar_signals.append(dim.label) scalar_data.append(NXfield( value=all_scalar_data[i], units=dim.units, attrs={'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, 'local_name': dim.name})) if (self.config.experiment_type == 'EDD' and not placeholder_data is False): scalar_signals.append('placeholder_data_used') scalar_data.append(NXfield( value=all_scalar_data[-1], attrs={'description': 'Indicates whether placeholder data may be present for' 'the corresponding frames of detector data.'})) for i, dim in enumerate(deepcopy(self.config.independent_dimensions)): if i in constant_dim: scalar_signals.append(dim.label) scalar_data.append(NXfield( independent_dimensions[i], dim.label, attrs={'units': dim.units, 'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, 'local_name': dim.name})) self.config.all_scalar_data.append( PointByPointScanData(**dim.model_dump())) self.config.independent_dimensions.remove(dim) if scalar_signals: nxentry.scalar_data = NXdata() for k, v in zip(scalar_signals, scalar_data): nxentry.scalar_data[k] = v if 'SCAN_N' in scalar_signals: nxentry.scalar_data.attrs['signal'] = 'SCAN_N' else: nxentry.scalar_data.attrs['signal'] = scalar_signals[0] scalar_signals.remove(nxentry.scalar_data.attrs['signal']) nxentry.scalar_data.attrs['auxiliary_signals'] = scalar_signals # Add detector data nxdata = NXdata() nxentry.data = nxdata nxentry.data.set_default() detector_ids = [] for k, v in self.config.attrs.items(): nxdata.attrs[k] = v if data is not None: min_ = np.min(data, axis=tuple(range(1, data.ndim))) max_ = np.max(data, axis=tuple(range(1, data.ndim))) for i, detector in enumerate(self.detector_config.detectors): nxdata[detector.get_id()] = NXfield( value=data[i], attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]}) detector_ids.append(detector.get_id()) linkdims(nxdata, nxentry.independent_dimensions) if len(self.detector_config.detectors) == 1: nxdata.attrs['signal'] = self.detector_config.detectors[0].get_id() nxentry.detector_ids = detector_ids return nxroot def _read_raw_data_edd( self, comm, num_scan, offset, placeholder_data): """Read the raw EDD data for a given map configuration. :param comm: MPI communicator. :type comm: mpi4py.MPI.Comm, optional :param num_scan: Number of scans in the map. :type num_scan: int :param offset: Offset scan number of current processor. :type offset: int :param placeholder_data: Value to use for missing detector data frames, or `False` if missing data should raise an error. :type placeholder_data: object :return: Map's raw data, independent dimensions and scalar data. :rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray """ # Third party modules try: from mpi4py import MPI from mpi4py.util import dtlib except ImportError: pass # Local modules from CHAP.utils.general import list_to_string if comm is None: self.num_proc = 1 rank = 0 else: self.num_proc = comm.Get_size() rank = comm.Get_rank() if not rank: self.logger.debug(f'Number of processors: {self.num_proc}') self.logger.debug(f'Number of scans: {num_scan}') # Create the shared data buffers # FIX: just one spec scan at this point assert len(self.config.spec_scans) == 1 scan = self.config.spec_scans[0] scan_numbers = scan.scan_numbers scanparser = scan.get_scanparser(scan_numbers[0]) detector_ids = [ int(d.get_id()) for d in self.detector_config.detectors] ddata, placeholder_used = scanparser.get_detector_data( detector_ids, placeholder_data=placeholder_data) spec_scan_shape = scanparser.spec_scan_shape num_dim = np.prod(spec_scan_shape) num_id = len(self.config.independent_dimensions) num_sd = len(self.config.all_scalar_data) if placeholder_data is not False: num_sd += 1 if self.num_proc == 1: assert num_scan == len(scan_numbers) data = np.empty((num_scan, *ddata.shape), dtype=ddata.dtype) independent_dimensions = np.empty( (num_id, num_scan*num_dim), dtype=np.float64) all_scalar_data = np.empty( (num_sd, num_scan*num_dim), dtype=np.float64) else: self.logger.debug(f'Scan offset on processor {rank}: {offset}') self.logger.debug(f'Scan numbers on processor {rank}: ' f'{list_to_string(scan_numbers)}') datatype = dtlib.from_numpy_dtype(ddata.dtype) itemsize = datatype.Get_size() if not rank: nbytes = num_scan * np.prod(ddata.shape) * itemsize else: nbytes = 0 win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm) buf, itemsize = win.Shared_query(0) assert itemsize == datatype.Get_size() data = np.ndarray( buffer=buf, dtype=ddata.dtype, shape=(num_scan, *ddata.shape)) datatype = dtlib.from_numpy_dtype(np.float64) itemsize = datatype.Get_size() if not rank: nbytes = num_id * num_scan * num_dim * itemsize win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm) buf_id, _ = win_id.Shared_query(0) independent_dimensions = np.ndarray( buffer=buf_id, dtype=np.float64, shape=(num_id, num_scan*num_dim)) if not rank: nbytes = num_sd * num_scan * num_dim * itemsize win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm) buf_sd, _ = win_sd.Shared_query(0) all_scalar_data = np.ndarray( buffer=buf_sd, dtype=np.float64, shape=(num_sd, num_scan*num_dim)) # Read the raw data init = True for scan in self.config.spec_scans: for scan_number in scan.scan_numbers: if init: init = False else: scanparser = scan.get_scanparser(scan_number) assert spec_scan_shape == scanparser.spec_scan_shape ddata, placeholder_used = scanparser.get_detector_data( detector_ids, placeholder_data=placeholder_data) data[offset] = ddata start_dim = offset * num_dim end_dim = start_dim + num_dim for i, dim in enumerate(self.config.independent_dimensions): independent_dimensions[i][start_dim:end_dim] = \ dim.get_value( scan, scan_number, scan_step_index=-1, relative=False) for i, dim in enumerate(self.config.all_scalar_data): all_scalar_data[i][start_dim:end_dim] = dim.get_value( scan, scan_number, scan_step_index=-1, relative=False) if placeholder_data is not False: all_scalar_data[-1][start_dim:end_dim] = \ placeholder_used offset += 1 return ( np.swapaxes( data.reshape((np.prod(data.shape[:2]), *data.shape[2:])), 0, 1), independent_dimensions, all_scalar_data) # @profile def _read_raw_data(self, comm, num_scan, offset): """Read the raw data for a given map configuration. :param comm: MPI communicator. :type comm: mpi4py.MPI.Comm, optional :param num_scan: Number of scans in the map. :type num_scan: int :param offset: Offset scan number of current processor. :type offset: int :return: Map's raw data, independent dimensions and scalar data. :rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray """ # Third party modules try: from mpi4py import MPI from mpi4py.util import dtlib except ImportError: pass # Local modules from CHAP.utils.general import list_to_string if comm is None: self.num_proc = 1 rank = 0 else: self.num_proc = comm.Get_size() rank = comm.Get_rank() if not rank: self.logger.debug(f'Number of processors: {self.num_proc}') self.logger.debug(f'Number of scans: {num_scan}') # Create the shared data buffers assert len(self.config.spec_scans) == 1 scans = self.config.spec_scans[0] scan_numbers = scans.scan_numbers scanparser = scans.get_scanparser(scan_numbers[0]) #RV only correct for multiple detectors if the same image sizes if len(self.detector_config.detectors) != 1: raise ValueError('Multiple detectors not tested yet') # FIX eliminate need for testing for self.config.experiment_type # in scanparser if self.config.experiment_type == 'TOMO': dtype = np.float32 if self.detector_config.roi is None: detector_roi = [slice(None), slice(None)] else: detector_roi = self.detector_config.roitoslice() ddata = scanparser.get_detector_data( self.detector_config.detectors[0].get_id(), detector_roi=detector_roi, dtype=dtype) else: dtype = None ddata = scanparser.get_detector_data( self.detector_config.detectors[0].get_id()) num_det = len(self.detector_config.detectors) num_dim = ddata.shape[0] num_id = len(self.config.independent_dimensions) num_sd = len(self.config.all_scalar_data) if self.num_proc == 1: assert num_scan == len(scan_numbers) data = num_det*[num_scan*[None]] independent_dimensions = np.empty( (num_scan, num_id, num_dim), dtype=np.float64) if num_sd: all_scalar_data = np.empty( (num_scan, num_sd, num_dim), dtype=np.float64) else: self.logger.debug(f'Scan offset on processor {rank}: {offset}') self.logger.debug(f'Scan numbers on processor {rank}: ' f'{list_to_string(scan_numbers)}') datatype = dtlib.from_numpy_dtype(dtype) itemsize = datatype.Get_size() if not rank: nbytes = num_scan * np.prod(ddata.shape) * itemsize else: nbytes = 0 win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm) buf, _ = win.Shared_query(0) #RV improve memory requirements ala single processor case? data = np.ndarray( buffer=buf, dtype=dtype, shape=(num_det, num_scan, *ddata.shape)) datatype = dtlib.from_numpy_dtype(np.float64) itemsize = datatype.Get_size() if not rank: nbytes = num_scan * num_id * num_dim * itemsize else: nbytes = 0 win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm) buf_id, _ = win_id.Shared_query(0) independent_dimensions = np.ndarray( buffer=buf_id, dtype=np.float64, shape=(num_scan, num_id, num_dim)) if num_sd: if not rank: nbytes = num_scan * num_sd * num_dim * itemsize win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm) buf_sd, _ = win_sd.Shared_query(0) all_scalar_data = np.ndarray( buffer=buf_sd, dtype=np.float64, shape=(num_scan, num_sd, num_dim)) else: all_scalar_data = None # Read the raw data init = True for scans in self.config.spec_scans: for scan_number in scans.scan_numbers: for i in range(len((self.detector_config.detectors))): if init: init = False data[i][offset] = ddata del ddata else: scanparser = scans.get_scanparser(scan_number) if self.config.experiment_type == 'TOMO': if self.detector_config.roi is None: detector_roi = [ slice(None), slice(None)] else: detector_roi = \ self.detector_config.roitoslice() data[i][offset] = scanparser.get_detector_data( self.detector_config.detectors[i].get_id(), detector_roi=detector_roi, dtype=dtype) else: data[i][offset] = scanparser.get_detector_data( self.detector_config.detectors[0].get_id()) for i, dim in enumerate(self.config.independent_dimensions): if dim.data_type in ['scan_column', 'detector_log_timestamps']: independent_dimensions[offset,i] = dim.get_value( scans, scan_number, scan_step_index=-1, relative=False)[:num_dim] elif dim.data_type in ['smb_par', 'spec_motor', 'expression']: independent_dimensions[offset,i] = dim.get_value( scans, scan_number, scan_step_index=-1, relative=False, scalar_data=self.config.scalar_data) else: independent_dimensions[offset,i] = dim.get_value( scans, scan_number, scan_step_index=-1) for i, dim in enumerate(self.config.all_scalar_data): all_scalar_data[offset,i] = dim.get_value( scans, scan_number, scan_step_index=-1, relative=False) offset += 1 if self.num_proc == 1: data = np.asarray(data) if num_sd: return ( data.reshape( (data.shape[0], np.prod(data.shape[1:3]), *data.shape[3:])), np.stack(tuple([independent_dimensions[:,i].flatten() for i in range(num_id)])), np.stack(tuple([all_scalar_data[:,i].flatten() for i in range(num_sd)]))) return ( data.reshape( (data.shape[0], np.prod(data.shape[1:3]), *data.shape[3:])), np.stack(tuple([independent_dimensions[:,i].flatten() for i in range(num_id)])), None)
[docs] class MPICollectProcessor(Processor): """A Processor that collects the distributed worker data from MPIMapProcessor on the root node. """
[docs] def process(self, data, comm, root_as_worker=True): """Collect data on root node. :param data: Input data. :type data: list[PipelineData] :param comm: MPI communicator. :type comm: mpi4py.MPI.Comm, optional :param root_as_worker: Use the root node as a worker, defaults to `True`. :type root_as_worker: bool, optional :return: Distributed worker data on the root node. :rtype: list """ num_proc = comm.Get_size() rank = comm.Get_rank() if root_as_worker: data = self.get_pipelinedata_item(data) if num_proc > 1: data = comm.gather(data, root=0) else: for n_worker in range(1, num_proc): if rank == n_worker: comm.send(self.get_pipelinedata_item(data), dest=0) data = None elif not rank: if n_worker == 1: data = [comm.recv(source=n_worker)] else: data.append(comm.recv(source=n_worker)) #FIX RV TODO Merge the list of data items in some generic fashion return data
[docs] class MPIMapProcessor(Processor): """A Processor that applies a parallel generic sub-pipeline to a map configuration. """
[docs] def process(self, data, config=None, sub_pipeline=None): """Run a parallel generic sub-pipeline. :param data: Input data. :type data: list[PipelineData] :param config: Initialization parameters for an instance of :class:`~CHAP.common.models.map.MapConfig`. :type config: dict, optional :param sub_pipeline: Sub-pipeline. :type sub_pipeline: Pipeline, optional :return: `data` field of the first item in the returned list of sub-pipeline items. :rtype: Any """ # Third party modules from mpi4py import MPI # Local modules from CHAP.models import RunConfig from CHAP.runner import run from CHAP.common.models.map import SpecScans raise NotImplementedError('MPIMapProcessor needs updating and testing') # pylint: disable=c-extension-no-member comm = MPI.COMM_WORLD num_proc = comm.Get_size() rank = comm.Get_rank() # Get the validated map configuration map_config = self.get_config( data=data, config=config, schema='common.models.map.MapConfig') # Create the spec reader configuration for each processor # FIX: catered to EDD with one spec scan assert len(map_config.spec_scans) == 1 spec_scans = map_config.spec_scans[0] scan_numbers = spec_scans.scan_numbers num_scan = len(scan_numbers) scans_per_proc = num_scan//num_proc n_scan = 0 for n_proc in range(num_proc): num = scans_per_proc if n_proc == rank: if rank < num_scan - scans_per_proc*num_proc: num += 1 scan_numbers = scan_numbers[n_scan:n_scan+num] n_scan += num spec_config = { 'station': map_config.station, 'experiment_type': map_config.experiment_type, 'spec_scans': [SpecScans( spec_file=spec_scans.spec_file, scan_numbers=scan_numbers)]} # Get the run configuration to use for the sub-pipeline if sub_pipeline is None: sub_pipeline = {} run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir, 'interactive': self.interactive, 'log_level': self.log_level} run_config.update(sub_pipeline.get('config')) run_config = RunConfig(**run_config, comm=comm) pipeline_config = [] for item in sub_pipeline['pipeline']: if isinstance(item, dict): for k, v in deepcopy(item).items(): if k.endswith('Reader'): v['config'] = spec_config item[k] = v if num_proc > 1 and k.endswith('Writer'): r, e = os.path.splitext(v['filename']) v['filename'] = f'{r}_{rank}{e}' item[k] = v pipeline_config.append(item) # Run the sub-pipeline on each processor return run(run_config, pipeline_config, logger=self.logger, comm=comm)
[docs] class MPISpawnMapProcessor(Processor): """A Processor that applies a parallel generic sub-pipeline to a map configuration by spawning workers processes. """
[docs] def process( self, data, num_proc=1, root_as_worker=True, collect_on_root=False, sub_pipeline=None): """Spawn workers running a parallel generic sub-pipeline. :param data: Input data. :type data: list[PipelineData] :param num_proc: Number of spawned processors, defaults to `1`. :type num_proc: int, optional :param root_as_worker: Use the root node as a worker, defaults to `True`. :type root_as_worker: bool, optional :param collect_on_root: Collect the result of the spawned workers on the root node, defaults to `False`. :type collect_on_root: bool, optional :param sub_pipeline: Sub-pipeline. :type sub_pipeline: Pipeline, optional :return: `data` field of the first item in the returned list of sub-pipeline items. """ # Third party modules from mpi4py import MPI import yaml # Local modules from CHAP.models import RunConfig from CHAP.runner import runner from CHAP.common.models.map import SpecScans raise NotImplementedError('MPIMapProcessor needs updating and testing') # Get the map configuration from data map_config = self.get_config( data=data, schema='common.models.map.MapConfig') # Get the run configuration to use for the sub-pipeline # Optionally include the root node as a worker node if sub_pipeline is None: sub_pipeline = {} run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir, 'interactive': self.interactive, 'log_level': self.log_level} run_config.update(sub_pipeline.get('config')) if root_as_worker: first_proc = 1 spawn = 1 else: first_proc = 0 spawn = -1 run_config = RunConfig(**run_config, logger=self.logger, spawn=spawn) # Create the sub-pipeline configuration for each processor spec_scans = map_config.spec_scans[0] scan_numbers = spec_scans.scan_numbers num_scan = len(scan_numbers) scans_per_proc = num_scan//num_proc n_scan = 0 pipeline_config = [] for n_proc in range(num_proc): num = scans_per_proc if n_proc < num_scan - scans_per_proc*num_proc: num += 1 spec_config = { 'station': map_config.station, 'experiment_type': map_config.experiment_type, 'spec_scans': [SpecScans( spec_file=spec_scans.spec_file, scan_numbers=scan_numbers[n_scan:n_scan+num]).__dict__]} sub_pipeline_config = [] for item in deepcopy(sub_pipeline['pipeline']): if isinstance(item, dict): for k, v in deepcopy(item).items(): if k.endswith('Reader'): v['config'] = spec_config item[k] = v if num_proc > 1 and k.endswith('Writer'): r, e = os.path.splitext(v['filename']) v['filename'] = f'{r}_{n_proc}{e}' item[k] = v sub_pipeline_config.append(item) if collect_on_root and (not root_as_worker or num_proc > 1): sub_pipeline_config += [ {'common.MPICollectProcessor': { 'root_as_worker': root_as_worker}}] pipeline_config.append(sub_pipeline_config) n_scan += num # Spawn the workers to run the sub-pipeline if num_proc > first_proc: # System modules from tempfile import NamedTemporaryFile tmp_names = [] with NamedTemporaryFile(delete=False) as fp: # pylint: disable=c-extension-no-member fp_name = fp.name tmp_names.append(fp_name) with open(fp_name, 'w', encoding='utf-8') as f: yaml.dump( {'config': {'spawn': run_config.spawn}}, f, sort_keys=False) for n_proc in range(first_proc, num_proc): f_name = f'{fp_name}_{n_proc}' tmp_names.append(f_name) with open(f_name, 'w', encoding='utf-8') as f: yaml.dump( #FIX once comm is a field of RunConfig #{'config': run_config.model_dump(exclude='comm'), {'config': run_config.model_dump(), 'pipeline': pipeline_config[n_proc]}, f, sort_keys=False) # pylint: disable=used-before-assignment sub_comm = MPI.COMM_SELF.Spawn( 'CHAP', args=[fp_name], maxprocs=num_proc-first_proc) common_comm = sub_comm.Merge(False) if run_config.spawn > 0: # Align with the barrier in RunConfig() on common_comm # called from the spawned main() common_comm.barrier() else: common_comm = None # Run the sub-pipeline on the root node if root_as_worker: data = runner(run_config, pipeline_config[0], comm=common_comm) elif collect_on_root: run_config.spawn = 0 pipeline_config = [{'common.MPICollectProcessor': { 'root_as_worker': root_as_worker}}] data = runner(run_config, pipeline_config, common_comm) else: # Align with the barrier in run() on common_comm # called from the spawned main() common_comm.barrier() data = None # Disconnect spawned workers and cleanup temporary files if num_proc > first_proc: # Align with the barrier in main() on common_comm # when disconnecting the spawned worker common_comm.barrier() # Disconnect spawned workers and cleanup temporary files sub_comm.Disconnect() for tmp_name in tmp_names: os.remove(tmp_name) return data
[docs] class NexusToNumpyProcessor(Processor): """A Processor to convert the default plottable data in a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__, object into a `numpy.ndarray`. """
[docs] def process(self, data): """Return the default plottable data signal in a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__, object contained in `data` as an `numpy.ndarray`. :param data: Input data. :type data: list[PipelineData] :raises ValueError: If `data` has no default plottable data signal. :return: Default plottable data signal. :rtype: numpy.ndarray """ # Third party modules from nexusformat.nexus import NXdata data = self.get_pipelinedata_item(data) if isinstance(data, NXdata): default_data = data else: default_data = data.plottable_data if default_data is None: default_data_path = data.attrs.get('default') default_data = data.get(default_data_path) if default_data is None: raise ValueError( f'The structure of {data} contains no default data') try: default_signal = default_data.attrs['signal'] except Exception as exc: raise ValueError( f'The signal of {default_data} is unknown') from exc np_data = default_data[default_signal].nxdata return np_data
[docs] class NexusToXarrayProcessor(Processor): """A Processor to convert the default plottable data in a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__, object into an `xarray.DataArray`. """
[docs] def process(self, data): """Return the default plottable data signal in a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__, object contained in `data` as an `xarray.DataArray`. :param data: Input data. :type data: list[PipelineData] :raises ValueError: If metadata for `xarray` is absent from `data` :return: Default plottable data signal. :rtype: xarray.DataArray """ # Third party modules from nexusformat.nexus import NXdata # pylint: disable=import-error from xarray import DataArray # pylint: enable=import-error data = self.get_pipelinedata_item(data) if isinstance(data, NXdata): default_data = data else: default_data = data.plottable_data if default_data is None: default_data_path = data.attrs.get('default') default_data = data.get(default_data_path) if default_data is None: raise ValueError( f'The structure of {data} contains no default data') try: default_signal = default_data.attrs['signal'] except Exception as exc: raise ValueError( f'The signal of {default_data} is unknown') from exc signal_data = default_data[default_signal].nxdata axes = default_data.attrs['axes'] if isinstance(axes, str): axes = [axes] coords = {} for axis_name in axes: axis = default_data[axis_name] coords[axis_name] = (axis_name, axis.nxdata, axis.attrs) dims = tuple(axes) name = default_signal attrs = default_data[default_signal].attrs return DataArray(data=signal_data, coords=coords, dims=dims, name=name, attrs=attrs)
[docs] class NexusToZarrProcessor(Processor): """Converter for `NeXus <https://www.nexusformat.org>`__ to `Zarr <https://zarr.readthedocs.io/en/stable/>`__ format. """
[docs] def process(self, data, chunks='auto'): """Copy and return a `Zarr group <https://zarr.readthedocs.io/en/stable/api/zarr/group/#zarr.Group>`__ object from a NeXus style `NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__ object. :param data: Input data. :type data: list[PipelineData] :return: Zarr style group object. :rtype: zarr.Group """ # Third party modules from nexusformat.nexus import ( NXfield, NXgroup, ) # pylint: disable=import-error import zarr from zarr.storage import MemoryStore # pylint: enable=import-error nexus_group = self.get_data(data) if isinstance(chunks, int): chunks = [chunks] zarr_group = zarr.create_group(store=MemoryStore({})) def copy_group(nexus_group, zarr_group): """Copy a NeXus style `NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__ object to a `Zarr group <https://zarr.readthedocs.io/en/stable/api/zarr/group/#zarr.Group>`__ object. :param source_store: Source NeXus style `NXgroup`. :type: nexusformat.nexus.NXgroup :param dest_store: Destination Zarr group. :type: zarr.Group """ self.logger.info(f'Copying {nexus_group.nxpath}') # Copy attributes for attr_key, attr_value in nexus_group.attrs.items(): if isinstance(attr_value.nxvalue, np.ndarray): zarr_group.attrs[attr_key] = attr_value.nxvalue.tolist() else: zarr_group.attrs[attr_key] = attr_value.nxvalue # Copy datasets and sub-groups for key, item in nexus_group.items(): if isinstance(item, NXfield): if isinstance(item.nxdata, np.ndarray): try: # Determine chunks if isinstance(chunks, list): if len(chunks) < len(item.nxdata.shape): _chunks = ( *chunks, *item.nxdata.shape[len(chunks):] ) elif len(chunks) > len(item.nxdata.shape): _chunks = 'auto' else: _chunks = chunks else: _chunks = chunks # Copy dataset zarr_dset = zarr_group.create_array( name=key, shape=item.nxdata.shape, dtype=item.nxdata.dtype, attributes={k: v.nxvalue for k, v in item.attrs.items()}, chunks=_chunks, ) self.logger.info(f'Copying {item.nxpath}') zarr_dset[:] = item.nxdata except Exception as exc: self.logger.error(f'{item.nxpath}: {exc}') else: self.logger.warning(f'Ignoring {item.nxpath}') elif isinstance(item, NXgroup): # Recursively copy subgroup zarr_subgroup = zarr_group.create_group(key) copy_group(item, zarr_subgroup) copy_group(nexus_group, zarr_group) return zarr_group
[docs] class NormalizeNexusProcessor(Processor): """Processor for scaling one or more `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ objects in the input NeXus style `NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__ object by the values of another NXfield in the same object . """
[docs] def process(self, data, normalize_nxfields, normalize_by_nxfield): """Return copy of the original input NeXus style `NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__ object with additional fields containing the normalized data of each field in `normalize_nxfields`. :param data: Input data. to normalize them. :type data: list[PipelineData] :param normalize_nxfields: :type normalize_nxfields: list[str] :param normalize_by_nxfield: Path in `data` to the `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ containing normalization data :type normalize_by_nxfield: str :returns: Copy of input data with additional normalized fields :rtype: nexusformat.nexus.NXgroup """ # Third party modules from nexusformat.nexus import ( NXgroup, NXfield, ) # Local modules from CHAP.utils.general import nxcopy # Check input data data = self.get_pipelinedata_item(data) data = nxcopy(data) if not isinstance(data, NXgroup): raise TypeError(f'Expected NXgroup, got (type{data})') # Check normalize_by_nxfield if normalize_by_nxfield not in data: raise ValueError( f'{normalize_by_nxfield} not present in input data') if not isinstance(data[normalize_by_nxfield], NXfield): raise TypeError( f'{normalize_by_nxfield} is {type(data[normalize_by_nxfield])}' + ', expected NXfield') normalization_data = data[normalize_by_nxfield].nxdata # Process normalize_nxfields for nxfield in normalize_nxfields: if nxfield not in data: self.logger.error(f'{nxfield} not present in input data') elif not isinstance(data[nxfield], NXfield): self.logger.error( f'{nxfield} is {type(data[nxfield])}, expected NXfield') else: field_shape = data[nxfield].nxdata.shape if not normalization_data.shape == \ field_shape[:normalization_data.ndim]: self.logger.error( f'Incompatible dataset shapes: {normalize_by_nxfield} ' + f'is {normalization_data.shape}, ' + f'{nxfield} is {field_shape}' ) else: self.logger.info(f'Normalizing {nxfield}') # make shapes compatible _normalization_data = normalization_data.reshape( normalization_data.shape + (1,) * (data[nxfield].nxdata.ndim - normalization_data.ndim)) data[f'{nxfield}_normalized'] = NXfield( value=data[nxfield].nxdata / _normalization_data, attrs={**data[nxfield].attrs, 'normalized_by': normalize_by_nxfield} ) return data
[docs] class NormalizeMapProcessor(Processor): """Processor for calling :class:`~CHAP.common.processor.NormalizeNexusProcessor` for (usually all) detector data in a NeXus style `NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__ object created by :class:`~CHAP.common.processor.MapProcessor` """
[docs] def process(self, data, normalize_by_nxfield, detector_ids=None): """Return copy of the original input map with additional fields containing normalized detector data. :param data: Input data. :type data: list[PipelineData] :param normalize_by_nxfield: Path in `data` to the `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ containing normalization data. :type normalize_by_nxfield: str :returns: Copy of input data with additional normalized fields. :rtype: nexusformat.nexus.NXroot """ # Third party modules from nexusformat.nexus import ( NXentry, NXlink, ) # Check input data data = self.get_pipelinedata_item(data) map_title = None for k, v in data.items(): if isinstance(v, NXentry): map_title = k break if map_title is None: self.logger.error(f'Input data contains no NXentry') else: self.logger.info(f'Got map_title: {map_title}') # Check detector_ids normalize_nxfields = [] if detector_ids is None: detector_ids = [k for k in data[map_title].data.keys() if not isinstance(data[map_title].data[k], NXlink)] self.logger.info(f'Using detector_ids: {detector_ids}') normalize_nxfields = [f'{map_title}/data/{_id}' for _id in detector_ids] # Normalize normalizer = NormalizeNexusProcessor() normalizer.logger = self.logger return normalizer.process( data, normalize_nxfields, normalize_by_nxfield)
[docs] class PandasToXarrayProcessor(Processor): """Converter for `pandas.DataFrame` to `xarray.DataArray` or `xarray.Dataset` """
[docs] def process(self, data): """Return input dataframe converted to xarray. :param data: Input data. :type data: list[PipelineData] :returns: Input dataframe as xarray. :rtype: xarray.DataArray | xarray.Dataset """ dataframe = self.get_data(data) return dataframe.to_xarray()
[docs] class PrintProcessor(Processor): """A Processor to simply print the input data to stdout and return the original input data, unchanged in any way. """
[docs] def process(self, data): """Print and return the input data. :param data: Input data. :type data: list[PipelineData] :return: `data` :rtype: Any """ if callable(getattr(data, '_str_tree', None)): # If data is likely a NXobject, print its tree # representation (since NXobjects' str representations are # just their nxname) print(data._str_tree(attrs=True, recursive=True)) else: # System modules from pprint import pprint pprint(data) return data
[docs] class PyfaiAzimuthalIntegrationProcessor(Processor): """Processor to azimuthally integrate one or more frames of 2d detector data using the `pyFAI <https://pyfai.readthedocs.io/en/stable>`__ package. """
[docs] def process( self, data, poni_file, npt, mask_file=None, integrate1d_kwargs=None): """Azimuthally integrate the detector data provided and return the result as a dictionary of numpy arrays containing the values of the radial coordinate of the result, the intensities along the radial direction, and the poisson errors for each intensity spectrum. :param data: Detector data to integrate. :type data: PipelineData | list[np.ndarray] :param poni_file: Name of the [pyFAI PONI file] containing the detector properties pyFAI needs to perform azimuthal integration. :type poni_file: str :param npt: Number of points in the output pattern. :type npt: int :param mask_file: File to use for masking the input data. :type mask_file: str, optional :param integrate1d_kwargs: Optional dictionary of keywords :type integrate1d_kwargs: Optional[dict] :returns: Azimuthal integration results as a dictionary of numpy arrays. """ # Third party modules from pyFAI import load if not os.path.isabs(poni_file): poni_file = os.path.join(self.inputdir, poni_file) ai = load(poni_file) if mask_file is None: mask = None else: # Third party modules import fabio if not os.path.isabs(mask_file): mask_file = os.path.join(self.inputdir, mask_file) mask = fabio.open(mask_file).data try: det_data = self.get_pipelinedata_item(data) except ValueError: det_data = data if integrate1d_kwargs is None: integrate1d_kwargs = {} integrate1d_kwargs['mask'] = mask return [ai.integrate1d(d, npt, **integrate1d_kwargs) for d in det_data]
[docs] class RawDetectorDataMapProcessor(Processor): """A Processor to return a map of raw detector data in a NeXus style `NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__ object. """
[docs] def process(self, data, detector_name, detector_shape): """Process configurations for a map and return the raw detector data data collected over the map. :param data: Input data. :type data: list[PipelineData] :param detector_name: Detector prefix. :type detector_name: str :param detector_shape: Detector data shape for a single scan step. :type detector_shape: list :return: Map of raw detector data. :rtype: nexusformat.nexus.NXroot """ map_config = self.get_config(data) nxroot = self.get_nxroot(map_config, detector_name, detector_shape) return nxroot
[docs] def get_config(self, data): """Get instances of the map configuration object needed by this `Processor`. :param data: Result of `Reader.read` where at least one item has the value `'common.models.map.MapConfig'` for the `'schema'` key. :type data: list[PipelineData] :raises Exception: If a valid map config object cannot be constructed from `data`. :return: Valid instance of the map configuration object with field values taken from `data`. :rtype: MapConfig """ map_config = False if isinstance(data, list): for item in data: if isinstance(item, dict): if item.get('schema') == 'common.models.map.MapConfig': map_config = item.get('data') if not map_config: raise ValueError('No map configuration found in input data') return MapConfig(**map_config)
[docs] def get_nxroot(self, map_config, detector_name, detector_shape): """Get a map of the detector data collected by the scans in `map_config`. The data will be returned along with some relevant metadata in the form of a NeXus style `NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__ structure. :param map_config: Map configuration. :type map_config: MapConfig :param detector_name: Detector prefix. :type detector_name: str :param detector_shape: Detector data shape for a single scan step. :type detector_shape: list :return: Map of the raw detector data. :rtype: nexusformat.nexus.NXroot """ # Third party modules # pylint: disable=no-name-in-module from nexusformat.nexus import ( NXdata, NXdetector, NXinstrument, NXroot, ) # pylint: enable=no-name-in-module raise RuntimeError('Not updated for the new MapProcessor') nxroot = NXroot() nxroot[map_config.title] = MapProcessor.get_nxentry(map_config) nxentry = nxroot[map_config.title] nxentry.instrument = NXinstrument() nxentry.instrument.detector = NXdetector() nxentry.instrument.detector.data = NXdata() nxdata = nxentry.instrument.detector.data nxdata.raw = np.empty((*map_config.shape, *detector_shape)) nxdata.raw.attrs['units'] = 'counts' for i, det_axis_size in enumerate(detector_shape): nxdata[f'detector_axis_{i}_index'] = np.arange(det_axis_size) for map_index in np.ndindex(map_config.shape): scans, scan_number, scan_step_index = \ map_config.get_scan_step_index(map_index) scanparser = scans.get_scanparser(scan_number) self.logger.debug( f'Adding data to nxroot for map point {map_index}') nxdata.raw[map_index] = scanparser.get_detector_data( detector_name, scan_step_index) nxentry.data.makelink( nxdata.raw, name=detector_name) for i, det_axis_size in enumerate(detector_shape): nxentry.data.makelink( nxdata[f'detector_axis_{i}_index'], name=f'{detector_name}_axis_{i}_index' ) if isinstance(nxentry.data.attrs['axes'], str): nxentry.data.attrs['axes'] = [ nxentry.data.attrs['axes'], f'{detector_name}_axis_{i}_index'] else: nxentry.data.attrs['axes'] += [ f'{detector_name}_axis_{i}_index'] nxentry.data.attrs['signal'] = detector_name return nxroot
[docs] class SetupNXdataProcessor(Processor): """Processor to set up and return an "empty" representation of a structured dataset. This representation will be an instance of a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object that has: A NeXus style `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ entry for every coordinate/signal specified. `nxaxes` that are the `NXfield` entries for the coordinates and contain the values provided for each coordinate. NXfield entries of appropriate shape, but containing all zeros, for every signal. Attributes that define the axes, plus any additional attributes specified by the user. This `Processor` is most useful as a "setup" step for constucting a representation of / container for a complete dataset that will be filled out in pieces later by :class:`~CHAP.common.processor.UpdateNXdataProcessor`. """
[docs] def process( self, data, nxname='data', coords=None, signals=None, attrs=None, data_points=None, extra_nxfields=None, duplicates='overwrite'): """Return a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object that has the requisite axes and `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ entries to represent a structured dataset with the properties provided. Properties may be provided either through the `data` argument (from an appropriate `PipelineItem` that immediately preceeds this one in a `Pipeline`), or through the `coords`, `signals`, `attrs`, and/or `data_points` arguments. If any of the latter are used, their values will completely override any values for these parameters found from `data`. :param data: Input data. :type data: list[PipelineData] :param nxname: Name for the returned NXdata object, defaults to `'data'`. :type nxname: str, optional :param coords: List of dictionaries defining the coordinates of the dataset. Each dictionary must have the keys `'name'` and `'values'`, whose values are the name of the coordinate axis (a string) and all the unique values of that coordinate for the structured dataset (a list of numbers), respectively. A third item in the dictionary is optional, but highly recommended: `'attrs'` may provide a dictionary of attributes to attach to the coordinate axis that assist in in interpreting the returned NXdata representation of the dataset. It is strongly recommended to provide the units of the values along an axis in the `attrs` dictionary. :type coords: list[dict[str, object]], optional :param signals: List of dictionaries defining the signals of the dataset. Each dictionary must have the keys `'name'` and `'shape'`, whose values are the name of the signal field (a string) and the shape of the signal's value at each point in the dataset (a list of zero or more integers), respectively. A third item in the dictionary is optional, but highly recommended: `'attrs'` may provide a dictionary of attributes to attach to the signal fieldthat assist in in interpreting the returned NXdata representation of the dataset. It is strongly recommended to provide the units of the signal's values `attrs` dictionary. :type signals: list[dict[str, object]], optional :param attrs: An arbitrary dictionary of attributes to assign to the returned NXdata object. :type attrs: dict[str, object], optional :param data_points: Data points to partially (or even entirely) fill out the "empty" signal NXfield's before returning the NXdata object. :type data_points: list[dict[str, object]], optional :param extra_nxfields: List "extra" NXfields to include that can be described neither as a signal of the dataset, not a dedicated coordinate. This paramteter is good for including "alternate" values for one of the coordinate dimensions -- the same coordinate axis expressed in different units, for instance. Each item in the list should be a dictionary of parameters for the `nexusformat.nexus.NXfield` constructor. :type extra_nxfields: list[dict[str, object]], optional :param duplicates: Behavior to use if any new data points occur at the same point in the dataset's coordinate space as an existing data point. Allowed values for `duplicates` are: `'overwrite'` and `'block'`. Defaults to `'overwrite'`. :type duplicates: Literal['overwrite', 'block'] :returns: Structured dataset as specified. :rtype: nexusformat.nexus.NXdata """ self.nxname = nxname if coords is None: coords = [] if signals is None: signals = [] if attrs is None: attrs = {} if extra_nxfields is None: extra_nxfields = [] self.coords = coords self.signals = signals self.attrs = attrs self.data_points = data_points try: setup_params = self.get_pipelinedata_item(data) except Exception: setup_params = None if isinstance(setup_params, dict): for a in ('coords', 'signals', 'attrs', 'data_points'): setup_param = setup_params.get(a) if not getattr(self, a) and setup_param is not None: self.logger.info(f'Using input data from pipeline for {a}') setattr(self, a, setup_param) else: self.logger.info( f'Ignoring input data from pipeline for {a}') else: self.logger.warning('Ignoring all input data from pipeline') self.shape = tuple(len(c['values']) for c in self.coords) self.extra_nxfields = extra_nxfields self.duplicates = duplicates self.init_nxdata() if self.data_points is not None: for d in self.data_points: self.add_data_point(d) return self.nxdata
[docs] def add_data_point(self, data_point): """Add a data point to this dataset. 1. Validate `data_point`. 2. Append `data_point` to `self.data_points`. 3. Update signal NXfields in `self.nxdata`. :param data_point: Data point defining a point in the dataset's coordinate space and the new signal values at that point. :type data_point: dict[str, object] """ self.logger.info( f'Adding data point no. {data_point["dataset_point_index"]+1} of ' f'{len(self.data_points)}') self.logger.debug(f'New data point: {data_point}') valid, msg = self.validate_data_point(data_point) if not valid: self.logger.error(f'Cannot add data point: {msg}') else: self.update_nxdata(data_point)
[docs] def validate_data_point(self, data_point): """Return `True` if `data_point` occurs at a valid point in this structured dataset's coordinate space, `False` otherwise. Also validate shapes of signal values and add NaN values for any missing signals. :param data_point: Data point defining a point in the dataset's coordinate space and the new signal values at that point. :type data_point: dict[str, object] :returns: Validity of `data_point`, message :rtype: bool, str """ valid = True msg = '' # Convert all values to numpy types data_point = {k: np.asarray(v) for k, v in data_point.items()} # Ensure data_point defines a specific point in the dataset's # coordinate space if not all(c['name'] in data_point for c in self.coords): valid = False msg = 'Missing coordinate values' # Ensure a value is present for all signals for s in self.signals: name = s['name'] if name not in data_point: data_point[name] = np.full(s['shape'], 0) else: if not data_point[name].shape == tuple(s['shape']): valid = False msg = f'Shape mismatch for signal {s}' return valid, msg
[docs] def init_nxdata(self): """Initialize an empty NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html>`__ representing this dataset to `self.nxdata`; values for axes `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ objects are filled out, values for signals' NXfields are empty an can be filled out later. """ # Third party modules from nexusformat.nexus import ( NXdata, NXfield, ) axes = tuple(NXfield( value=c['values'], name=c['name'], attrs=c.get('attrs'), dtype=c.get('dtype', 'float64')) for c in self.coords) entries = {s['name']: NXfield( value=np.full((*self.shape, *s['shape']), 0), name=s['name'], attrs=s.get('attrs'), dtype=s.get('dtype', 'float64')) for s in self.signals} extra_nxfields = [NXfield(**params) for params in self.extra_nxfields] extra_nxfields = {f.nxname: f for f in extra_nxfields} entries.update(extra_nxfields) self.nxdata = NXdata( name=self.nxname, axes=axes, entries=entries, attrs=self.attrs)
[docs] def update_nxdata(self, data_point): """Update `self.nxdata`'s `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ values. :param data_point: Data point defining a point in the dataset's coordinate space and the new signal values at that point. :type data_point: dict[str, object] """ index = self.get_index(data_point) for s in self.signals: name = s['name'] if name in data_point: self.nxdata[name][index] = data_point[name]
[docs] def get_index(self, data_point): """Return a tuple representing the array index of `data_point` in the coordinate space of the dataset. :param data_point: Data point defining a point in the dataset's coordinate space. :type data_point: dict[str, object] :returns: Multi-dimensional index of `data_point` in the dataset's coordinate space. :rtype: tuple """ return tuple(c['values'].index(data_point[c['name']]) for c in self.coords)
[docs] class UnstructuredToStructuredProcessor(Processor): """Processor to reshape data in a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object from an "unstructured" to a "structured" representation. """
[docs] def process(self, data, nxpath=None): """Reshape the input data from an "unstructured" to a "structured" representation. :param data: Input data. :type data: list[PipelineData] :param nxname: Name for the returned `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object, defaults to `'data'`. :type nxname: str, optional :return: Converted data. :rtype: nexusformat.nexus.NXdata """ # Third party modules from nexusformat.nexus import NXdata try: nxobject = self.get_data(data) except Exception: nxobject = self.get_pipelinedata_item(data) if isinstance(nxobject, NXdata): return self._convert_nxdata(nxobject) if nxpath is not None: try: nxobject = nxobject[nxpath] except Exception as exc: raise ValueError( f'Invalid parameter nxpath ({nxpath})') from exc else: raise ValueError(f'Invalid input data ({data})') return self._convert_nxdata(nxobject)
def _convert_nxdata(self, nxdata): """Convert a NeXus style `NXdata` object from an "unstructured" to a "structured" representation. """ # Third party modules from nexusformat.nexus import ( NXdata, NXfield, ) # Local modules from CHAP.common.map_utils import get_axes # Extract axes from the NXdata attributes axes = get_axes(nxdata) for a in axes: if a not in nxdata: raise ValueError(f'Missing coordinates for {a}') # Check the independent dimensions and axes unstructured_axes = [] unstructured_dim = None for a in axes: if not isinstance(nxdata[a], NXfield): raise ValueError( f'Invalid axis field type ({type(nxdata[a])})') if len(nxdata[a].shape) == 1: if not unstructured_axes: unstructured_axes.append(a) unstructured_dim = nxdata[a].size else: if nxdata[a].size == unstructured_dim: unstructured_axes.append(a) elif 'unstructured_axes' in nxdata.attrs: raise ValueError('Inconsistent axes dimensions') elif 'unstructured_axes' in nxdata.attrs: raise ValueError( f'Invalid unstructered axis shape ({nxdata[a].shape})') if not axes and hasattr(nxdata, 'signal'): if len(nxdata[nxdata.signal].shape) < 2: raise ValueError( f'Invalid signal shape ({nxdata[nxdata.signal].shape})') unstructured_dim = nxdata[nxdata.signal].shape[0] for k, v in nxdata.items(): if (isinstance(v, NXfield) and len(v.shape) == 1 and v.shape[0] == unstructured_dim): unstructured_axes.append(k) if unstructured_dim is None: raise ValueError('Unable to determine the unstructered axes') axes = unstructured_axes # Identify unique coordinate points for each axis unique_coords = {} coords = {} axes_attrs = {} for a in axes: coords[a] = nxdata[a].nxdata unique_coords[a] = np.sort(np.unique(nxdata[a].nxdata)) axes_attrs[a] = deepcopy(nxdata[a].attrs) if 'target' in axes_attrs[a]: del axes_attrs[a]['target'] # Calculate the total number of unique coordinate points unique_npts = np.prod([len(v) for k, v in unique_coords.items()]) if unique_npts != unstructured_dim: self.logger.warning('The unstructered grid does not fully map to ' 'a structered one (there are missing points)') # Identify the signals and the data point axes signals = [] data_point_axes = [] data_point_shape = [] if hasattr(nxdata, 'signal'): if (len(nxdata[nxdata.signal].shape) < 2 or nxdata[nxdata.signal].shape[0] != unstructured_dim): raise ValueError( f'Invalid signal shape ({nxdata[nxdata.signal].shape})') signals = [nxdata.signal] data_point_shape = [nxdata[nxdata.signal].shape[1:]] for k, v in nxdata.items(): if (isinstance(v, NXfield) and k not in axes and k not in signals and v.shape[0] == unstructured_dim): signals.append(k) if not data_point_shape: data_point_shape.append(v.shape[1:]) if len(data_point_shape) == 1: data_point_shape = data_point_shape[0] else: data_point_shape = [] for _ in data_point_shape: for k, v in nxdata.items(): if (isinstance(v, NXfield) and k not in axes and v.shape == data_point_shape): data_point_axes.append(k) # Create the structured NXdata object structured_shape = tuple(len(unique_coords[a]) for a in axes) attrs = deepcopy(nxdata.attrs) if 'unstructured_axes' in attrs: attrs.pop('unstructured_axes') attrs['axes'] = axes nxdata_structured = NXdata( name=f'{nxdata.nxname}_structured', **{a: NXfield( value=unique_coords[a], attrs=axes_attrs[a]) for a in axes}, **{s: NXfield( # value=np.reshape( # FIX not always a sound way to reshape. # nxdata[s], (*structured_shape, *nxdata[s].shape[1:])), dtype=nxdata[s].dtype, shape=(*structured_shape, *nxdata[s].shape[1:]), attrs=nxdata[s].attrs) for s in signals}, attrs=attrs) if len(data_point_axes) == 1: axes = nxdata_structured.attrs['axes'] if isinstance(axes, str): axes = [axes] nxdata_structured.attrs['axes'] = axes + data_point_axes for a in data_point_axes: nxdata_structured[a] = NXfield( value=nxdata[a], attrs=nxdata[a].attrs) # Populate the structured NXdata object with values for i, coord in enumerate(zip(*tuple(nxdata[a].nxdata for a in axes))): structured_index = tuple( np.asarray( coord[ii] == unique_coords[axes[ii]]).nonzero()[0][0] for ii in range(len(axes))) for s in signals: nxdata_structured[s][structured_index] = nxdata[s][i] return nxdata_structured
[docs] class UpdateNXvalueProcessor(Processor): """Processor to fill in part(s) of an object representing a structured dataset that's already been written to a `NeXus <https://www.nexusformat.org>`__ file. This Processor is most useful as an "update" step for a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object created by :class:`~CHAP.common.processor.SetupNXdataProcessor`, and is most easy to use in a `Pipeline` immediately after another `PipelineItem` designed specifically to return a value that can be used as input to this `Processor`. """
[docs] def process(self, data, nxfilename, data_points=None): """Write new data values to an existing object representing an unstructured dataset in a `NeXus <https://www.nexusformat.org>`__ file. Return the list of data points used to update the dataset. :param data: Data from the previous item in a `Pipeline`. May contain a list of data points that will extend the list of data points optionally provided with the `data_points` argument. :type data: list[PipelineData] :param nxfilename: Name of the NeXus file containing the object to update. :type nxfilename: str :param data_points: List of data points, each one a dictionary whose keys are the names of the nxpath, the index of the data point in the dataset, and the data value. :type data_points: Optional[list[dict[str, object]]] :returns: Complete list of data points used to update the dataset. :rtype: list[dict[str, object]] """ # Third party modules from nexusformat.nexus import NXFile if data_points is None: data_points = [] self.logger.debug(f'Got {len(data_points)} data points from keyword') ddata_points = self.get_pipelinedata_item(data) if isinstance(ddata_points, list): self.logger.debug(f'Got {len(ddata_points)} from pipeline data') data_points.extend(ddata_points) self.logger.info(f'Updating a total of {len(data_points)} data points') if not os.path.isabs(nxfilename): nxfilename = os.path.join(self.inputdir, nxfilename) nxfile = NXFile(nxfilename, 'rw') indices = [] for data_point in data_points: try: nxfile.writevalue( data_point['nxpath'], np.asarray(data_point['value']), data_point['index']) indices.append(data_point['index']) except Exception as exc: self.logger.error( f'Error updating {data_point["nxpath"]} for data point ' f'{data_point["index"]}: {exc}') else: self.logger.debug(f'Updated data point {data_point}') nxfile.close() return data_points
[docs] class UpdateNXdataProcessor(Processor): """Processor to fill in part(s) of a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ representing a structured dataset that's already been written to a `NeXus <https://www.nexusformat.org>`__ file. This Processor is most useful as an "update" step for a `NXdata` object created by :class:`~CHAP.common.processor.SetupNXdataProcessor`, and is most easy to use in a `Pipeline` immediately after another `PipelineItem` designed specifically to return a value that can be used as input to this `Processor`. """
[docs] def process( self, data, nxfilename, nxdata_path, data_points=None, allow_approximate_coordinates=False): """Write new data points to the signal fields of an existing `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object representing a structued dataset in a `NeXus <https://www.nexusformat.org>`__ file. Return the list of data points used to update the dataset. :param data: Data from the previous item in a `Pipeline`. May contain a list of data points that will extend the list of data points optionally provided with the `data_points` argument. :type data: list[PipelineData] :param nxfilename: Name of the NeXus file containing the NXdata object to update. :type nxfilename: str :param nxdata_path: Path to the NXdata object to update in the file. :type nxdata_path: str :param data_points: List of data points, each one a dictionary whose keys are the names of the coordinates and axes, and whose values are the values of each coordinate / signal at a single point in the dataset. Deafults to None. :type data_points: Optional[list[dict[str, object]]] :param allow_approximate_coordinates: Parameter to allow the nearest existing match for the new data points' coordinates to be used if an exact match connot be found (sometimes this is due simply to differences in rounding convetions). Defaults to False. :type allow_approximate_coordinates: bool, optional :returns: Complete list of data points used to update the dataset. :rtype: list[dict[str, object]] """ # Third party modules from nexusformat.nexus import NXFile if data_points is None: data_points = [] self.logger.debug(f'Got {len(data_points)} data points from keyword') _data_points = self.get_pipelinedata_item(data) if isinstance(_data_points, list): self.logger.debug(f'Got {len(_data_points)} from pipeline data') data_points.extend(_data_points) self.logger.info(f'Updating {len(data_points)} data points total') if not os.path.isabs(nxfilename): nxfilename = os.path.join(self.inputdir, nxfilename) nxfile = NXFile(nxfilename, 'rw') nxdata = nxfile.readfile()[nxdata_path] axes_names = [a.nxname for a in nxdata.nxaxes] data_points_used = [] for i, d in enumerate(data_points): # Verify that the data point contains a value for all # coordinates in the dataset. if not all(a in d for a in axes_names): self.logger.error( f'Data point {i} is missing a value for at least one ' f'axis. Skipping. Axes are: {", ".join(axes_names)}') continue self.logger.debug( f'Coordinates for data point {i}: ' + ', '.join([f'{a}={d[a]}' for a in axes_names])) # Get the index of the data point in the dataset based on # its values for each coordinate. try: index = tuple(np.where(a.nxdata == d[a.nxname])[0][0] for a in nxdata.nxaxes) except Exception: if allow_approximate_coordinates: try: index = tuple( np.argmin(np.abs(a.nxdata - d[a.nxname])) for a in nxdata.nxaxes) self.logger.warning( f'Nearest match for coordinates of data point {i}:' ', '.join( [f'{a.nxname}={a[_i]}' for _i, a in zip(index, nxdata.nxaxes)])) except Exception: self.logger.error( f'Cannot get the index of data point {i}. ' 'Skipping.') continue else: self.logger.error( f'Cannot get the index of data point {i}. Skipping.') continue self.logger.debug(f'Index of data point {i}: {index}') # Update the signals contained in this data point at the # proper index in the dataset's singal NXfields for k, v in d.items(): if k in axes_names: continue try: nxfile.writevalue( os.path.join(nxdata_path, k), np.asarray(v), index) # self.logger.debug( # f'Wrote to {os.path.join(nxdata_path, k)} in ' # f'{nxfilename} at index {index} value: {np.asarray(v)}' # f' (type: {type(v)})') except Exception as exc: self.logger.error( f'Error updating signal {k} for new data point ' f'{i} (dataset index {index}): {exc}') data_points_used.append(d) nxfile.close() return data_points_used
[docs] class NumpyStackProcessor(Processor): """Processor for joining a sequence of arrays along a new axis. Uses (`numpy.stack`)[https://numpy.org/doc/stable/reference/generated/numpy.stack.html]. :ivar stack_order: List of names of input data arrays to determine order of stacking. If not specified, data arrays are stacked in the exact order input to the Processor. Defaults to `None`. :type stack_order: list[str], optional :ivar kwargs: Dictionary of keyword arguments to (`numpy.stack`)[https://numpy.org/doc/stable/reference/generated/numpy.stack.html]; defaults to `{}` :type kwargs: dict, optional. """ stack_order: Optional[conlist(item_type=str, min_length=1)] = None kwargs: Optional[dict] = {}
[docs] def process(self, data): import numpy as np arrays = () if self.stack_order is None: for d in data: try: arrays = (*arrays, np.asarray(d['data'])) except: self.logger.warning( f'Omitting input data {d["name"]} ' + f'(type: {type(d["data"])}).' ) else: for name in self.stack_order: arrays = (*arrays, self.get_data(name=name)) return np.stack(arrays, **self.kwargs)
[docs] class NumpySumProcessor(Processor): """Processor for summing an array of elements over a given axis. Uses (`numpy.sum`)[https://numpy.org/doc/stable/reference/generated/numpy.sum.html]. :ivar kwargs: Dictionary of keyword arguments to (`numpy.sum`)[https://numpy.org/doc/stable/reference/generated/numpy.sum.html]; defaults to `{}`. :type kwargs: dict, optional """ kwargs: Optional[dict] = {}
[docs] def process(self, data): import numpy as np _data = None for d in data[::-1]: try: _data = np.asarray(d['data']) break except: continue if _data is None: err = 'No array-like input data found.' self.logger.error(err) raise TypeError(err) _data = np.asarray(_data) self.logger.debug(f'_data.shape = {_data.shape}') return np.sum(_data, **self.kwargs)
[docs] class NumpyToNXfieldProcessor(Processor): """Processor for converting a numpy array into an `NXfield`. :ivar value: Name of input data array to use as the field's values. If unspecified, use the last array-like data object in the input data list. Defaults to `None`. :type value: str, optional :ivar kwargs: Dictionary of keyword arguments to [`nexusformat.nexus.tree.NXfield`](https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield); defaults to `{}` :type kwargs: dict, optional """ value: Optional[str] = None kwargs: Optional[dict] = {}
[docs] def process(self, data): import numpy as np from nexusformat.nexus import NXfield _data = None if self.value is None: for d in data[::-1]: try: _data = np.asarray(d['data']) self.logger.debug(f'Using {d["name"]}') break except: continue if _data is None: err = 'No array-like input data found.' self.logger.error(err) raise TypeError(err) else: _data = self.get_data(data, name=self.value) self.logger.debug(f'_data.shape = {_data.shape}') return NXfield(value=_data, **self.kwargs)
[docs] class NXdataToDataPointsProcessor(Processor): """Transform a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object into a list of dictionaries. Each dictionary represents a single data point in the coordinate space of the dataset. The keys are the names of the signals and axes in the dataset, and the values are a single scalar value (in the case of axes) or the value of the signal at that point in the coordinate space of the dataset (in the case of signals -- this means that values for signals may be any shape, depending on the shape of the signal itself). """
[docs] def process(self, data): """Return a list of dictionaries representing the coordinate and signal values at every point in the dataset provided. :param data: Input data. :type data: list[PipelineData] :returns: List of all data points in the dataset. :rtype: list[dict[str,object]] """ nxdata = self.get_pipelinedata_item(data) data_points = [] axes_names = [a.nxname for a in nxdata.nxaxes] self.logger.info(f'Dataset axes: {axes_names}') dataset_shape = tuple([a.size for a in nxdata.nxaxes]) self.logger.info(f'Dataset shape: {dataset_shape}') signal_names = [k for k, v in nxdata.entries.items() if not k in axes_names \ and v.shape[:len(dataset_shape)] == dataset_shape] self.logger.info(f'Dataset signals: {signal_names}') other_fields = [k for k, v in nxdata.entries.items() if not k in axes_names + signal_names] if len(other_fields) > 0: self.logger.warning( 'Ignoring the following fields that cannot be interpreted as ' f'either dataset coordinates or signals: {other_fields}') for i in np.ndindex(dataset_shape): data_points.append({ **{a: nxdata[a][_i] for a, _i in zip(axes_names, i)}, **{s: nxdata[s].nxdata[i] for s in signal_names}, }) return data_points
[docs] class XarrayToNexusProcessor(Processor): """A Processor to convert the data in an `xarray` structure to a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object. """
[docs] def process(self, data): """Return the input data represented as a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object. :param data: Input data. :type data: list[PipelineData] :return: Data and metadata in `data`. :rtype: nexusformat.nexus.NXdata """ # Third party modules from nexusformat.nexus import ( NXdata, NXfield, ) data = self.get_pipelinedata_item(data) signal = NXfield(value=data.data, name=data.name, attrs=data.attrs) axes = [] for name, coord in data.coords.items(): axes.append( NXfield(value=coord.data, name=name, attrs=coord.attrs)) axes = tuple(axes) return NXdata(signal=signal, axes=axes)
[docs] class XarrayToNumpyProcessor(Processor): """A Processor to convert the data in an `xarray.DataArray` structure to an `numpy.ndarray`. """
[docs] def process(self, data): """Return just the signal values contained in `data`. :param data: Input data. :type data: list[PipelineData] :return: Data in `data`. :rtype: numpy.ndarray """ return self.get_pipelinedata_item(data).data
[docs] class ZarrToNexusProcessor(Processor): """Processor for converting `Zarr <https://zarr.readthedocs.io/en/stable/>`__ data to `NeXus <https://www.nexusformat.org>`__ file. format. """
[docs] def process(self, data, zarr_filename, nexus_filename): """Convert the signal values contained in the input data. :param data: Input data. :type data: list[PipelineData] :param zarr_filename: Zarr input file name. :type zarr_filename: str :param nexus_filename: NeXus output file name. :type nexus_filename: str """ # Third party modules import h5py # pylint: disable=import-error import zarr # pylint: enable=import-error if not os.path.isabs(zarr_filename): zarr_filename = os.path.join(self.inputdir, zarr_filename) if not os.path.isabs(nexus_filename): nexus_filename = os.path.join(self.inputdir, nexus_filename) # Open the Zarr file zarr_file = zarr.open(zarr_filename, mode='r') # Create the Nexus file with h5py.File(nexus_filename, 'w') as nexus_file: # Recursively copy all datasets and attributes def copy_group(zarr_group, nexus_group): """Convert a `Zarr group <https://zarr.readthedocs.io/en/latest/api/zarr/group/#zarr.Group>`__ object to a NeXus style `NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__ object. :param zarr_group: Zarr style group. :type: zarr.Group :param nexus_group: Nexus style group. :type: nexusformat.nexus.NXgroup """ self.logger.info(f'Copying {zarr_group.path}') # Copy attributes for attr_key, attr_value in zarr_group.attrs.items(): nexus_group.attrs[attr_key] = attr_value # Copy datasets and sub-groups for key, item in zarr_group.members(): if isinstance(item, zarr.Array): self.logger.info(f'Copying {zarr_group.path}/{key}') # Copy dataset nexus_dset = nexus_group.create_dataset( name=key, data=item.__array__(), # chunks=item.chunks, # FIXME compression='gzip', compression_opts=4 # GZIP compression level ) # Copy dataset attributes for attr_key, attr_value in item.attrs.items(): nexus_dset.attrs[attr_key] = attr_value elif isinstance(item, zarr.Group): # Recursively copy subgroup nexus_subgroup = nexus_group.create_group(key) copy_group(item, nexus_subgroup) # Start copying from the root group copy_group(zarr_file, nexus_file)
if __name__ == '__main__': # Local modules from CHAP.processor import main main()