#!/usr/bin/env python
#-*- coding: utf-8 -*-
"""Module for generic Processors used in multiple experiment-specific
workflows.
"""
# System modules
from copy import deepcopy
import os
from typing import Optional
# Third party modules
import numpy as np
from pydantic import (
Field,
PrivateAttr,
conint,
conlist,
field_validator,
)
# Local modules
from CHAP.common.models.common import ImageProcessorConfig
from CHAP.common.models.map import (
DetectorConfig,
MapConfig,
)
from CHAP.pipeline import PipelineData
from CHAP.processor import Processor
[docs]
class AsyncProcessor(Processor):
"""A Processor to process multiple sets of input data via asyncio
module.
:ivar mgr: The `Processor` used to process every set of input data.
:vartype mgr: Processor
"""
def __init__(self, mgr):
super().__init__()
self.mgr = mgr
[docs]
def process(self, data):
"""Asynchronously process the input documents with the
`self.mgr` `Processor`.
:param data: Input data.
:type data: list[PipelineData]
"""
# System modules
import asyncio
async def task(mgr, doc):
"""Process given data using provided `Processor`.
:param mgr: Object that will process given data.
:type mgr: Processor
:param doc: Data to process.
:type doc: object
:return: Processed data.
:rtype: object
"""
return mgr.process(doc)
async def execute_tasks(mgr, docs):
"""Process given set of documents using provided task
manager.
:param mgr: Object that will process all documents.
:type mgr: Processor
:param docs: Set of data documents to process.
:type doc: iterable
"""
coroutines = [task(mgr, d) for d in docs]
await asyncio.gather(*coroutines)
asyncio.run(execute_tasks(self.mgr, data))
[docs]
class BinarizeProcessor(Processor):
"""A Processor to binarize a dataset.
:ivar nxmemory: Maximum memory usage when reading
`NeXus <https://www.nexusformat.org>`__ files.
:vartype nxmemory: int, optional
"""
nxmemory: Optional[conint(gt=0)] = 100000
[docs]
def process(self, data, config=None):
"""Plot and return a binarized dataset from a dataset contained
in `data`. The dataset must either be `array-like` or a NeXus
style
`NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__
object with a default plottable data path or a
specified path to a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
or
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
object.
:param data: Input data.
:type data: list[PipelineData]
:param config: Initialization parameters for an instance of
:class:`~CHAP.common.models.BinarizeConfig`.
:type config: dict, optional
:return: Binarized dataset for an `array-like` input or
a return type equal that of the input object with the
binarized dataset added.
:rtype: numpy.ndarray | nexusformat.nexus.NXobject
"""
# Third party modules
from nexusformat.nexus import (
NXdata,
NXfield,
NXlink,
nxsetconfig,
)
# Local modules
from CHAP.utils.general import nxcopy
nxsetconfig(memory=self.nxmemory)
# Load the validated processor configuration
if config is None:
# Local modules
from CHAP.common.models.common import BinarizeConfig
config = BinarizeConfig()
else:
config = self.get_config(
data, config=config,
schema='common.models.BinarizeConfig')
# Load the default data
try:
nxobject = self.get_data(data)
if config.nxpath is None:
dataset = nxobject.get_default()
else:
dataset = nxobject[config.nxpath]
if isinstance(dataset, NXdata):
nxsignal = dataset.nxsignal
data = nxsignal.nxdata
else:
data = dataset.nxdata
assert isinstance(data, np.ndarray)
except Exception:
try:
data = np.asarray(self.get_pipelinedata_item(data))
except Exception as exc:
raise ValueError(
'Unable the load a valid input data object') from exc
if config.method == 'yen':
min_ = data.min()
max_ = data.max()
data = 1 + (config.num_bin - 1) * (data - min_) / (max_ - min_)
# Get a histogram of the data
counts, edges = np.histogram(data, bins=config.num_bin)
centers = edges[:-1] + 0.5 * np.diff(edges)
# Calculate the data cutoff threshold
# pylint: disable=no-name-in-module
if config.method == 'CHAP':
weights = np.cumsum(counts)
means = np.cumsum(counts * centers)
weights = weights[0:-1] / weights[-1]
means = means[0:-1] / means[-1]
variances = (means-weights)**2 / (weights * (1. - weights))
threshold = centers[np.argmax(variances)]
elif config.method == 'otsu':
# Third party modules
from skimage.filters import threshold_otsu
threshold = threshold_otsu(hist=(counts, centers))
elif config.method == 'yen':
# Third party modules
from skimage.filters import threshold_yen
threshold = threshold_yen(hist=(counts, centers))
elif config.method == 'isodata':
# Third party modules
from skimage.filters import threshold_isodata
threshold = threshold_isodata(hist=(counts, centers))
else:
# Third party modules
from skimage.filters import threshold_minimum
threshold = threshold_minimum(hist=(counts, centers))
# pylint: enable=no-name-in-module
# Apply the data cutoff threshold
data = np.where(data < threshold, 0, 1).astype(np.ubyte)
# Return the output for array-like or NXfield inputs
if isinstance(dataset, np.ndarray):
return data
if isinstance(dataset, NXfield):
attrs = dataset.attrs
attrs.pop('target', None)
nxfield = NXfield(
value=data, name=f'{dataset.nxname}_binarized', attrs=attrs)
return nxfield
# Otherwise create a copy of the input data, add the binarized
# data to the copied original dataset, and remove the original
# dataset if config.remove_original_data is set
name = f'{nxsignal.nxname}_binarized'
nxdefault = nxobject.get_default()
if isinstance(nxsignal, NXlink):
link = dataset.nxpath
path = os.path.split(nxsignal.nxtarget)[0]
else:
link = nxdefault.nxpath
path = os.path.split(nxsignal.nxpath)[0]
exclude_nxpaths = []
if config.remove_original_data:
if link is not None:
exclude_nxpaths.append(os.path.relpath(
f'{link}/{nxsignal.nxname}', nxobject.nxpath))
exclude_nxpaths.append(os.path.relpath(
f'{path}/{nxsignal.nxname}', nxobject.nxpath))
nxobject = nxcopy(nxobject, exclude_nxpaths=exclude_nxpaths)
attrs = nxsignal.attrs
attrs.pop('target', None)
nxobject[f'{path}/{name}'] = NXfield(
value=data, name=name, attrs=attrs)
nxobject[path].attrs['signal'] = name
if link is not None:
nxobject[f'{link}/{name}'] = NXlink(f'{path}/{name}')
nxobject[link].attrs['signal'] = name
return nxobject
[docs]
class ConstructBaseline(Processor):
"""A Processor to construct a baseline for a dataset."""
[docs]
def process(
self, data, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20,
save_figures=False):
"""Construct and return the baseline for a dataset.
:param data: Input data.
:type data: numpy.ndarray | list[PipelineData]
:param x: Independent dimension (only used when running
interactively or when filename is set).
:type x: array-like, optional
:param mask: Mask to apply to the spectrum before baseline
construction.
:type mask: array-like, optional
:param tol: Convergence tolerence, defaults to `1.e-6`.
:type tol: float, optional
:param lam: &lambda (smoothness) parameter (the balance
between the residual of the data and the baseline and the
smoothness of the baseline). The suggested range is between
100 and 10^8, defaults to `10^6`.
:type lam: float, optional
:param max_iter: Maximum number of iterations,
defaults to `20`.
:type max_iter: int, optional
:param save_figures: Save .pngs of plots for checking inputs &
outputs of this Processor, defaults to `False`.
:type save_figures: bool, optional
:return: Smoothed baseline and the configuration.
:rtype: numpy.ndarray, dict
"""
try:
y = np.asarray(self.get_pipelinedata_item(data))
except Exception as exc:
raise ValueError(f'No valid data found in {data}') from exc
return self.construct_baseline(
y, x=x, mask=mask, tol=tol, lam=lam, max_iter=max_iter,
return_buf=save_figures)
[docs]
@staticmethod
def construct_baseline(
y, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, title=None,
xlabel=None, ylabel=None, interactive=False, return_buf=False):
"""Construct and return the baseline for a dataset.
:param y: Input data.
:type y: numpy.ndarray
:param x: Independent dimension (only used when interactive is
`True` of when filename is set).
:type x: array-like, optional
:param mask: Mask to apply to the spectrum before baseline
construction.
:type mask: array-like, optional
:param tol: Convergence tolerence, defaults to `1.e-6`.
:type tol: float, optional
:param lam: &lambda (smoothness) parameter (the balance
between the residual of the data and the baseline and the
smoothness of the baseline). The suggested range is between
100 and 10^8, defaults to `10^6`.
:type lam: float, optional
:param max_iter: Maximum number of iterations,
defaults to `20`.
:type max_iter: int, optional
:param title: Title for the displayed figure.
:type title: str, optional
:param xlabel: Label for the x-axis of the displayed figure.
:type xlabel: str, optional
:param ylabel: Label for the y-axis of the displayed figure.
:type ylabel: str, optional
:param interactive: Allows for user interactions,
defaults to `False`.
:type interactive: bool, optional
:param return_buf: Return an in-memory object as a byte stream
represention of the Matplotlib figure, defaults to `False`.
:type return_buf: bool, optional
:return: Smoothed baseline and the configuration and a
byte stream represention of the Matplotlib figure if
return_buf is `True` (`None` otherwise)
:rtype: numpy.ndarray, dict, io.BytesIO | None
"""
# Third party modules
from matplotlib.widgets import TextBox, Button
import matplotlib.pyplot as plt
# Local modules
from CHAP.utils.general import (
baseline_arPLS,
fig_to_iobuf,
)
def _change_fig_subtitle(maxed_out=False, subtitle=None):
"""Change the figure's subtitle."""
if fig_subtitles:
fig_subtitles[0].remove()
fig_subtitles.pop()
if subtitle is None:
subtitle = r'$\lambda$ = 'f'{lambdas[-1]:.2e}, '
if maxed_out:
subtitle += f'# iter = {num_iters[-1]} (maxed out) '
else:
subtitle += f'# iter = {num_iters[-1]} '
subtitle += f'error = {errors[-1]:.2e}'
fig_subtitles.append(
plt.figtext(*subtitle_pos, subtitle, **subtitle_props))
def select_lambda(expression):
"""Callback function for the "Select lambda" TextBox."""
if not expression:
return
try:
lam = float(expression)
if lam < 0:
raise ValueError
except ValueError:
_change_fig_subtitle(
subtitle='Invalid lambda, enter a positive number')
else:
lambdas.pop()
lambdas.append(10**lam)
baseline, _, _, num_iter, error = get_baseline(
y, mask=mask, tol=tol, lam=lambdas[-1], max_iter=max_iter)
num_iters.pop()
num_iters.append(num_iter)
errors.pop()
errors.append(error)
if num_iter < max_iter:
_change_fig_subtitle()
else:
_change_fig_subtitle(maxed_out=True)
baseline_handle.set_ydata(baseline)
lambda_box.set_val('')
plt.draw()
def continue_iter(event):
"""Callback function for the "Continue" button."""
baseline, _, w, n_iter, error = get_baseline(
y, mask=mask, w=weights[-1], tol=tol, lam=lambdas[-1],
max_iter=max_iter)
num_iters[-1] += n_iter
errors.pop()
errors.append(error)
if n_iter < max_iter:
_change_fig_subtitle()
else:
_change_fig_subtitle(maxed_out=True)
baseline_handle.set_ydata(baseline)
plt.draw()
weights.pop()
weights.append(w)
def confirm(event):
"""Callback function for the "Confirm" button."""
plt.close()
def get_baseline(
y, *, mask=None, w=None, tol=1.e-6, lam=1.6, max_iter=20):
"""Get a baseline.
:param y: Input data.
:type y: numpy.ndarray
:param mask: Mask to apply to the spectrum before baseline
construction.
:type mask: array-like, optional
:param w: Weights (allows restart for additional
iterations).
:type w: numpy.array, optional
:param tol: Convergence tolerence, defaults to `1.e-8`.
:type tol: float, optional
:param lam: &lambda (smoothness) parameter (the balance
between the residual of the data and the baseline and
the smoothness of the baseline). The suggested range is
between 100 and 10^8, defaults to `10^6`.
:type lam: float, optional
:param max_iter: Maximum number of iterations,
defaults to `20`.
:type max_iter: int, optional
"""
return baseline_arPLS(
y, mask=mask, w=w, tol=tol, lam=lam, max_iter=max_iter,
full_output=True)
baseline, _, w, num_iter, error = get_baseline(
y, mask=mask, tol=tol, lam=lam, max_iter=max_iter)
if not interactive and not return_buf:
config = {
'tol': tol, 'lambda': lam, 'max_iter': max_iter,
'num_iter': num_iter, 'error': error, 'mask': mask}
return baseline, config, None
lambdas = [lam]
weights = [w]
num_iters = [num_iter]
errors = [error]
fig_subtitles = []
# Check inputs
if x is None:
x = np.arange(y.size)
# Setup the Matplotlib figure
title_pos = (0.5, 0.95)
title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
'verticalalignment': 'bottom'}
subtitle_pos = (0.5, 0.90)
subtitle_props = {'fontsize': 'x-large',
'horizontalalignment': 'center',
'verticalalignment': 'bottom'}
fig, ax = plt.subplots(figsize=(11, 8.5))
if mask is None:
ax.plot(x, y, label='input data')
else:
ax.plot(
x[mask.astype(bool)], y[mask.astype(bool)], label='input data')
baseline_handle = ax.plot(x, baseline, label='baseline')[0]
# ax.plot(x, y-baseline, label='baseline corrected data')
ax.legend()
ax.set_xlabel(xlabel, fontsize='x-large')
ax.set_ylabel(ylabel, fontsize='x-large')
ax.set_xlim(x[0], x[-1])
if title is None:
fig_title = plt.figtext(*title_pos, 'Baseline', **title_props)
else:
fig_title = plt.figtext(*title_pos, title, **title_props)
if num_iter < max_iter:
_change_fig_subtitle()
else:
_change_fig_subtitle(maxed_out=True)
fig.subplots_adjust(bottom=0.0, top=0.85)
lambda_box = None
if interactive:
fig.subplots_adjust(bottom=0.2)
# Setup TextBox
lambda_box = TextBox(
plt.axes([0.15, 0.05, 0.15, 0.075]), r'log($\lambda$)')
lambda_cid = lambda_box.on_submit(select_lambda)
# Setup "Continue" button
continue_btn = Button(
plt.axes([0.45, 0.05, 0.15, 0.075]), 'Continue smoothing')
continue_cid = continue_btn.on_clicked(continue_iter)
# Setup "Confirm" button
confirm_btn = Button(
plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
confirm_cid = confirm_btn.on_clicked(confirm)
# Show figure for user interaction
plt.show()
# Disconnect all widget callbacks when figure is closed
lambda_box.disconnect(lambda_cid)
continue_btn.disconnect(continue_cid)
confirm_btn.disconnect(confirm_cid)
# ... and remove the buttons before returning the figure
lambda_box.ax.remove()
continue_btn.ax.remove()
confirm_btn.ax.remove()
if return_buf:
fig_title.set_in_layout(True)
fig_subtitles[-1].set_in_layout(True)
fig.tight_layout(rect=(0, 0, 1, 0.90))
buf = fig_to_iobuf(fig)
else:
buf = None
plt.close()
config = {
'tol': tol, 'lambda': lambdas[-1], 'max_iter': max_iter,
'num_iter': num_iters[-1], 'error': errors[-1], 'mask': mask}
return baseline, config, buf
[docs]
class ConvertStructuredProcessor(Processor):
"""Processor for converting map data between structured /
unstructued formats.
"""
[docs]
def process(self, data):
"""Return the converted map data
:param data: Input data.
:type data: list[PipelineData]
:return: Converted data.
:rtype: nexusformat.nexus.NXdata
"""
# Local modules
from CHAP.utils.converters import convert_structured_unstructured
data = self.get_pipelinedata_item(data)
return convert_structured_unstructured(data)
[docs]
class ExpressionProcessor(Processor):
"""Processor to perform an arbitrary expression on input data."""
[docs]
def process(self, data, expression, symtable=None, nxprocess=False,
nxfieldtable=None, nxdata_name='data',
nxfield_name='result'):
"""Return result of plugging input data into the given
mathematical expression.
:param data: Input data.
:type data: list[PipelineData]
:param expression: Mathemetical expression. May use the
built-in function `round` and / or numpy functions with
`np.<function_name>` or `numpy.<function_name>.`
:type expression: str
:param symtable: Values to use for names in `expression` that
should not be obtained from input data. Defaults to `None`.
:type symtable: dict[str, (float, int)], optional.
:param nxprocess: Flag to indicate the results should be
returned as a NeXus style
`NXprocess <https://manual.nexusformat.org/classes/base_classes/NXprocess.html#index-0>`__
defaults to `False`.
:type nxprocess: bool, optional
:param nxfieldtable: Used only if `nxprocess` is
`True`. Dictionary of additional
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
objects to include in the NXprocess, the result object
right next to the expression result's NXfield.
Dictionary keys become NXfield names in the returned
object. Defaults to `None`.
:type nxfieldtable: dict[str, nexusformat.nexus.NXfield],
optional
:param nxdata_name: Used only if `nxprocess` is
`True`. Name for the NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object in the returned NXprocess object that contains
actual result data. efaults to `'data'`.
:type nxdata_name: str, optional
:param nxfield_name: Used only if `nxprocess` is
`True`. Name for the NXfield dataset that contains the
evaluated expression results. Defaults to `'result'`.
:type nxfield_name: str, optional
:returns: Result of evaluating the expression.
:rtype: object
"""
# Third party modules
try:
import zarr
except ImportError:
pass
from ast import parse
from asteval import get_ast_names, Interpreter
if not isinstance(nxfieldtable, dict):
self.logger.warning('No usable nxfieldtable provided, using {}')
nxfieldtable = {}
names = get_ast_names(parse(expression))
if symtable is None:
symtable = {}
for name in names:
if name in symtable:
continue
if name == 'round':
symtable[name] = round
elif name in ('np', 'numpy'):
symtable[name] = np
else:
symtable[name] = self.get_data(
data, name=name, remove=False)
for k, v in symtable.items():
try:
if isinstance(v, zarr.core.array.Array):
symtable[k] = v[()]
except Exception:
pass
self.logger.debug(f'Asteval symtable: {symtable}')
aeval = Interpreter(symtable=symtable)
new_data = aeval(expression)
if not nxprocess:
return new_data
# Third party modules
from nexusformat.nexus import (
NXdata,
NXfield,
NXprocess,
)
return NXprocess(
name=nxprocess,
entries={
nxdata_name: NXdata(
signal=NXfield(
name=nxfield_name,
value=new_data,
attrs={'expression': expression}
),
**dict(nxfieldtable.items()),
attrs={'expression': expression}
),
}
)
[docs]
class ImageProcessor(Processor):
"""A Processor to perform various visualization operations on
images (slices) selected from a NeXus style
`NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__.
:ivar config: Initialization parameters for an instance of
:class:`~CHAP.common.models.ImageProcessorConfig`.
:vartype config: dict, optional
:ivar nxmemory: Maximum memory usage when reading
`NeXus <https://www.nexusformat.org>`__ files.
:vartype nxmemory: int, optional
:ivar save_figures: Return the plottable image(s) to be written
to file downstream in the pipeline, defaults to `True`.
:vartype save_figures: bool, optional
"""
pipeline_fields: dict = Field(
default = {
'config': 'common.models.map.ImageProcessorConfig'}, init_var=True)
config: ImageProcessorConfig
nxmemory: Optional[conint(gt=0)] = 100000
save_figures: Optional[bool] = True
_figconfig: dict = PrivateAttr(default={})
[docs]
def process(self, data):
"""Plot and/or return image slices from a NeXus style
`NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__.
object with a default plottable data path.
:param data: Input data.
:type data: list[PipelineData]
:return: Plottable image(s) (for save_figures = `True`)
or the input default NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object (for save_figures = `False`).
:rtype: bytes | nexusformat.nexus.NXdata | numpy.ndarray
"""
if not self.save_figures and not self.interactive:
return None
# Third party modules
from nexusformat.nexus import nxsetconfig
nxsetconfig(memory=self.nxmemory)
# Load the default data
try:
nxdata = self.get_data(data).get_default()
except Exception as exc:
raise ValueError(
'Unable the load the default NXdata object from the input '
f'pipeline ({data})') from exc
# Get the axes info and image slice(s)
try:
data = nxdata.nxsignal
except Exception as exc:
raise ValueError('Unable the find the default signal in:\n'
f'({nxdata.tree})') from exc
axis = self.config.axis
axes = nxdata.attrs.get('axes', None)
if isinstance(axes, str):
axes = [axes]
if nxdata.nxsignal.ndim == 2:
exit('ImageProcessor not tested yet for a 2D dataset')
if axis is not None:
axis = None
self.logger.warning('Ignoring parameter axis')
if index is not None:
index = None
self.logger.warning('Ignoring parameter index')
if coord is not None:
coord = None
self.logger.warning('Ignoring parameter coord')
elif nxdata.nxsignal.ndim == 3:
if isinstance(axis, int):
if not 0 <= axis < nxdata.nxsignal.ndim:
raise ValueError(f'axis index out of range ({axis} not in '
f'[0, {nxdata.nxsignal.ndim-1}])')
elif isinstance(axis, str):
if axes is None or axis not in axes:
raise ValueError(
f'Unable to match axis = {axis} in {nxdata.tree}')
axis = axes.index(axis)
else:
raise ValueError(f'Invalid parameter axis ({axis})')
if axis:
data = np.moveaxis(data, axis, 0)
if axes is not None and hasattr(nxdata, axes[axis]):
if axis == 1:
axes = [axes[1], axes[0], axes[2]]
elif axis:
axes = [axes[2], axes[0], axes[1]]
axis_name = axes[0]
if 'units' in nxdata[axis_name].attrs:
axis_unit = f' ({nxdata[axis_name].units})'
else:
axis_unit = ''
row_label = axes[2]
row_coords = nxdata[row_label].nxdata
column_label = axes[1]
column_coords = nxdata[column_label].nxdata
if 'units' in nxdata[row_label].attrs:
row_label += f' ({nxdata[row_label].units})'
if 'units' in nxdata[column_label].attrs:
column_label += f' ({nxdata[column_label].units})'
else:
exit('No axes attribute not tested yet')
axes = [0, 1, 2]
axes.pop(axis)
axis_name = f'axis {axis}'
axis_unit = ''
# row_label = f'axis {axis[1]}'
# row_coords = None
# column_label = f'axis {axis[0]}'
# column_coords = None
axis_coords = nxdata[axis_name].nxdata
else:
raise ValueError('Invalid data dimension (must be 2D or 3D)')
if self.config.coord_range is None:
index_range = self.config.index_range
else:
# Local modules
from CHAP.utils.general import (
index_nearest_down,
index_nearest_up,
)
if self.config.index_range is not None:
self.logger.warning('Ignoring parameter index_range')
if isinstance(self.config.coord_range, (int, float)):
index_range = index_nearest_up(
axis_coords, self.config.coord_range)
elif len(self.config.coord_range) == 2:
index_range = [
index_nearest_up(axis_coords, self.config.coord_range[0]),
index_nearest_down(
axis_coords, self.config.coord_range[1])]
else:
index_range = [
index_nearest_up(axis_coords, self.config.coord_range[0]),
index_nearest_down(
axis_coords, self.config.coord_range[1]),
int(max(1, self.config.coord_range[2] /
((axis_coords[-1]-axis_coords[0])/data.shape[0])))]
if index_range == -1:
index_range = nxdata.nxsignal.shape[axis] // 2
if isinstance(index_range, int):
data = data[index_range]
axis_coords = [axis_coords[index_range]]
elif index_range is not None:
slice_ = slice(*tuple(index_range))
data = data[slice_]
axis_coords = axis_coords[slice_]
if self.config.vrange is None:
vrange = (float(data.min()), float(data.max()))
else:
vrange = self.config.vrange
if vrange[0] is None:
vrange[0] = float(data.min())
if vrange[1] is None:
vrange[1] = float(data.max())
# Create the figure configuration
self._figconfig = {
'title': f'{nxdata.nxpath}/{nxdata.signal}',
'axis_name': axis_name,
'axis_unit': axis_unit,
'axis_coords': axis_coords,
'row_label': row_label,
'column_label': column_label,
'extent': (row_coords[0], row_coords[-1],
column_coords[-1], column_coords[0]),
'vrange': vrange,
}
self.logger.debug(f'figure configuration:\n{self._figconfig}')
if len(axis_coords) == 1:
# Create a figure for a single image slice
if self.config.animation:
self.logger.warning(
'Ignoring animation parameter for a single image')
fileformat = 'png'
if self.config.fileformat is None:
fileformat = 'png'
else:
fileformat = self.config.fileformat
fig, plt = self._create_figure(np.squeeze(data))
if self.interactive:
plt.show()
if self.save_figures:
# Local modules
from CHAP.utils.general import fig_to_iobuf
# Return a binary image of the figure
buf, fileformat = fig_to_iobuf(fig, fileformat=fileformat)
else:
buf = None
plt.close()
if self.save_figures:
return {'image_data': buf, 'fileformat': fileformat}
return nxdata
# Create an animation for a set of image slices
if self.interactive or self.config.animation:
ani = self._create_animation(data)
else:
ani = None
if self.save_figures:
if self.config.animation:
# Return the animation object
if (self.config.fileformat is not None
and self.config.fileformat != 'gif'):
self.logger.warning(
'Ignoring inconsistent file extension')
fileformat = 'gif'
image_data = ani
else:
# Local modules
from CHAP.utils.general import fig_to_iobuf
if self.config.fileformat in ('png', 'tif', 'tifstack'):
fileformat = self.config.fileformat
else:
if self.config.fileformat is not None:
self.logger.warning('Ignoring invalid fileformat '
f'({self.config.fileformat})')
if data.shape[0] == 1:
fileformat = 'tif'
else:
fileformat = 'tifstack'
if fileformat != 'tifstack':
# Return the set of image slices as individual figs
num_digit = len(str(data.shape[0]))
images = []
for i in range(data.shape[0]):
fig, plt = self._create_figure(data[i])
images.append((
fig_to_iobuf(fig, fileformat=fileformat),
f'{self.config.basename}_'
f'{str(i).zfill(num_digit)}'))
plt.close()
return images
# Return the set of image slices as a tif stack
data = 255.0*((data - vrange[0])/
(vrange[1] - vrange[0]))
fileformat = 'tif'
image_data = data.astype(np.uint8)
return {'image_data': image_data, 'fileformat': fileformat}
return nxdata
def _create_animation(self, data):
"""Create a Matplotlib animation from a set of images."""
# Third party modules
from functools import partial
from matplotlib import animation
def _set_title(self, i):
"""Set the title for a single Matplotlib image of the animation."""
return self._figconfig['axis_name'] +\
f' = {self._figconfig["axis_coords"][i]:.3f}' +\
self._figconfig['axis_unit']
def _animate(i, plt, title):
"""Create a single Matplotlib image for the animation."""
im.set_array(data[i])
title.set_text(self._set_title(i))
plt.draw()
return im,
fig, im, plt, title = self._create_figure(data[0], animated=True)
ani = animation.FuncAnimation(
fig, partial(_animate, plt=plt, title=title), frames=data.shape[0],
interval=50, blit=True)
if self.interactive:
plt.show()
plt.close()
return ani
def _create_figure(self, image, animated=False):
"""Create a Matplotlib figure from an image."""
# Third party modules
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
im = plt.imshow(
image, extent=self._figconfig['extent'], origin='lower',
vmin=self._figconfig['vrange'][0],
vmax=self._figconfig['vrange'][1], cmap='gray', animated=animated)
fig.suptitle(self._figconfig['title'], fontsize='x-large')
title = ax.set_title(self._set_title(0), fontsize='x-large', pad=10)
ax.set_xlabel(self._figconfig['row_label'], fontsize='x-large')
ax.set_ylabel(self._figconfig['column_label'], fontsize='x-large')
plt.colorbar()
fig.tight_layout()
if animated:
return fig, im, plt, title
return fig, plt
[docs]
class MapProcessor(Processor):
"""A Processor that takes a map configuration and returns a NeXus
style
`NXentry <https://manual.nexusformat.org/classes/base_classes/NXentry.html#index-0>`__
object representing that map's metadata and any scalar-valued raw
data requested by the supplied map configuration.
:ivar config: Map configuration parameters to initialize an
instance of :class:`~CHAP.common.models.map.MapConfig`.
Any values in `'config'` supplant their corresponding values
obtained from the pipeline data configuration.
:vartype config: dict | MapConfig
:ivar detector_config: Detector configurations of the detectors to
include raw data for in the returned NXentry object
(overruling detector info in the pipeline data, if present).
:vartype detector_config: dict | DetectorConfig
:ivar num_proc: Number of processors used to read map,
defaults to `1`.
:vartype num_proc: int, optional
"""
pipeline_fields: dict = Field(
default = {
'config': 'common.models.map.MapConfig',
'detector_config': 'common.models.map.DetectorConfig'},
init_var=True)
config: Optional[MapConfig] = None
detector_config: DetectorConfig = DetectorConfig(detectors=[])
num_proc: Optional[conint(gt=0)] = 1
[docs]
@field_validator('num_proc')
@classmethod
def validate_num_proc(cls, num_proc, info):
"""Validate the number of processors.
:param num_proc: Number of processors used to read map,
defaults to `1`.
:type num_proc: int, optional
:param info: Model parameter validation information.
:type info: pydantic.ValidationInfo
:return: Validated number of processors
:rtype: str
"""
if num_proc > 1:
logger = info.data['logger']
try:
# Third party modules
# pylint: disable=unused-import
from mpi4py import MPI
if num_proc > os.cpu_count():
logger.warning(
f'The requested number of processors ({num_proc}) '
'exceeds the maximum number of processors '
f'({os.cpu_count()}): reset it to {os.cpu_count()}')
num_proc = os.cpu_count()
except ImportError:
logger.warning('Unable to load mpi4py, running serially')
num_proc = 1
logger.debug(f'Number of processors: {num_proc}')
return num_proc
[docs]
def process(
self, data, placeholder_data=False, fill_data=True, comm=None):
"""Process that takes a map configuration and returns a NeXus
style
`NXentry <https://manual.nexusformat.org/classes/base_classes/NXentry.html#index-0>`__
object representing the map.
:param data: Pipeline data list with an optional item for the
map configuration parameters with
`'common.models.map.MapConfig'` as its `'schema'` key.
:type data: list[PipelineData]
:param placeholder_data: For SMB EDD maps only. Value to use
for missing detector data frames, or `False` if missing
data should raise an error, defaults to `False`.
:type placeholder_data: object, optional
:param fill_data: Flag to indicate whether or not to fill out
datasets with real data; defaults to `True`.
:type fill_data: bool, optional
:param comm: MPI communicator.
:type comm: mpi4py.MPI.Comm, optional
:return: Map data and metadata.
:rtype: Union[nexusformat.nexus.NXentry
"""
# System modules
import logging
# Third party modules
import yaml
# Update metadata
# self._metadata['user_metadata'].update({
# 'map': self.config.model_dump()})
# Check for available metadata
metadata = {}
provenance = {}
if data:
metadata, provenance = self._get_metadata_provenance(data)
if fill_data:
# Create the sub-pipeline configuration for each processor
# FIX: catered to EDD with one spec scan
assert len(self.config.spec_scans) == 1
spec_scans = self.config.spec_scans[0]
scan_numbers = spec_scans.scan_numbers
num_scan = len(scan_numbers)
if num_scan < self.num_proc:
self.logger.warning(
f'Requested number of processors ({self.num_proc}) exceeds '
f'the number of scans ({num_scan}): reset it to {num_scan}')
self.num_proc = num_scan
if self.num_proc == 1:
common_comm = comm
offsets = [0]
else:
# System modules
from tempfile import NamedTemporaryFile
# Local modules
from CHAP.models import RunConfig
raise NotImplementedError(
'MapProcessor needs testing for num_proc>1')
scans_per_proc = num_scan//self.num_proc
num = scans_per_proc
if num_scan - scans_per_proc*self.num_proc > 0:
num += 1
spec_scans.scan_numbers = scan_numbers[:num]
n_scan = num
pipeline_config = []
offsets = [0]
for n_proc in range(1, self.num_proc):
num = scans_per_proc
if n_proc < num_scan - scans_per_proc*self.num_proc:
num += 1
config = self.config.model_dump()
config['spec_scans'][0]['scan_numbers'] = \
scan_numbers[n_scan:n_scan+num]
pipeline_config.append([{
'common.MapProcessor': {
'config': config,
'detector_config':
self.detector_config.model_dump()}}])
offsets.append(n_scan)
n_scan += num
# Spawn the workers to run the sub-pipeline
run_config = RunConfig(
log_level=logging.getLevelName(self.logger.level), spawn=1)
tmp_names = []
with NamedTemporaryFile(delete=False) as fp:
# pylint: disable=c-extension-no-member
fp_name = fp.name
tmp_names.append(fp_name)
with open(fp_name, 'w', encoding='utf-8') as f:
yaml.dump({'config': {'spawn': 1}}, f, sort_keys=False)
for n_proc in range(1, self.num_proc):
f_name = f'{fp_name}_{n_proc}'
tmp_names.append(f_name)
with open(f_name, 'w', encoding='utf-8') as f:
yaml.dump(
# FIX once comm is a field of RunConfig
#processor.py {'config': run_config.model_dump(exclude='comm'),
{'config': run_config.model_dump(),
'pipeline': pipeline_config[n_proc-1]},
f, sort_keys=False)
# pylint: disable=used-before-assignment
sub_comm = MPI.COMM_SELF.Spawn(
'CHAP', args=[fp_name], maxprocs=self.num_proc-1)
common_comm = sub_comm.Merge(False)
# Align with the barrier in RunConfig() on common_comm
# called from the spawned main() in common_comm
common_comm.barrier()
# Align with the barrier in run() on common_comm
# called from the spawned main()
common_comm.barrier()
if common_comm is None:
self.num_proc = 1
rank = 0
else:
self.num_proc = common_comm.Get_size()
rank = common_comm.Get_rank()
if self.num_proc == 1:
offset = 0
else:
num_scan = common_comm.bcast(num_scan, root=0)
offset = common_comm.scatter(offsets, root=0)
# Read the raw data
if self.config.experiment_type == 'EDD':
data, independent_dimensions, all_scalar_data = \
self._read_raw_data_edd(
common_comm, num_scan, offset, placeholder_data)
else:
data, independent_dimensions, all_scalar_data = \
self._read_raw_data(common_comm, num_scan, offset)
if not rank:
self.logger.debug(f'Data shape: {data.shape}')
if independent_dimensions is not None:
self.logger.debug('Independent dimensions shape: '
f'{independent_dimensions.shape}')
if all_scalar_data is not None:
self.logger.debug('Scalar data shape: '
f'{all_scalar_data.shape}')
if rank:
return None
if self.num_proc > 1:
# Reset the scan_numbers to the original full set
spec_scans.scan_numbers = scan_numbers
# Align with the barrier in main() on common_comm
# when disconnecting the spawned worker
common_comm.barrier()
# Disconnect spawned workers and cleanup temporary files
sub_comm.Disconnect()
for tmp_name in tmp_names:
os.remove(tmp_name)
else:
# fill_data is False, just use empty arrays
map_len = 0
_independent_dimensions = {
dim.label: [] for dim in self.config.independent_dimensions}
det_shapes = False
for scans in self.config.spec_scans:
for scan_number in scans.scan_numbers:
scanparser = scans.get_scanparser(scan_number)
map_len += scanparser.spec_scan_npts
for dim in self.config.independent_dimensions:
val = dim.get_value(
scans, scan_number, -1,
self.config.scalar_data)
if not isinstance(val, list):
val = [val]
_independent_dimensions[dim.label].extend(val)
if not det_shapes:
det_shapes = {}
for detector in self.detector_config.detectors:
ddata_init = scanparser.get_detector_data(
detector.get_id(), 0)
if isinstance(ddata_init, tuple):
ddata_init = ddata_init[0].squeeze()
det_shapes[detector.get_id()] = ddata_init.shape
all_scalar_data = np.empty(
(len(self.config.all_scalar_data), map_len))
if len(self.detector_config.detectors) > 0:
data = np.empty(
(len(self.detector_config.detectors),
map_len,
*det_shapes[self.detector_config.detectors[0].get_id()]))
else:
data = None
independent_dimensions = np.asarray(
[_independent_dimensions[dim.label]
for dim in self.config.independent_dimensions])
# Construct and return the NXroot object
nxroot = self._get_nxroot(
data, independent_dimensions, all_scalar_data, placeholder_data)
if metadata and provenance:
return (
PipelineData(
name=self.name, data=metadata,
schema='foxden.reader.FoxdenMetadataReader'),
PipelineData(
name=self.name, data=provenance,
schema='foxden.reader.FoxdenProvenanceReader'),
PipelineData(
name=self.name, data=nxroot, schema=self.get_schema()))
return nxroot
def _get_metadata_provenance(self, data):
"""Get experiment specific configurational data from the
FOXDEN metadata and provenance records.
:param data: Input data.
:type data: list[PipelineData]
:return: Experiment specific metadata and provenance.
:rtype: dict, dict
"""
# Local modules
from CHAP.tomo.processor import (
read_metadata_provenance,
create_metadata_provenance,
)
# Read metadata and provenance from the pipeline data
metadata, provenance = read_metadata_provenance(data, self.logger)
if not metadata:
return metadata, provenance
# Update metadata and provenance
experiment_type = [v.lower() for v in metadata.get('technique')]
if 'tomography' in experiment_type:
station = f'id{metadata.get("beamline")[0].lower()}'
if station in ('id1a3', 'id3a'):
spec_file = os.path.join(
metadata.get('data_location_raw'), 'spec.log')
else:
raise ValueError(f'Invalid beamline parameter ({station})')
sample_name = metadata.get('sample_name')
# FIX We could add full self.config and self.detector_config
# That would allow us to create a full map from just metadata
user_metadata = {
'map_config': {
'experiment_type': 'TOMO',
'sample': {'name': sample_name,
'description': metadata.get('description')},
'station': station,
'title': sample_name,
'spec_file': spec_file}}
metadata, provenance = create_metadata_provenance(
'map', metadata=metadata, provenance=provenance,
user_metadata=user_metadata, logger=self.logger,
update=False, read=False)
user_metadata = metadata.pop('user_metadata', {})
metadata['user_metadata'] = user_metadata
else:
raise ValueError(
f'Experiment type {experiment_type} not implemented yet')
return metadata, provenance
def _get_nxroot(
self, data, independent_dimensions, all_scalar_data,
placeholder_data):
"""Use a `MapConfig` to construct a NeXus style
`NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__
object.
:param data: Map's raw data.
:type data: numpy.ndarray
:param independent_dimensions: Map's independent
coordinates.
:type independent_dimensions: numpy.ndarray
:param all_scalar_data: Map's scalar data.
:type all_scalar_data: numpy.ndarray
:param placeholder_data: For SMB EDD maps only. Value to use
for missing detector data frames, or `False` if missing
data should raise an error.
:type placeholder_data: object
:return: Map's data and metadata.
:rtype: nexusformat.nexus.NXroot
"""
# Third party modules
# pylint: disable=no-name-in-module
from nexusformat.nexus import (
NXcollection,
NXdata,
NXentry,
NXfield,
NXlinkfield,
NXsample,
NXroot,
)
# pylint: enable=no-name-in-module
# Local modules:
from CHAP.common.models.map import PointByPointScanData
def linkdims(nxgroup, nxdata_source):
"""Link the dimensions for an
`NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__.
"""
source_axes = list(nxdata_source.keys())
if isinstance(source_axes, str):
source_axes = [source_axes]
axes = []
for dim in source_axes:
axes.append(dim)
if isinstance(nxdata_source[dim], NXlinkfield):
nxgroup[dim] = nxdata_source[dim]
else:
nxgroup.makelink(nxdata_source[dim])
if f'{dim}_indices' in nxdata_source.attrs:
nxgroup.attrs[f'{dim}_indices'] = \
nxdata_source.attrs[f'{dim}_indices']
if len(axes) == 1:
nxgroup.attrs['axes'] = axes
else:
nxgroup.attrs['unstructured_axes'] = axes
# Set up NXroot/NXentry and add CHESS-specific metadata
nxroot = NXroot()
nxentry = NXentry(name=self.config.title)
nxroot[nxentry.nxname] = nxentry
nxentry.map_config = self.config.model_dump_json()
nxentry.detector_config = self.detector_config.model_dump_json()
nxentry.attrs['station'] = self.config.station
for k, v in self.config.attrs.items():
nxentry.attrs[k] = v
nxentry.spec_scans = NXcollection()
for scans in self.config.spec_scans:
nxentry.spec_scans[scans.scanparsers[0].scan_name] = \
NXfield(value=scans.scan_numbers,
dtype='int8',
attrs={'spec_file': str(scans.spec_file)})
# Add sample metadata
nxentry[self.config.sample.name] = NXsample(
**self.config.sample.model_dump())
# Set up independent dimensions NXdata group
# (squeeze out constant dimensions)
constant_dim = []
for i, dim in enumerate(self.config.independent_dimensions):
unique = np.unique(independent_dimensions[i])
if unique.size == 1:
constant_dim.append(i)
nxentry.independent_dimensions = NXdata()
if len(constant_dim) < len(self.config.independent_dimensions):
for i, dim in enumerate(self.config.independent_dimensions):
if i not in constant_dim:
nxentry.independent_dimensions[dim.label] = NXfield(
independent_dimensions[i], dim.label,
attrs={'units': dim.units,
'long_name': f'{dim.label} ({dim.units})',
'data_type': dim.data_type,
'local_name': dim.name})
else:
nxentry.independent_dimensions.index = NXfield(
np.arange(independent_dimensions[0].size), 'index')
# Set up scalar data NXdata group
# (add the constant independent dimensions)
if all_scalar_data is not None:
self.logger.debug(
f'all_scalar_data.shape = {all_scalar_data.shape}\n\n')
scalar_signals = []
scalar_data = []
for i, dim in enumerate(self.config.all_scalar_data):
scalar_signals.append(dim.label)
scalar_data.append(NXfield(
value=all_scalar_data[i],
units=dim.units,
attrs={'long_name': f'{dim.label} ({dim.units})',
'data_type': dim.data_type,
'local_name': dim.name}))
if (self.config.experiment_type == 'EDD'
and not placeholder_data is False):
scalar_signals.append('placeholder_data_used')
scalar_data.append(NXfield(
value=all_scalar_data[-1],
attrs={'description':
'Indicates whether placeholder data may be present for'
'the corresponding frames of detector data.'}))
for i, dim in enumerate(deepcopy(self.config.independent_dimensions)):
if i in constant_dim:
scalar_signals.append(dim.label)
scalar_data.append(NXfield(
independent_dimensions[i], dim.label,
attrs={'units': dim.units,
'long_name': f'{dim.label} ({dim.units})',
'data_type': dim.data_type,
'local_name': dim.name}))
self.config.all_scalar_data.append(
PointByPointScanData(**dim.model_dump()))
self.config.independent_dimensions.remove(dim)
if scalar_signals:
nxentry.scalar_data = NXdata()
for k, v in zip(scalar_signals, scalar_data):
nxentry.scalar_data[k] = v
if 'SCAN_N' in scalar_signals:
nxentry.scalar_data.attrs['signal'] = 'SCAN_N'
else:
nxentry.scalar_data.attrs['signal'] = scalar_signals[0]
scalar_signals.remove(nxentry.scalar_data.attrs['signal'])
nxentry.scalar_data.attrs['auxiliary_signals'] = scalar_signals
# Add detector data
nxdata = NXdata()
nxentry.data = nxdata
nxentry.data.set_default()
detector_ids = []
for k, v in self.config.attrs.items():
nxdata.attrs[k] = v
if data is not None:
min_ = np.min(data, axis=tuple(range(1, data.ndim)))
max_ = np.max(data, axis=tuple(range(1, data.ndim)))
for i, detector in enumerate(self.detector_config.detectors):
nxdata[detector.get_id()] = NXfield(
value=data[i],
attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]})
detector_ids.append(detector.get_id())
linkdims(nxdata, nxentry.independent_dimensions)
if len(self.detector_config.detectors) == 1:
nxdata.attrs['signal'] = self.detector_config.detectors[0].get_id()
nxentry.detector_ids = detector_ids
return nxroot
def _read_raw_data_edd(
self, comm, num_scan, offset, placeholder_data):
"""Read the raw EDD data for a given map configuration.
:param comm: MPI communicator.
:type comm: mpi4py.MPI.Comm, optional
:param num_scan: Number of scans in the map.
:type num_scan: int
:param offset: Offset scan number of current processor.
:type offset: int
:param placeholder_data: Value to use for missing detector
data frames, or `False` if missing data should raise an
error.
:type placeholder_data: object
:return: Map's raw data, independent dimensions and scalar
data.
:rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
"""
# Third party modules
try:
from mpi4py import MPI
from mpi4py.util import dtlib
except ImportError:
pass
# Local modules
from CHAP.utils.general import list_to_string
if comm is None:
self.num_proc = 1
rank = 0
else:
self.num_proc = comm.Get_size()
rank = comm.Get_rank()
if not rank:
self.logger.debug(f'Number of processors: {self.num_proc}')
self.logger.debug(f'Number of scans: {num_scan}')
# Create the shared data buffers
# FIX: just one spec scan at this point
assert len(self.config.spec_scans) == 1
scan = self.config.spec_scans[0]
scan_numbers = scan.scan_numbers
scanparser = scan.get_scanparser(scan_numbers[0])
detector_ids = [
int(d.get_id()) for d in self.detector_config.detectors]
ddata, placeholder_used = scanparser.get_detector_data(
detector_ids, placeholder_data=placeholder_data)
spec_scan_shape = scanparser.spec_scan_shape
num_dim = np.prod(spec_scan_shape)
num_id = len(self.config.independent_dimensions)
num_sd = len(self.config.all_scalar_data)
if placeholder_data is not False:
num_sd += 1
if self.num_proc == 1:
assert num_scan == len(scan_numbers)
data = np.empty((num_scan, *ddata.shape), dtype=ddata.dtype)
independent_dimensions = np.empty(
(num_id, num_scan*num_dim), dtype=np.float64)
all_scalar_data = np.empty(
(num_sd, num_scan*num_dim), dtype=np.float64)
else:
self.logger.debug(f'Scan offset on processor {rank}: {offset}')
self.logger.debug(f'Scan numbers on processor {rank}: '
f'{list_to_string(scan_numbers)}')
datatype = dtlib.from_numpy_dtype(ddata.dtype)
itemsize = datatype.Get_size()
if not rank:
nbytes = num_scan * np.prod(ddata.shape) * itemsize
else:
nbytes = 0
win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
buf, itemsize = win.Shared_query(0)
assert itemsize == datatype.Get_size()
data = np.ndarray(
buffer=buf, dtype=ddata.dtype, shape=(num_scan, *ddata.shape))
datatype = dtlib.from_numpy_dtype(np.float64)
itemsize = datatype.Get_size()
if not rank:
nbytes = num_id * num_scan * num_dim * itemsize
win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
buf_id, _ = win_id.Shared_query(0)
independent_dimensions = np.ndarray(
buffer=buf_id, dtype=np.float64,
shape=(num_id, num_scan*num_dim))
if not rank:
nbytes = num_sd * num_scan * num_dim * itemsize
win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
buf_sd, _ = win_sd.Shared_query(0)
all_scalar_data = np.ndarray(
buffer=buf_sd, dtype=np.float64,
shape=(num_sd, num_scan*num_dim))
# Read the raw data
init = True
for scan in self.config.spec_scans:
for scan_number in scan.scan_numbers:
if init:
init = False
else:
scanparser = scan.get_scanparser(scan_number)
assert spec_scan_shape == scanparser.spec_scan_shape
ddata, placeholder_used = scanparser.get_detector_data(
detector_ids, placeholder_data=placeholder_data)
data[offset] = ddata
start_dim = offset * num_dim
end_dim = start_dim + num_dim
for i, dim in enumerate(self.config.independent_dimensions):
independent_dimensions[i][start_dim:end_dim] = \
dim.get_value(
scan, scan_number, scan_step_index=-1,
relative=False)
for i, dim in enumerate(self.config.all_scalar_data):
all_scalar_data[i][start_dim:end_dim] = dim.get_value(
scan, scan_number, scan_step_index=-1,
relative=False)
if placeholder_data is not False:
all_scalar_data[-1][start_dim:end_dim] = \
placeholder_used
offset += 1
return (
np.swapaxes(
data.reshape((np.prod(data.shape[:2]), *data.shape[2:])),
0, 1),
independent_dimensions, all_scalar_data)
# @profile
def _read_raw_data(self, comm, num_scan, offset):
"""Read the raw data for a given map configuration.
:param comm: MPI communicator.
:type comm: mpi4py.MPI.Comm, optional
:param num_scan: Number of scans in the map.
:type num_scan: int
:param offset: Offset scan number of current processor.
:type offset: int
:return: Map's raw data, independent dimensions and scalar
data.
:rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
"""
# Third party modules
try:
from mpi4py import MPI
from mpi4py.util import dtlib
except ImportError:
pass
# Local modules
from CHAP.utils.general import list_to_string
if comm is None:
self.num_proc = 1
rank = 0
else:
self.num_proc = comm.Get_size()
rank = comm.Get_rank()
if not rank:
self.logger.debug(f'Number of processors: {self.num_proc}')
self.logger.debug(f'Number of scans: {num_scan}')
# Create the shared data buffers
assert len(self.config.spec_scans) == 1
scans = self.config.spec_scans[0]
scan_numbers = scans.scan_numbers
scanparser = scans.get_scanparser(scan_numbers[0])
#RV only correct for multiple detectors if the same image sizes
if len(self.detector_config.detectors) != 1:
raise ValueError('Multiple detectors not tested yet')
# FIX eliminate need for testing for self.config.experiment_type
# in scanparser
if self.config.experiment_type == 'TOMO':
dtype = np.float32
if self.detector_config.roi is None:
detector_roi = [slice(None), slice(None)]
else:
detector_roi = self.detector_config.roitoslice()
ddata = scanparser.get_detector_data(
self.detector_config.detectors[0].get_id(),
detector_roi=detector_roi, dtype=dtype)
else:
dtype = None
ddata = scanparser.get_detector_data(
self.detector_config.detectors[0].get_id())
num_det = len(self.detector_config.detectors)
num_dim = ddata.shape[0]
num_id = len(self.config.independent_dimensions)
num_sd = len(self.config.all_scalar_data)
if self.num_proc == 1:
assert num_scan == len(scan_numbers)
data = num_det*[num_scan*[None]]
independent_dimensions = np.empty(
(num_scan, num_id, num_dim), dtype=np.float64)
if num_sd:
all_scalar_data = np.empty(
(num_scan, num_sd, num_dim), dtype=np.float64)
else:
self.logger.debug(f'Scan offset on processor {rank}: {offset}')
self.logger.debug(f'Scan numbers on processor {rank}: '
f'{list_to_string(scan_numbers)}')
datatype = dtlib.from_numpy_dtype(dtype)
itemsize = datatype.Get_size()
if not rank:
nbytes = num_scan * np.prod(ddata.shape) * itemsize
else:
nbytes = 0
win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
buf, _ = win.Shared_query(0)
#RV improve memory requirements ala single processor case?
data = np.ndarray(
buffer=buf, dtype=dtype,
shape=(num_det, num_scan, *ddata.shape))
datatype = dtlib.from_numpy_dtype(np.float64)
itemsize = datatype.Get_size()
if not rank:
nbytes = num_scan * num_id * num_dim * itemsize
else:
nbytes = 0
win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
buf_id, _ = win_id.Shared_query(0)
independent_dimensions = np.ndarray(
buffer=buf_id, dtype=np.float64,
shape=(num_scan, num_id, num_dim))
if num_sd:
if not rank:
nbytes = num_scan * num_sd * num_dim * itemsize
win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
buf_sd, _ = win_sd.Shared_query(0)
all_scalar_data = np.ndarray(
buffer=buf_sd, dtype=np.float64,
shape=(num_scan, num_sd, num_dim))
else:
all_scalar_data = None
# Read the raw data
init = True
for scans in self.config.spec_scans:
for scan_number in scans.scan_numbers:
for i in range(len((self.detector_config.detectors))):
if init:
init = False
data[i][offset] = ddata
del ddata
else:
scanparser = scans.get_scanparser(scan_number)
if self.config.experiment_type == 'TOMO':
if self.detector_config.roi is None:
detector_roi = [
slice(None), slice(None)]
else:
detector_roi = \
self.detector_config.roitoslice()
data[i][offset] = scanparser.get_detector_data(
self.detector_config.detectors[i].get_id(),
detector_roi=detector_roi, dtype=dtype)
else:
data[i][offset] = scanparser.get_detector_data(
self.detector_config.detectors[0].get_id())
for i, dim in enumerate(self.config.independent_dimensions):
if dim.data_type in ['scan_column',
'detector_log_timestamps']:
independent_dimensions[offset,i] = dim.get_value(
scans, scan_number, scan_step_index=-1,
relative=False)[:num_dim]
elif dim.data_type in ['smb_par', 'spec_motor',
'expression']:
independent_dimensions[offset,i] = dim.get_value(
scans, scan_number, scan_step_index=-1,
relative=False,
scalar_data=self.config.scalar_data)
else:
independent_dimensions[offset,i] = dim.get_value(
scans, scan_number, scan_step_index=-1)
for i, dim in enumerate(self.config.all_scalar_data):
all_scalar_data[offset,i] = dim.get_value(
scans, scan_number, scan_step_index=-1,
relative=False)
offset += 1
if self.num_proc == 1:
data = np.asarray(data)
if num_sd:
return (
data.reshape(
(data.shape[0], np.prod(data.shape[1:3]),
*data.shape[3:])),
np.stack(tuple([independent_dimensions[:,i].flatten()
for i in range(num_id)])),
np.stack(tuple([all_scalar_data[:,i].flatten()
for i in range(num_sd)])))
return (
data.reshape(
(data.shape[0], np.prod(data.shape[1:3]), *data.shape[3:])),
np.stack(tuple([independent_dimensions[:,i].flatten()
for i in range(num_id)])),
None)
[docs]
class MPICollectProcessor(Processor):
"""A Processor that collects the distributed worker data from
MPIMapProcessor on the root node.
"""
[docs]
def process(self, data, comm, root_as_worker=True):
"""Collect data on root node.
:param data: Input data.
:type data: list[PipelineData]
:param comm: MPI communicator.
:type comm: mpi4py.MPI.Comm, optional
:param root_as_worker: Use the root node as a worker,
defaults to `True`.
:type root_as_worker: bool, optional
:return: Distributed worker data on the root node.
:rtype: list
"""
num_proc = comm.Get_size()
rank = comm.Get_rank()
if root_as_worker:
data = self.get_pipelinedata_item(data)
if num_proc > 1:
data = comm.gather(data, root=0)
else:
for n_worker in range(1, num_proc):
if rank == n_worker:
comm.send(self.get_pipelinedata_item(data), dest=0)
data = None
elif not rank:
if n_worker == 1:
data = [comm.recv(source=n_worker)]
else:
data.append(comm.recv(source=n_worker))
#FIX RV TODO Merge the list of data items in some generic fashion
return data
[docs]
class MPIMapProcessor(Processor):
"""A Processor that applies a parallel generic sub-pipeline to
a map configuration.
"""
[docs]
def process(self, data, config=None, sub_pipeline=None):
"""Run a parallel generic sub-pipeline.
:param data: Input data.
:type data: list[PipelineData]
:param config: Initialization parameters for an instance of
:class:`~CHAP.common.models.map.MapConfig`.
:type config: dict, optional
:param sub_pipeline: Sub-pipeline.
:type sub_pipeline: Pipeline, optional
:return: `data` field of the first item in the returned
list of sub-pipeline items.
:rtype: Any
"""
# Third party modules
from mpi4py import MPI
# Local modules
from CHAP.models import RunConfig
from CHAP.runner import run
from CHAP.common.models.map import SpecScans
raise NotImplementedError('MPIMapProcessor needs updating and testing')
# pylint: disable=c-extension-no-member
comm = MPI.COMM_WORLD
num_proc = comm.Get_size()
rank = comm.Get_rank()
# Get the validated map configuration
map_config = self.get_config(
data=data, config=config, schema='common.models.map.MapConfig')
# Create the spec reader configuration for each processor
# FIX: catered to EDD with one spec scan
assert len(map_config.spec_scans) == 1
spec_scans = map_config.spec_scans[0]
scan_numbers = spec_scans.scan_numbers
num_scan = len(scan_numbers)
scans_per_proc = num_scan//num_proc
n_scan = 0
for n_proc in range(num_proc):
num = scans_per_proc
if n_proc == rank:
if rank < num_scan - scans_per_proc*num_proc:
num += 1
scan_numbers = scan_numbers[n_scan:n_scan+num]
n_scan += num
spec_config = {
'station': map_config.station,
'experiment_type': map_config.experiment_type,
'spec_scans': [SpecScans(
spec_file=spec_scans.spec_file, scan_numbers=scan_numbers)]}
# Get the run configuration to use for the sub-pipeline
if sub_pipeline is None:
sub_pipeline = {}
run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir,
'interactive': self.interactive, 'log_level': self.log_level}
run_config.update(sub_pipeline.get('config'))
run_config = RunConfig(**run_config, comm=comm)
pipeline_config = []
for item in sub_pipeline['pipeline']:
if isinstance(item, dict):
for k, v in deepcopy(item).items():
if k.endswith('Reader'):
v['config'] = spec_config
item[k] = v
if num_proc > 1 and k.endswith('Writer'):
r, e = os.path.splitext(v['filename'])
v['filename'] = f'{r}_{rank}{e}'
item[k] = v
pipeline_config.append(item)
# Run the sub-pipeline on each processor
return run(run_config, pipeline_config, logger=self.logger, comm=comm)
[docs]
class MPISpawnMapProcessor(Processor):
"""A Processor that applies a parallel generic sub-pipeline to
a map configuration by spawning workers processes.
"""
[docs]
def process(
self, data, num_proc=1, root_as_worker=True, collect_on_root=False,
sub_pipeline=None):
"""Spawn workers running a parallel generic sub-pipeline.
:param data: Input data.
:type data: list[PipelineData]
:param num_proc: Number of spawned processors, defaults to `1`.
:type num_proc: int, optional
:param root_as_worker: Use the root node as a worker,
defaults to `True`.
:type root_as_worker: bool, optional
:param collect_on_root: Collect the result of the spawned
workers on the root node, defaults to `False`.
:type collect_on_root: bool, optional
:param sub_pipeline: Sub-pipeline.
:type sub_pipeline: Pipeline, optional
:return: `data` field of the first item in the returned
list of sub-pipeline items.
"""
# Third party modules
from mpi4py import MPI
import yaml
# Local modules
from CHAP.models import RunConfig
from CHAP.runner import runner
from CHAP.common.models.map import SpecScans
raise NotImplementedError('MPIMapProcessor needs updating and testing')
# Get the map configuration from data
map_config = self.get_config(
data=data, schema='common.models.map.MapConfig')
# Get the run configuration to use for the sub-pipeline
# Optionally include the root node as a worker node
if sub_pipeline is None:
sub_pipeline = {}
run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir,
'interactive': self.interactive, 'log_level': self.log_level}
run_config.update(sub_pipeline.get('config'))
if root_as_worker:
first_proc = 1
spawn = 1
else:
first_proc = 0
spawn = -1
run_config = RunConfig(**run_config, logger=self.logger, spawn=spawn)
# Create the sub-pipeline configuration for each processor
spec_scans = map_config.spec_scans[0]
scan_numbers = spec_scans.scan_numbers
num_scan = len(scan_numbers)
scans_per_proc = num_scan//num_proc
n_scan = 0
pipeline_config = []
for n_proc in range(num_proc):
num = scans_per_proc
if n_proc < num_scan - scans_per_proc*num_proc:
num += 1
spec_config = {
'station': map_config.station,
'experiment_type': map_config.experiment_type,
'spec_scans': [SpecScans(
spec_file=spec_scans.spec_file,
scan_numbers=scan_numbers[n_scan:n_scan+num]).__dict__]}
sub_pipeline_config = []
for item in deepcopy(sub_pipeline['pipeline']):
if isinstance(item, dict):
for k, v in deepcopy(item).items():
if k.endswith('Reader'):
v['config'] = spec_config
item[k] = v
if num_proc > 1 and k.endswith('Writer'):
r, e = os.path.splitext(v['filename'])
v['filename'] = f'{r}_{n_proc}{e}'
item[k] = v
sub_pipeline_config.append(item)
if collect_on_root and (not root_as_worker or num_proc > 1):
sub_pipeline_config += [
{'common.MPICollectProcessor': {
'root_as_worker': root_as_worker}}]
pipeline_config.append(sub_pipeline_config)
n_scan += num
# Spawn the workers to run the sub-pipeline
if num_proc > first_proc:
# System modules
from tempfile import NamedTemporaryFile
tmp_names = []
with NamedTemporaryFile(delete=False) as fp:
# pylint: disable=c-extension-no-member
fp_name = fp.name
tmp_names.append(fp_name)
with open(fp_name, 'w', encoding='utf-8') as f:
yaml.dump(
{'config': {'spawn': run_config.spawn}}, f,
sort_keys=False)
for n_proc in range(first_proc, num_proc):
f_name = f'{fp_name}_{n_proc}'
tmp_names.append(f_name)
with open(f_name, 'w', encoding='utf-8') as f:
yaml.dump(
#FIX once comm is a field of RunConfig
#{'config': run_config.model_dump(exclude='comm'),
{'config': run_config.model_dump(),
'pipeline': pipeline_config[n_proc]},
f, sort_keys=False)
# pylint: disable=used-before-assignment
sub_comm = MPI.COMM_SELF.Spawn(
'CHAP', args=[fp_name], maxprocs=num_proc-first_proc)
common_comm = sub_comm.Merge(False)
if run_config.spawn > 0:
# Align with the barrier in RunConfig() on common_comm
# called from the spawned main()
common_comm.barrier()
else:
common_comm = None
# Run the sub-pipeline on the root node
if root_as_worker:
data = runner(run_config, pipeline_config[0], comm=common_comm)
elif collect_on_root:
run_config.spawn = 0
pipeline_config = [{'common.MPICollectProcessor': {
'root_as_worker': root_as_worker}}]
data = runner(run_config, pipeline_config, common_comm)
else:
# Align with the barrier in run() on common_comm
# called from the spawned main()
common_comm.barrier()
data = None
# Disconnect spawned workers and cleanup temporary files
if num_proc > first_proc:
# Align with the barrier in main() on common_comm
# when disconnecting the spawned worker
common_comm.barrier()
# Disconnect spawned workers and cleanup temporary files
sub_comm.Disconnect()
for tmp_name in tmp_names:
os.remove(tmp_name)
return data
[docs]
class NexusToNumpyProcessor(Processor):
"""A Processor to convert the default plottable data in a NeXus
style
`NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__,
object into a `numpy.ndarray`.
"""
[docs]
def process(self, data):
"""Return the default plottable data signal in a NeXus style
`NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__,
object contained in `data` as an `numpy.ndarray`.
:param data: Input data.
:type data: list[PipelineData]
:raises ValueError: If `data` has no default plottable data
signal.
:return: Default plottable data signal.
:rtype: numpy.ndarray
"""
# Third party modules
from nexusformat.nexus import NXdata
data = self.get_pipelinedata_item(data)
if isinstance(data, NXdata):
default_data = data
else:
default_data = data.plottable_data
if default_data is None:
default_data_path = data.attrs.get('default')
default_data = data.get(default_data_path)
if default_data is None:
raise ValueError(
f'The structure of {data} contains no default data')
try:
default_signal = default_data.attrs['signal']
except Exception as exc:
raise ValueError(
f'The signal of {default_data} is unknown') from exc
np_data = default_data[default_signal].nxdata
return np_data
[docs]
class NexusToXarrayProcessor(Processor):
"""A Processor to convert the default plottable data in a
NeXus style
`NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__,
object into an `xarray.DataArray`.
"""
[docs]
def process(self, data):
"""Return the default plottable data signal in a NeXus style
`NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__,
object contained in `data` as an `xarray.DataArray`.
:param data: Input data.
:type data: list[PipelineData]
:raises ValueError: If metadata for `xarray` is absent from
`data`
:return: Default plottable data signal.
:rtype: xarray.DataArray
"""
# Third party modules
from nexusformat.nexus import NXdata
# pylint: disable=import-error
from xarray import DataArray
# pylint: enable=import-error
data = self.get_pipelinedata_item(data)
if isinstance(data, NXdata):
default_data = data
else:
default_data = data.plottable_data
if default_data is None:
default_data_path = data.attrs.get('default')
default_data = data.get(default_data_path)
if default_data is None:
raise ValueError(
f'The structure of {data} contains no default data')
try:
default_signal = default_data.attrs['signal']
except Exception as exc:
raise ValueError(
f'The signal of {default_data} is unknown') from exc
signal_data = default_data[default_signal].nxdata
axes = default_data.attrs['axes']
if isinstance(axes, str):
axes = [axes]
coords = {}
for axis_name in axes:
axis = default_data[axis_name]
coords[axis_name] = (axis_name, axis.nxdata, axis.attrs)
dims = tuple(axes)
name = default_signal
attrs = default_data[default_signal].attrs
return DataArray(data=signal_data,
coords=coords,
dims=dims,
name=name,
attrs=attrs)
[docs]
class NexusToZarrProcessor(Processor):
"""Converter for `NeXus <https://www.nexusformat.org>`__ to
`Zarr <https://zarr.readthedocs.io/en/stable/>`__
format.
"""
[docs]
def process(self, data, chunks='auto'):
"""Copy and return a
`Zarr group <https://zarr.readthedocs.io/en/stable/api/zarr/group/#zarr.Group>`__
object from a NeXus style
`NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__
object.
:param data: Input data.
:type data: list[PipelineData]
:return: Zarr style group object.
:rtype: zarr.Group
"""
# Third party modules
from nexusformat.nexus import (
NXfield,
NXgroup,
)
# pylint: disable=import-error
import zarr
from zarr.storage import MemoryStore
# pylint: enable=import-error
nexus_group = self.get_data(data)
if isinstance(chunks, int):
chunks = [chunks]
zarr_group = zarr.create_group(store=MemoryStore({}))
def copy_group(nexus_group, zarr_group):
"""Copy a NeXus style
`NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__
object to a
`Zarr group <https://zarr.readthedocs.io/en/stable/api/zarr/group/#zarr.Group>`__
object.
:param source_store: Source NeXus style `NXgroup`.
:type: nexusformat.nexus.NXgroup
:param dest_store: Destination Zarr group.
:type: zarr.Group
"""
self.logger.info(f'Copying {nexus_group.nxpath}')
# Copy attributes
for attr_key, attr_value in nexus_group.attrs.items():
if isinstance(attr_value.nxvalue, np.ndarray):
zarr_group.attrs[attr_key] = attr_value.nxvalue.tolist()
else:
zarr_group.attrs[attr_key] = attr_value.nxvalue
# Copy datasets and sub-groups
for key, item in nexus_group.items():
if isinstance(item, NXfield):
if isinstance(item.nxdata, np.ndarray):
try:
# Determine chunks
if isinstance(chunks, list):
if len(chunks) < len(item.nxdata.shape):
_chunks = (
*chunks,
*item.nxdata.shape[len(chunks):]
)
elif len(chunks) > len(item.nxdata.shape):
_chunks = 'auto'
else:
_chunks = chunks
else:
_chunks = chunks
# Copy dataset
zarr_dset = zarr_group.create_array(
name=key,
shape=item.nxdata.shape,
dtype=item.nxdata.dtype,
attributes={k: v.nxvalue
for k, v in item.attrs.items()},
chunks=_chunks,
)
self.logger.info(f'Copying {item.nxpath}')
zarr_dset[:] = item.nxdata
except Exception as exc:
self.logger.error(f'{item.nxpath}: {exc}')
else:
self.logger.warning(f'Ignoring {item.nxpath}')
elif isinstance(item, NXgroup):
# Recursively copy subgroup
zarr_subgroup = zarr_group.create_group(key)
copy_group(item, zarr_subgroup)
copy_group(nexus_group, zarr_group)
return zarr_group
[docs]
class NormalizeNexusProcessor(Processor):
"""Processor for scaling one or more
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
objects in the input NeXus style
`NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__
object by the values of another NXfield in the same object .
"""
[docs]
def process(self, data, normalize_nxfields, normalize_by_nxfield):
"""Return copy of the original input NeXus style
`NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__
object with additional fields containing the normalized data
of each field in `normalize_nxfields`.
:param data: Input data.
to normalize them.
:type data: list[PipelineData]
:param normalize_nxfields:
:type normalize_nxfields: list[str]
:param normalize_by_nxfield: Path in `data` to the
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
containing normalization data
:type normalize_by_nxfield: str
:returns: Copy of input data with additional normalized fields
:rtype: nexusformat.nexus.NXgroup
"""
# Third party modules
from nexusformat.nexus import (
NXgroup,
NXfield,
)
# Local modules
from CHAP.utils.general import nxcopy
# Check input data
data = self.get_pipelinedata_item(data)
data = nxcopy(data)
if not isinstance(data, NXgroup):
raise TypeError(f'Expected NXgroup, got (type{data})')
# Check normalize_by_nxfield
if normalize_by_nxfield not in data:
raise ValueError(
f'{normalize_by_nxfield} not present in input data')
if not isinstance(data[normalize_by_nxfield], NXfield):
raise TypeError(
f'{normalize_by_nxfield} is {type(data[normalize_by_nxfield])}'
+ ', expected NXfield')
normalization_data = data[normalize_by_nxfield].nxdata
# Process normalize_nxfields
for nxfield in normalize_nxfields:
if nxfield not in data:
self.logger.error(f'{nxfield} not present in input data')
elif not isinstance(data[nxfield], NXfield):
self.logger.error(
f'{nxfield} is {type(data[nxfield])}, expected NXfield')
else:
field_shape = data[nxfield].nxdata.shape
if not normalization_data.shape == \
field_shape[:normalization_data.ndim]:
self.logger.error(
f'Incompatible dataset shapes: {normalize_by_nxfield} '
+ f'is {normalization_data.shape}, '
+ f'{nxfield} is {field_shape}'
)
else:
self.logger.info(f'Normalizing {nxfield}')
# make shapes compatible
_normalization_data = normalization_data.reshape(
normalization_data.shape + (1,)
* (data[nxfield].nxdata.ndim
- normalization_data.ndim))
data[f'{nxfield}_normalized'] = NXfield(
value=data[nxfield].nxdata / _normalization_data,
attrs={**data[nxfield].attrs,
'normalized_by': normalize_by_nxfield}
)
return data
[docs]
class NormalizeMapProcessor(Processor):
"""Processor for calling
:class:`~CHAP.common.processor.NormalizeNexusProcessor` for
(usually all) detector data in a NeXus style
`NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__
object created by :class:`~CHAP.common.processor.MapProcessor`
"""
[docs]
def process(self, data, normalize_by_nxfield, detector_ids=None):
"""Return copy of the original input map with additional
fields containing normalized detector data.
:param data: Input data.
:type data: list[PipelineData]
:param normalize_by_nxfield: Path in `data` to the
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
containing normalization data.
:type normalize_by_nxfield: str
:returns: Copy of input data with additional normalized fields.
:rtype: nexusformat.nexus.NXroot
"""
# Third party modules
from nexusformat.nexus import (
NXentry,
NXlink,
)
# Check input data
data = self.get_pipelinedata_item(data)
map_title = None
for k, v in data.items():
if isinstance(v, NXentry):
map_title = k
break
if map_title is None:
self.logger.error(f'Input data contains no NXentry')
else:
self.logger.info(f'Got map_title: {map_title}')
# Check detector_ids
normalize_nxfields = []
if detector_ids is None:
detector_ids = [k for k in data[map_title].data.keys()
if not isinstance(data[map_title].data[k], NXlink)]
self.logger.info(f'Using detector_ids: {detector_ids}')
normalize_nxfields = [f'{map_title}/data/{_id}'
for _id in detector_ids]
# Normalize
normalizer = NormalizeNexusProcessor()
normalizer.logger = self.logger
return normalizer.process(
data, normalize_nxfields, normalize_by_nxfield)
[docs]
class PandasToXarrayProcessor(Processor):
"""Converter for `pandas.DataFrame` to `xarray.DataArray` or
`xarray.Dataset`
"""
[docs]
def process(self, data):
"""Return input dataframe converted to xarray.
:param data: Input data.
:type data: list[PipelineData]
:returns: Input dataframe as xarray.
:rtype: xarray.DataArray | xarray.Dataset
"""
dataframe = self.get_data(data)
return dataframe.to_xarray()
[docs]
class PrintProcessor(Processor):
"""A Processor to simply print the input data to stdout and return
the original input data, unchanged in any way.
"""
[docs]
def process(self, data):
"""Print and return the input data.
:param data: Input data.
:type data: list[PipelineData]
:return: `data`
:rtype: Any
"""
if callable(getattr(data, '_str_tree', None)):
# If data is likely a NXobject, print its tree
# representation (since NXobjects' str representations are
# just their nxname)
print(data._str_tree(attrs=True, recursive=True))
else:
# System modules
from pprint import pprint
pprint(data)
return data
[docs]
class PyfaiAzimuthalIntegrationProcessor(Processor):
"""Processor to azimuthally integrate one or more frames of 2d
detector data using the
`pyFAI <https://pyfai.readthedocs.io/en/stable>`__ package.
"""
[docs]
def process(
self, data, poni_file, npt, mask_file=None,
integrate1d_kwargs=None):
"""Azimuthally integrate the detector data provided and return
the result as a dictionary of numpy arrays containing the
values of the radial coordinate of the result, the intensities
along the radial direction, and the poisson errors for each
intensity spectrum.
:param data: Detector data to integrate.
:type data: PipelineData | list[np.ndarray]
:param poni_file: Name of the [pyFAI PONI file]
containing the detector properties pyFAI needs to perform
azimuthal integration.
:type poni_file: str
:param npt: Number of points in the output pattern.
:type npt: int
:param mask_file: File to use for masking the input data.
:type mask_file: str, optional
:param integrate1d_kwargs: Optional dictionary of keywords
:type integrate1d_kwargs: Optional[dict]
:returns: Azimuthal integration results as a dictionary of
numpy arrays.
"""
# Third party modules
from pyFAI import load
if not os.path.isabs(poni_file):
poni_file = os.path.join(self.inputdir, poni_file)
ai = load(poni_file)
if mask_file is None:
mask = None
else:
# Third party modules
import fabio
if not os.path.isabs(mask_file):
mask_file = os.path.join(self.inputdir, mask_file)
mask = fabio.open(mask_file).data
try:
det_data = self.get_pipelinedata_item(data)
except ValueError:
det_data = data
if integrate1d_kwargs is None:
integrate1d_kwargs = {}
integrate1d_kwargs['mask'] = mask
return [ai.integrate1d(d, npt, **integrate1d_kwargs) for d in det_data]
[docs]
class RawDetectorDataMapProcessor(Processor):
"""A Processor to return a map of raw detector data in a
NeXus style
`NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__
object.
"""
[docs]
def process(self, data, detector_name, detector_shape):
"""Process configurations for a map and return the raw
detector data data collected over the map.
:param data: Input data.
:type data: list[PipelineData]
:param detector_name: Detector prefix.
:type detector_name: str
:param detector_shape: Detector data shape for a single scan
step.
:type detector_shape: list
:return: Map of raw detector data.
:rtype: nexusformat.nexus.NXroot
"""
map_config = self.get_config(data)
nxroot = self.get_nxroot(map_config, detector_name, detector_shape)
return nxroot
[docs]
def get_config(self, data):
"""Get instances of the map configuration object needed by this
`Processor`.
:param data: Result of `Reader.read` where at least one item
has the value `'common.models.map.MapConfig'` for the
`'schema'` key.
:type data: list[PipelineData]
:raises Exception: If a valid map config object cannot be
constructed from `data`.
:return: Valid instance of the map configuration object with
field values taken from `data`.
:rtype: MapConfig
"""
map_config = False
if isinstance(data, list):
for item in data:
if isinstance(item, dict):
if item.get('schema') == 'common.models.map.MapConfig':
map_config = item.get('data')
if not map_config:
raise ValueError('No map configuration found in input data')
return MapConfig(**map_config)
[docs]
def get_nxroot(self, map_config, detector_name, detector_shape):
"""Get a map of the detector data collected by the scans in
`map_config`. The data will be returned along with some
relevant metadata in the form of a NeXus style
`NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__
structure.
:param map_config: Map configuration.
:type map_config: MapConfig
:param detector_name: Detector prefix.
:type detector_name: str
:param detector_shape: Detector data shape for a single scan
step.
:type detector_shape: list
:return: Map of the raw detector data.
:rtype: nexusformat.nexus.NXroot
"""
# Third party modules
# pylint: disable=no-name-in-module
from nexusformat.nexus import (
NXdata,
NXdetector,
NXinstrument,
NXroot,
)
# pylint: enable=no-name-in-module
raise RuntimeError('Not updated for the new MapProcessor')
nxroot = NXroot()
nxroot[map_config.title] = MapProcessor.get_nxentry(map_config)
nxentry = nxroot[map_config.title]
nxentry.instrument = NXinstrument()
nxentry.instrument.detector = NXdetector()
nxentry.instrument.detector.data = NXdata()
nxdata = nxentry.instrument.detector.data
nxdata.raw = np.empty((*map_config.shape, *detector_shape))
nxdata.raw.attrs['units'] = 'counts'
for i, det_axis_size in enumerate(detector_shape):
nxdata[f'detector_axis_{i}_index'] = np.arange(det_axis_size)
for map_index in np.ndindex(map_config.shape):
scans, scan_number, scan_step_index = \
map_config.get_scan_step_index(map_index)
scanparser = scans.get_scanparser(scan_number)
self.logger.debug(
f'Adding data to nxroot for map point {map_index}')
nxdata.raw[map_index] = scanparser.get_detector_data(
detector_name, scan_step_index)
nxentry.data.makelink(
nxdata.raw,
name=detector_name)
for i, det_axis_size in enumerate(detector_shape):
nxentry.data.makelink(
nxdata[f'detector_axis_{i}_index'],
name=f'{detector_name}_axis_{i}_index'
)
if isinstance(nxentry.data.attrs['axes'], str):
nxentry.data.attrs['axes'] = [
nxentry.data.attrs['axes'],
f'{detector_name}_axis_{i}_index']
else:
nxentry.data.attrs['axes'] += [
f'{detector_name}_axis_{i}_index']
nxentry.data.attrs['signal'] = detector_name
return nxroot
[docs]
class SetupNXdataProcessor(Processor):
"""Processor to set up and return an "empty" representation of a structured
dataset. This representation will be an instance of a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object that has: A NeXus style
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
entry for every coordinate/signal specified. `nxaxes` that are the
`NXfield` entries for the coordinates and contain the values provided for
each coordinate. NXfield entries of appropriate shape, but containing all
zeros, for every signal. Attributes that define the axes, plus any
additional attributes specified by the user.
This `Processor` is most useful as a "setup" step for constucting a
representation of / container for a complete dataset that will be filled
out in pieces later by
:class:`~CHAP.common.processor.UpdateNXdataProcessor`.
"""
[docs]
def process(
self, data, nxname='data', coords=None, signals=None, attrs=None,
data_points=None, extra_nxfields=None, duplicates='overwrite'):
"""Return a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object that has the requisite axes and
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
entries to represent a structured dataset with the
properties provided. Properties may be provided either through
the `data` argument (from an appropriate `PipelineItem` that
immediately preceeds this one in a `Pipeline`), or through the
`coords`, `signals`, `attrs`, and/or `data_points` arguments.
If any of the latter are used, their values will completely
override any values for these parameters found from `data`.
:param data: Input data.
:type data: list[PipelineData]
:param nxname: Name for the returned NXdata object,
defaults to `'data'`.
:type nxname: str, optional
:param coords: List of dictionaries defining the coordinates
of the dataset. Each dictionary must have the keys
`'name'` and `'values'`, whose values are the name of the
coordinate axis (a string) and all the unique values of
that coordinate for the structured dataset (a list of
numbers), respectively. A third item in the dictionary is
optional, but highly recommended: `'attrs'` may provide a
dictionary of attributes to attach to the coordinate axis
that assist in in interpreting the returned NXdata
representation of the dataset. It is strongly recommended
to provide the units of the values along an axis in the
`attrs` dictionary.
:type coords: list[dict[str, object]], optional
:param signals: List of dictionaries defining the signals of
the dataset. Each dictionary must have the keys `'name'`
and `'shape'`, whose values are the name of the signal
field (a string) and the shape of the signal's value at
each point in the dataset (a list of zero or more
integers), respectively. A third item in the dictionary is
optional, but highly recommended: `'attrs'` may provide a
dictionary of attributes to attach to the signal fieldthat
assist in in interpreting the returned NXdata
representation of the dataset. It is strongly recommended
to provide the units of the signal's values `attrs`
dictionary.
:type signals: list[dict[str, object]], optional
:param attrs: An arbitrary dictionary of attributes to assign
to the returned NXdata object.
:type attrs: dict[str, object], optional
:param data_points: Data points to partially (or even entirely)
fill out the "empty" signal NXfield's before
returning the NXdata object.
:type data_points: list[dict[str, object]], optional
:param extra_nxfields: List "extra" NXfields to include
that can be described neither as a signal of the dataset,
not a dedicated coordinate. This paramteter is good for
including "alternate" values for one of the coordinate
dimensions -- the same coordinate axis expressed in
different units, for instance. Each item in the list should
be a dictionary of parameters for the
`nexusformat.nexus.NXfield` constructor.
:type extra_nxfields: list[dict[str, object]], optional
:param duplicates: Behavior to use if any new data points occur
at the same point in the dataset's coordinate space as an
existing data point. Allowed values for `duplicates` are:
`'overwrite'` and `'block'`. Defaults to `'overwrite'`.
:type duplicates: Literal['overwrite', 'block']
:returns: Structured dataset as specified.
:rtype: nexusformat.nexus.NXdata
"""
self.nxname = nxname
if coords is None:
coords = []
if signals is None:
signals = []
if attrs is None:
attrs = {}
if extra_nxfields is None:
extra_nxfields = []
self.coords = coords
self.signals = signals
self.attrs = attrs
self.data_points = data_points
try:
setup_params = self.get_pipelinedata_item(data)
except Exception:
setup_params = None
if isinstance(setup_params, dict):
for a in ('coords', 'signals', 'attrs', 'data_points'):
setup_param = setup_params.get(a)
if not getattr(self, a) and setup_param is not None:
self.logger.info(f'Using input data from pipeline for {a}')
setattr(self, a, setup_param)
else:
self.logger.info(
f'Ignoring input data from pipeline for {a}')
else:
self.logger.warning('Ignoring all input data from pipeline')
self.shape = tuple(len(c['values']) for c in self.coords)
self.extra_nxfields = extra_nxfields
self.duplicates = duplicates
self.init_nxdata()
if self.data_points is not None:
for d in self.data_points:
self.add_data_point(d)
return self.nxdata
[docs]
def add_data_point(self, data_point):
"""Add a data point to this dataset.
1. Validate `data_point`.
2. Append `data_point` to `self.data_points`.
3. Update signal NXfields in `self.nxdata`.
:param data_point: Data point defining a point in the
dataset's coordinate space and the new signal values at
that point.
:type data_point: dict[str, object]
"""
self.logger.info(
f'Adding data point no. {data_point["dataset_point_index"]+1} of '
f'{len(self.data_points)}')
self.logger.debug(f'New data point: {data_point}')
valid, msg = self.validate_data_point(data_point)
if not valid:
self.logger.error(f'Cannot add data point: {msg}')
else:
self.update_nxdata(data_point)
[docs]
def validate_data_point(self, data_point):
"""Return `True` if `data_point` occurs at a valid point in
this structured dataset's coordinate space, `False` otherwise.
Also validate shapes of signal values and add NaN values for
any missing signals.
:param data_point: Data point defining a point in the dataset's
coordinate space and the new signal values at that point.
:type data_point: dict[str, object]
:returns: Validity of `data_point`, message
:rtype: bool, str
"""
valid = True
msg = ''
# Convert all values to numpy types
data_point = {k: np.asarray(v) for k, v in data_point.items()}
# Ensure data_point defines a specific point in the dataset's
# coordinate space
if not all(c['name'] in data_point for c in self.coords):
valid = False
msg = 'Missing coordinate values'
# Ensure a value is present for all signals
for s in self.signals:
name = s['name']
if name not in data_point:
data_point[name] = np.full(s['shape'], 0)
else:
if not data_point[name].shape == tuple(s['shape']):
valid = False
msg = f'Shape mismatch for signal {s}'
return valid, msg
[docs]
def init_nxdata(self):
"""Initialize an empty NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html>`__
representing this dataset to `self.nxdata`; values for axes
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
objects are filled out, values for signals' NXfields are empty an
can be filled out later.
"""
# Third party modules
from nexusformat.nexus import (
NXdata,
NXfield,
)
axes = tuple(NXfield(
value=c['values'],
name=c['name'],
attrs=c.get('attrs'),
dtype=c.get('dtype', 'float64')) for c in self.coords)
entries = {s['name']: NXfield(
value=np.full((*self.shape, *s['shape']), 0),
name=s['name'],
attrs=s.get('attrs'),
dtype=s.get('dtype', 'float64')) for s in self.signals}
extra_nxfields = [NXfield(**params) for params in self.extra_nxfields]
extra_nxfields = {f.nxname: f for f in extra_nxfields}
entries.update(extra_nxfields)
self.nxdata = NXdata(
name=self.nxname, axes=axes, entries=entries, attrs=self.attrs)
[docs]
def update_nxdata(self, data_point):
"""Update `self.nxdata`'s
`NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__
values.
:param data_point: Data point defining a point in the
dataset's coordinate space and the new signal values at
that point.
:type data_point: dict[str, object]
"""
index = self.get_index(data_point)
for s in self.signals:
name = s['name']
if name in data_point:
self.nxdata[name][index] = data_point[name]
[docs]
def get_index(self, data_point):
"""Return a tuple representing the array index of `data_point`
in the coordinate space of the dataset.
:param data_point: Data point defining a point in the
dataset's coordinate space.
:type data_point: dict[str, object]
:returns: Multi-dimensional index of `data_point` in the
dataset's coordinate space.
:rtype: tuple
"""
return tuple(c['values'].index(data_point[c['name']])
for c in self.coords)
[docs]
class UnstructuredToStructuredProcessor(Processor):
"""Processor to reshape data in a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object from an "unstructured" to a "structured" representation.
"""
[docs]
def process(self, data, nxpath=None):
"""Reshape the input data from an "unstructured" to a
"structured" representation.
:param data: Input data.
:type data: list[PipelineData]
:param nxname: Name for the returned
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object, defaults to `'data'`.
:type nxname: str, optional
:return: Converted data.
:rtype: nexusformat.nexus.NXdata
"""
# Third party modules
from nexusformat.nexus import NXdata
try:
nxobject = self.get_data(data)
except Exception:
nxobject = self.get_pipelinedata_item(data)
if isinstance(nxobject, NXdata):
return self._convert_nxdata(nxobject)
if nxpath is not None:
try:
nxobject = nxobject[nxpath]
except Exception as exc:
raise ValueError(
f'Invalid parameter nxpath ({nxpath})') from exc
else:
raise ValueError(f'Invalid input data ({data})')
return self._convert_nxdata(nxobject)
def _convert_nxdata(self, nxdata):
"""Convert a NeXus style `NXdata` object from an "unstructured" to a
"structured" representation.
"""
# Third party modules
from nexusformat.nexus import (
NXdata,
NXfield,
)
# Local modules
from CHAP.common.map_utils import get_axes
# Extract axes from the NXdata attributes
axes = get_axes(nxdata)
for a in axes:
if a not in nxdata:
raise ValueError(f'Missing coordinates for {a}')
# Check the independent dimensions and axes
unstructured_axes = []
unstructured_dim = None
for a in axes:
if not isinstance(nxdata[a], NXfield):
raise ValueError(
f'Invalid axis field type ({type(nxdata[a])})')
if len(nxdata[a].shape) == 1:
if not unstructured_axes:
unstructured_axes.append(a)
unstructured_dim = nxdata[a].size
else:
if nxdata[a].size == unstructured_dim:
unstructured_axes.append(a)
elif 'unstructured_axes' in nxdata.attrs:
raise ValueError('Inconsistent axes dimensions')
elif 'unstructured_axes' in nxdata.attrs:
raise ValueError(
f'Invalid unstructered axis shape ({nxdata[a].shape})')
if not axes and hasattr(nxdata, 'signal'):
if len(nxdata[nxdata.signal].shape) < 2:
raise ValueError(
f'Invalid signal shape ({nxdata[nxdata.signal].shape})')
unstructured_dim = nxdata[nxdata.signal].shape[0]
for k, v in nxdata.items():
if (isinstance(v, NXfield) and len(v.shape) == 1
and v.shape[0] == unstructured_dim):
unstructured_axes.append(k)
if unstructured_dim is None:
raise ValueError('Unable to determine the unstructered axes')
axes = unstructured_axes
# Identify unique coordinate points for each axis
unique_coords = {}
coords = {}
axes_attrs = {}
for a in axes:
coords[a] = nxdata[a].nxdata
unique_coords[a] = np.sort(np.unique(nxdata[a].nxdata))
axes_attrs[a] = deepcopy(nxdata[a].attrs)
if 'target' in axes_attrs[a]:
del axes_attrs[a]['target']
# Calculate the total number of unique coordinate points
unique_npts = np.prod([len(v) for k, v in unique_coords.items()])
if unique_npts != unstructured_dim:
self.logger.warning('The unstructered grid does not fully map to '
'a structered one (there are missing points)')
# Identify the signals and the data point axes
signals = []
data_point_axes = []
data_point_shape = []
if hasattr(nxdata, 'signal'):
if (len(nxdata[nxdata.signal].shape) < 2
or nxdata[nxdata.signal].shape[0] != unstructured_dim):
raise ValueError(
f'Invalid signal shape ({nxdata[nxdata.signal].shape})')
signals = [nxdata.signal]
data_point_shape = [nxdata[nxdata.signal].shape[1:]]
for k, v in nxdata.items():
if (isinstance(v, NXfield) and k not in axes and k not in signals
and v.shape[0] == unstructured_dim):
signals.append(k)
if not data_point_shape:
data_point_shape.append(v.shape[1:])
if len(data_point_shape) == 1:
data_point_shape = data_point_shape[0]
else:
data_point_shape = []
for _ in data_point_shape:
for k, v in nxdata.items():
if (isinstance(v, NXfield) and k not in axes
and v.shape == data_point_shape):
data_point_axes.append(k)
# Create the structured NXdata object
structured_shape = tuple(len(unique_coords[a]) for a in axes)
attrs = deepcopy(nxdata.attrs)
if 'unstructured_axes' in attrs:
attrs.pop('unstructured_axes')
attrs['axes'] = axes
nxdata_structured = NXdata(
name=f'{nxdata.nxname}_structured',
**{a: NXfield(
value=unique_coords[a],
attrs=axes_attrs[a])
for a in axes},
**{s: NXfield(
# value=np.reshape( # FIX not always a sound way to reshape.
# nxdata[s], (*structured_shape, *nxdata[s].shape[1:])),
dtype=nxdata[s].dtype,
shape=(*structured_shape, *nxdata[s].shape[1:]),
attrs=nxdata[s].attrs)
for s in signals},
attrs=attrs)
if len(data_point_axes) == 1:
axes = nxdata_structured.attrs['axes']
if isinstance(axes, str):
axes = [axes]
nxdata_structured.attrs['axes'] = axes + data_point_axes
for a in data_point_axes:
nxdata_structured[a] = NXfield(
value=nxdata[a], attrs=nxdata[a].attrs)
# Populate the structured NXdata object with values
for i, coord in enumerate(zip(*tuple(nxdata[a].nxdata for a in axes))):
structured_index = tuple(
np.asarray(
coord[ii] == unique_coords[axes[ii]]).nonzero()[0][0]
for ii in range(len(axes)))
for s in signals:
nxdata_structured[s][structured_index] = nxdata[s][i]
return nxdata_structured
[docs]
class UpdateNXvalueProcessor(Processor):
"""Processor to fill in part(s) of an object representing a
structured dataset that's already been written to a
`NeXus <https://www.nexusformat.org>`__ file.
This Processor is most useful as an "update" step for a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object created by
:class:`~CHAP.common.processor.SetupNXdataProcessor`, and is
most easy to use in a `Pipeline` immediately after another `PipelineItem`
designed specifically to return a value that can be used as input to this
`Processor`.
"""
[docs]
def process(self, data, nxfilename, data_points=None):
"""Write new data values to an existing object representing an
unstructured dataset in a
`NeXus <https://www.nexusformat.org>`__ file. Return the list
of data points used to update the dataset.
:param data: Data from the previous item in a `Pipeline`. May
contain a list of data points that will extend the list of
data points optionally provided with the `data_points`
argument.
:type data: list[PipelineData]
:param nxfilename: Name of the NeXus file containing the object
to update.
:type nxfilename: str
:param data_points: List of data points, each one a dictionary
whose keys are the names of the nxpath, the index of the
data point in the dataset, and the data value.
:type data_points: Optional[list[dict[str, object]]]
:returns: Complete list of data points used to update the
dataset.
:rtype: list[dict[str, object]]
"""
# Third party modules
from nexusformat.nexus import NXFile
if data_points is None:
data_points = []
self.logger.debug(f'Got {len(data_points)} data points from keyword')
ddata_points = self.get_pipelinedata_item(data)
if isinstance(ddata_points, list):
self.logger.debug(f'Got {len(ddata_points)} from pipeline data')
data_points.extend(ddata_points)
self.logger.info(f'Updating a total of {len(data_points)} data points')
if not os.path.isabs(nxfilename):
nxfilename = os.path.join(self.inputdir, nxfilename)
nxfile = NXFile(nxfilename, 'rw')
indices = []
for data_point in data_points:
try:
nxfile.writevalue(
data_point['nxpath'], np.asarray(data_point['value']),
data_point['index'])
indices.append(data_point['index'])
except Exception as exc:
self.logger.error(
f'Error updating {data_point["nxpath"]} for data point '
f'{data_point["index"]}: {exc}')
else:
self.logger.debug(f'Updated data point {data_point}')
nxfile.close()
return data_points
[docs]
class UpdateNXdataProcessor(Processor):
"""Processor to fill in part(s) of a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
representing a structured dataset that's already been written to a
`NeXus <https://www.nexusformat.org>`__ file.
This Processor is most useful as an "update" step for a
`NXdata` object created by
:class:`~CHAP.common.processor.SetupNXdataProcessor`, and is most
easy to use in a `Pipeline` immediately after another
`PipelineItem` designed specifically to return a value that can be
used as input to this `Processor`.
"""
[docs]
def process(
self, data, nxfilename, nxdata_path, data_points=None,
allow_approximate_coordinates=False):
"""Write new data points to the signal fields of an existing
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object representing a structued dataset in a
`NeXus <https://www.nexusformat.org>`__ file.
Return the list of data points used to update the dataset.
:param data: Data from the previous item in a `Pipeline`. May
contain a list of data points that will extend the list of
data points optionally provided with the `data_points`
argument.
:type data: list[PipelineData]
:param nxfilename: Name of the NeXus file containing the
NXdata object to update.
:type nxfilename: str
:param nxdata_path: Path to the NXdata object to update in
the file.
:type nxdata_path: str
:param data_points: List of data points, each one a dictionary
whose keys are the names of the coordinates and axes, and
whose values are the values of each coordinate / signal at
a single point in the dataset. Deafults to None.
:type data_points: Optional[list[dict[str, object]]]
:param allow_approximate_coordinates: Parameter to allow the
nearest existing match for the new data points'
coordinates to be used if an exact match connot be found
(sometimes this is due simply to differences in rounding
convetions). Defaults to False.
:type allow_approximate_coordinates: bool, optional
:returns: Complete list of data points used to update the
dataset.
:rtype: list[dict[str, object]]
"""
# Third party modules
from nexusformat.nexus import NXFile
if data_points is None:
data_points = []
self.logger.debug(f'Got {len(data_points)} data points from keyword')
_data_points = self.get_pipelinedata_item(data)
if isinstance(_data_points, list):
self.logger.debug(f'Got {len(_data_points)} from pipeline data')
data_points.extend(_data_points)
self.logger.info(f'Updating {len(data_points)} data points total')
if not os.path.isabs(nxfilename):
nxfilename = os.path.join(self.inputdir, nxfilename)
nxfile = NXFile(nxfilename, 'rw')
nxdata = nxfile.readfile()[nxdata_path]
axes_names = [a.nxname for a in nxdata.nxaxes]
data_points_used = []
for i, d in enumerate(data_points):
# Verify that the data point contains a value for all
# coordinates in the dataset.
if not all(a in d for a in axes_names):
self.logger.error(
f'Data point {i} is missing a value for at least one '
f'axis. Skipping. Axes are: {", ".join(axes_names)}')
continue
self.logger.debug(
f'Coordinates for data point {i}: ' +
', '.join([f'{a}={d[a]}' for a in axes_names]))
# Get the index of the data point in the dataset based on
# its values for each coordinate.
try:
index = tuple(np.where(a.nxdata == d[a.nxname])[0][0]
for a in nxdata.nxaxes)
except Exception:
if allow_approximate_coordinates:
try:
index = tuple(
np.argmin(np.abs(a.nxdata - d[a.nxname]))
for a in nxdata.nxaxes)
self.logger.warning(
f'Nearest match for coordinates of data point {i}:'
', '.join(
[f'{a.nxname}={a[_i]}'
for _i, a in zip(index, nxdata.nxaxes)]))
except Exception:
self.logger.error(
f'Cannot get the index of data point {i}. '
'Skipping.')
continue
else:
self.logger.error(
f'Cannot get the index of data point {i}. Skipping.')
continue
self.logger.debug(f'Index of data point {i}: {index}')
# Update the signals contained in this data point at the
# proper index in the dataset's singal NXfields
for k, v in d.items():
if k in axes_names:
continue
try:
nxfile.writevalue(
os.path.join(nxdata_path, k), np.asarray(v), index)
# self.logger.debug(
# f'Wrote to {os.path.join(nxdata_path, k)} in '
# f'{nxfilename} at index {index} value: {np.asarray(v)}'
# f' (type: {type(v)})')
except Exception as exc:
self.logger.error(
f'Error updating signal {k} for new data point '
f'{i} (dataset index {index}): {exc}')
data_points_used.append(d)
nxfile.close()
return data_points_used
[docs]
class NumpyStackProcessor(Processor):
"""Processor for joining a sequence of arrays along a new axis.
Uses (`numpy.stack`)[https://numpy.org/doc/stable/reference/generated/numpy.stack.html].
:ivar stack_order: List of names of input data arrays to determine
order of stacking. If not specified, data arrays are stacked
in the exact order input to the Processor. Defaults to `None`.
:type stack_order: list[str], optional
:ivar kwargs: Dictionary of keyword arguments to
(`numpy.stack`)[https://numpy.org/doc/stable/reference/generated/numpy.stack.html]; defaults to `{}`
:type kwargs: dict, optional.
"""
stack_order: Optional[conlist(item_type=str, min_length=1)] = None
kwargs: Optional[dict] = {}
[docs]
def process(self, data):
import numpy as np
arrays = ()
if self.stack_order is None:
for d in data:
try:
arrays = (*arrays, np.asarray(d['data']))
except:
self.logger.warning(
f'Omitting input data {d["name"]} '
+ f'(type: {type(d["data"])}).'
)
else:
for name in self.stack_order:
arrays = (*arrays, self.get_data(name=name))
return np.stack(arrays, **self.kwargs)
[docs]
class NumpySumProcessor(Processor):
"""Processor for summing an array of elements over a given axis.
Uses (`numpy.sum`)[https://numpy.org/doc/stable/reference/generated/numpy.sum.html].
:ivar kwargs: Dictionary of keyword arguments to
(`numpy.sum`)[https://numpy.org/doc/stable/reference/generated/numpy.sum.html]; defaults to `{}`.
:type kwargs: dict, optional
"""
kwargs: Optional[dict] = {}
[docs]
def process(self, data):
import numpy as np
_data = None
for d in data[::-1]:
try:
_data = np.asarray(d['data'])
break
except:
continue
if _data is None:
err = 'No array-like input data found.'
self.logger.error(err)
raise TypeError(err)
_data = np.asarray(_data)
self.logger.debug(f'_data.shape = {_data.shape}')
return np.sum(_data, **self.kwargs)
[docs]
class NumpyToNXfieldProcessor(Processor):
"""Processor for converting a numpy array into an `NXfield`.
:ivar value: Name of input data array to use as the field's
values. If unspecified, use the last array-like data object in
the input data list. Defaults to `None`.
:type value: str, optional
:ivar kwargs: Dictionary of keyword arguments to
[`nexusformat.nexus.tree.NXfield`](https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield); defaults to `{}`
:type kwargs: dict, optional
"""
value: Optional[str] = None
kwargs: Optional[dict] = {}
[docs]
def process(self, data):
import numpy as np
from nexusformat.nexus import NXfield
_data = None
if self.value is None:
for d in data[::-1]:
try:
_data = np.asarray(d['data'])
self.logger.debug(f'Using {d["name"]}')
break
except:
continue
if _data is None:
err = 'No array-like input data found.'
self.logger.error(err)
raise TypeError(err)
else:
_data = self.get_data(data, name=self.value)
self.logger.debug(f'_data.shape = {_data.shape}')
return NXfield(value=_data, **self.kwargs)
[docs]
class NXdataToDataPointsProcessor(Processor):
"""Transform a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object into a list of dictionaries.
Each dictionary represents a single data point in the coordinate
space of the dataset. The keys are the names of the signals and
axes in the dataset, and the values are a single scalar value (in
the case of axes) or the value of the signal at that point in the
coordinate space of the dataset (in the case of signals -- this
means that values for signals may be any shape, depending on the
shape of the signal itself).
"""
[docs]
def process(self, data):
"""Return a list of dictionaries representing the coordinate
and signal values at every point in the dataset provided.
:param data: Input data.
:type data: list[PipelineData]
:returns: List of all data points in the dataset.
:rtype: list[dict[str,object]]
"""
nxdata = self.get_pipelinedata_item(data)
data_points = []
axes_names = [a.nxname for a in nxdata.nxaxes]
self.logger.info(f'Dataset axes: {axes_names}')
dataset_shape = tuple([a.size for a in nxdata.nxaxes])
self.logger.info(f'Dataset shape: {dataset_shape}')
signal_names = [k for k, v in nxdata.entries.items()
if not k in axes_names \
and v.shape[:len(dataset_shape)] == dataset_shape]
self.logger.info(f'Dataset signals: {signal_names}')
other_fields = [k for k, v in nxdata.entries.items()
if not k in axes_names + signal_names]
if len(other_fields) > 0:
self.logger.warning(
'Ignoring the following fields that cannot be interpreted as '
f'either dataset coordinates or signals: {other_fields}')
for i in np.ndindex(dataset_shape):
data_points.append({
**{a: nxdata[a][_i] for a, _i in zip(axes_names, i)},
**{s: nxdata[s].nxdata[i] for s in signal_names},
})
return data_points
[docs]
class XarrayToNexusProcessor(Processor):
"""A Processor to convert the data in an `xarray` structure to a
NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object.
"""
[docs]
def process(self, data):
"""Return the input data represented as a NeXus style
`NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__
object.
:param data: Input data.
:type data: list[PipelineData]
:return: Data and metadata in `data`.
:rtype: nexusformat.nexus.NXdata
"""
# Third party modules
from nexusformat.nexus import (
NXdata,
NXfield,
)
data = self.get_pipelinedata_item(data)
signal = NXfield(value=data.data, name=data.name, attrs=data.attrs)
axes = []
for name, coord in data.coords.items():
axes.append(
NXfield(value=coord.data, name=name, attrs=coord.attrs))
axes = tuple(axes)
return NXdata(signal=signal, axes=axes)
[docs]
class XarrayToNumpyProcessor(Processor):
"""A Processor to convert the data in an `xarray.DataArray`
structure to an `numpy.ndarray`.
"""
[docs]
def process(self, data):
"""Return just the signal values contained in `data`.
:param data: Input data.
:type data: list[PipelineData]
:return: Data in `data`.
:rtype: numpy.ndarray
"""
return self.get_pipelinedata_item(data).data
[docs]
class ZarrToNexusProcessor(Processor):
"""Processor for converting
`Zarr <https://zarr.readthedocs.io/en/stable/>`__
data to
`NeXus <https://www.nexusformat.org>`__ file.
format.
"""
[docs]
def process(self, data, zarr_filename, nexus_filename):
"""Convert the signal values contained in the input data.
:param data: Input data.
:type data: list[PipelineData]
:param zarr_filename: Zarr input file name.
:type zarr_filename: str
:param nexus_filename: NeXus output file name.
:type nexus_filename: str
"""
# Third party modules
import h5py
# pylint: disable=import-error
import zarr
# pylint: enable=import-error
if not os.path.isabs(zarr_filename):
zarr_filename = os.path.join(self.inputdir, zarr_filename)
if not os.path.isabs(nexus_filename):
nexus_filename = os.path.join(self.inputdir, nexus_filename)
# Open the Zarr file
zarr_file = zarr.open(zarr_filename, mode='r')
# Create the Nexus file
with h5py.File(nexus_filename, 'w') as nexus_file:
# Recursively copy all datasets and attributes
def copy_group(zarr_group, nexus_group):
"""Convert a
`Zarr group <https://zarr.readthedocs.io/en/latest/api/zarr/group/#zarr.Group>`__
object to a NeXus style
`NXgroup <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXgroup>`__
object.
:param zarr_group: Zarr style group.
:type: zarr.Group
:param nexus_group: Nexus style group.
:type: nexusformat.nexus.NXgroup
"""
self.logger.info(f'Copying {zarr_group.path}')
# Copy attributes
for attr_key, attr_value in zarr_group.attrs.items():
nexus_group.attrs[attr_key] = attr_value
# Copy datasets and sub-groups
for key, item in zarr_group.members():
if isinstance(item, zarr.Array):
self.logger.info(f'Copying {zarr_group.path}/{key}')
# Copy dataset
nexus_dset = nexus_group.create_dataset(
name=key,
data=item.__array__(),
# chunks=item.chunks, # FIXME
compression='gzip',
compression_opts=4 # GZIP compression level
)
# Copy dataset attributes
for attr_key, attr_value in item.attrs.items():
nexus_dset.attrs[attr_key] = attr_value
elif isinstance(item, zarr.Group):
# Recursively copy subgroup
nexus_subgroup = nexus_group.create_group(key)
copy_group(item, nexus_subgroup)
# Start copying from the root group
copy_group(zarr_file, nexus_file)
if __name__ == '__main__':
# Local modules
from CHAP.processor import main
main()