Source code for CHAP.common.writer

#!/usr/bin/env python
#-*- coding: utf-8 -*-
"""Module for generic Writers used in multiple experiment-specific
workflows.
"""

# System modules
import os
from typing import Optional

# Third party modules
import numpy as np
from pydantic import (
    conint,
    constr,
    model_validator,
)

# Local modules
from CHAP.pipeline import PipelineItem
from CHAP.writer import (
    Writer,
    validate_writer_model,
)


[docs] def validate_model(model): """Validate the `model` configuration. :return: Validated model. :rtype: Any """ if model.filename is not None: validate_writer_model(model) return model
[docs] def write_matplotlibfigure(data, filename, savefig_kw, force_overwrite=False): """Write a `Matplotlib <https://matplotlib.org>`__ figure to file. :param data: The figure to write to file :type data: matplotlib.figure.Figure :param filename: File name. :type filename: str :param savefig_kw: Keyword args to pass to matplotlib.figure.Figure.savefig. :type savefig_kw: dict, optional :param force_overwrite: Flag to allow data to be overwritten if it already exists, defaults to `False`. :type force_overwrite: bool, optional """ # Third party modules from matplotlib.figure import Figure if not isinstance(data, Figure): raise TypeError('Cannot write object of type' f'{type(data)} as a matplotlib Figure.') if os.path.isfile(filename) and not force_overwrite: raise FileExistsError(f'{filename} already exists') if savefig_kw is None: data.savefig(filename) else: data.savefig(filename, **savefig_kw)
[docs] def write_nexus(data, filename, force_overwrite=False): """Write a NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__ object to file. :param data: The data to write to file :type data: nexusformat.nexus.NXobject :param filename: File name. :type filename: str :param force_overwrite: Flag to allow data to be overwritten if it already exists, defaults to `False`. :type force_overwrite: bool, optional """ # Third party modules from nexusformat.nexus import NXobject if not isinstance(data, NXobject): raise TypeError('Cannot write object of type' f'{type(data).__name__} as a NeXus file.') mode = 'w' if force_overwrite else 'w-' data.save(filename, mode=mode)
[docs] def write_tif(data, filename, force_overwrite=False): """Write a tif image to file. :param data: The data to write to file :type data: numpy.ndarray :param filename: File name. :type filename: str :param force_overwrite: Flag to allow data to be overwritten if it already exists, defaults to `False`. :type force_overwrite: bool, optional """ # Third party modules from imageio import imwrite data = np.asarray(data) if data.ndim != 2: raise TypeError('Cannot write object of type' f'{type(data).__name__} as a tif file.') if os.path.isfile(filename) and not force_overwrite: raise FileExistsError(f'{filename} already exists') imwrite(filename, data)
[docs] def write_txt(data, filename, force_overwrite=False, append=False): """Write plain text to file. :param data: The data to write to file :type data: str | list[str] :param filename: File name. :type filename: str :param force_overwrite: Flag to allow data to be overwritten if it already exists, defaults to `False`. :type force_overwrite: bool, optional :param append: Flag to allow data to be appended to the file if it already exists, defaults to `False`. :type append: bool, optional """ # Local modules from CHAP.utils.general import is_str_series if not isinstance(data, str) and not is_str_series(data, log=False): raise TypeError('input data must be a str or a tuple or list of str ' f'instead of {type(data)} ({data})') if not force_overwrite and not append and os.path.isfile(filename): raise FileExistsError(f'{filename} already exists') if append: with open(filename, 'a', encoding='utf-8') as f: if isinstance(data, str): f.write(data) else: f.write('\n'.join(data)) else: with open(filename, 'w', encoding='utf-8') as f: if isinstance(data, str): f.write(data) else: f.write('\n'.join(data))
[docs] def write_yaml(data, filename, force_overwrite=False): """Write data to a YAML file. :param data: The data to write to file :type data: dict | list :param filename: File name. :type filename: str :param force_overwrite: Flag to allow data to be overwritten if it already exists, defaults to `False`. :type force_overwrite: bool, optional """ # Third party modules import yaml if not isinstance(data, (dict, list)): raise TypeError('input data must be a dict or list.') if os.path.isfile(filename) and not force_overwrite: raise FileExistsError(f'{filename} already exists') with open(filename, 'w', encoding='utf-8') as f: yaml.dump(data, f, sort_keys=False)
[docs] def write_filetree(data, outputdir='.', force_overwrite=False): """Write data to a file tree. :param data: The data to write to files :type data: nexusformat.nexus.NXobject :param outputdir: Output directory. :type filename: str, optional :param force_overwrite: Flag to allow data to be overwritten if it already exists, defaults to `False`. :type force_overwrite: bool, optional """ # System modules from os import makedirs # Third party modules from nexusformat.nexus import ( NXentry, NXgroup, NXobject, NXroot, NXsubentry, ) if not isinstance(data, NXobject): raise TypeError('Cannot write object of type' f'{type(data).__name__} as a file tree to disk.') # FIX: Right now this can bomb if MultiplePipelineItem # is called simultaneously from multiple nodes in MPI if not os.path.isdir(outputdir): makedirs(outputdir) for k, v in data.items(): if isinstance(v, NXsubentry) and 'schema' in v.attrs: schema = v.attrs['schema'] filename = os.path.join(outputdir, v.attrs['filename']) if schema == 'txt': write_txt(list(v.data), filename, force_overwrite) elif schema == 'json': write_txt(str(v.data), filename, force_overwrite) elif schema in ('yml', 'yaml'): # Third party modules from json import loads write_yaml(loads(v.data.nxdata), filename, force_overwrite) elif schema in ('tif', 'tiff'): write_tif(v.data, filename, force_overwrite) elif schema == 'h5': if any(isinstance(vv, NXsubentry) for vv in v.values()): nxbase = NXroot() else: nxbase = NXentry() for kk, vv in v.attrs.items(): if kk not in ('schema', 'filename'): nxbase.attrs[kk] = vv for kk, vv in v.items(): if isinstance(vv, NXsubentry): nxentry = NXentry() nxbase[vv.nxname] = nxentry for kkk, vvv in vv.items(): nxentry[kkk] = vvv else: nxbase[kk] = vv write_nexus(nxbase, filename, force_overwrite) else: raise TypeError(f'Files of type {schema} not yet implemented') elif isinstance(v, NXgroup): write_filetree(v, os.path.join(outputdir, k), force_overwrite)
[docs] class ExtractArchiveWriter(Writer): """Writer for tar files from binary data."""
[docs] def write(self, data): """Take a tar archive represented as bytes contained in `data` and write the extracted archive to files. :param data: The data to write to archive. :type data: list[PipelineData] """ # System modules from io import BytesIO import tarfile data = self.get_pipelinedata_item(data, remove=self.remove) with tarfile.open(fileobj=BytesIO(data)) as tar: tar.extractall(path=self.filename)
[docs] class FileTreeWriter(PipelineItem): """Writer for a file tree in `NeXus <https://www.nexusformat.org>`__ format. :ivar force_overwrite: Flag to allow data to be overwritten if it already exists, defaults to `False`. Note that the existence of files prior to executing the pipeline is not possible since the filename(s) of the data are unknown during pipeline validation. :vartype force_overwrite: bool, optional :ivar remove: Flag to remove the dictionary from `data`, defaults to `True`. :vartype remove: bool, optional """ force_overwrite: Optional[bool] = False remove: Optional[bool] = True
[docs] def write(self, data): """Write a NeXus format object contained in `data` to a directory tree stuctured like the NeXus tree. :param data: The data to write to disk. :type data: list[PipelineData] :raises RuntimeError: If `filename` already exists and `force_overwrite` is `False`. """ nxentry = self.get_default_nxentry( self.get_data(data, remove=self.remove)) write_filetree(nxentry, self.outputdir, self.force_overwrite)
[docs] class H5Writer(Writer): """Writer for `HDF5 <https://www.hdfgroup.org/solutions/hdf5/>`__ files from a NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object."""
[docs] def write(self, data): """Write the NeXus style `NXdata <https://manual.nexusformat.org/classes/base_classes/NXdata.html#index-0>`__ object contained in `data` to hdf5 file. :param data: The data to write to file. :type data: list[PipelineData] :raises RuntimeError: If `filename` already exists and `force_overwrite` is `False`. """ # Third party modules from h5py import File from nexusformat.nexus import NXdata data = self.get_pipelinedata_item(data, remove=self.remove) if not isinstance(data, NXdata): raise ValueError('Invalid data parameter {(data)}') mode = 'w' if self.force_overwrite else 'w-' with File(self.filename, mode) as f: f[data.signal] = data.nxsignal for i, axes in enumerate(data.attrs['axes']): f[axes] = data[axes] f[data.signal].dims[i].label = \ f'{axes} ({data[axes].units})' \ if 'units' in data[axes].attrs else axes f[axes].make_scale(axes) f[data.signal].dims[i].attach_scale(f[axes])
[docs] class ImageWriter(PipelineItem): """Writer for saving image files. :ivar filename: Name of file to write to. :vartype filename: str, optional :ivar force_overwrite: Flag to allow data in `filename` to be overwritten if it already exists, defaults to `False`. :vartype force_overwrite: bool, optional :ivar remove: Flag to remove the dictionary from `data`, defaults to `False`. :vartype remove: bool, optional """ filename: Optional[str] = None force_overwrite: Optional[bool] = False remove: Optional[bool] = True _validate_filename = model_validator(mode='after')(validate_model)
[docs] def write(self, data): """Write the image(s) contained in `data` to file. :param data: The data to write to file. :type data: list[PipelineData] :raises RuntimeError: If a file already exists and `force_overwrite` is `False`. """ # System modules from io import BytesIO # Third party modules from matplotlib.animation import ( ArtistAnimation, FuncAnimation, ) # Local modules from CHAP.utils.general import save_iobuf_fig try: data = self.get_data( data, schema='common.write.ImageWriter', remove=self.remove) except ValueError: self.logger.warning( 'Unable to find match with schema `common.write.ImageWriter`: ' 'return without writing') return if isinstance(data, list): for (buf, fileformat), basename in data: self.filename = f'{basename}.{fileformat}' if not os.path.isabs(self.filename): self.filename = os.path.join(self.outputdir, self.filename) if isinstance(buf, (ArtistAnimation, FuncAnimation)): buf.save(self.filename) else: save_iobuf_fig( buf, self.filename, force_overwrite=self.force_overwrite) return if isinstance(data, dict): image_data = data['image_data'] fileformat = data['fileformat'] elif isinstance(data, tuple) and len(data) == 2: image_data = data[0] fileformat = data[1] else: image_data = data if self.filename is None: self.filename = 'image' basename, ext = os.path.splitext(self.filename) if ext[1:] != fileformat: self.filename = f'{self.filename}.{fileformat}' if not os.path.isabs(self.filename): self.filename = os.path.join(self.outputdir, self.filename) if os.path.isfile(self.filename) and not self.force_overwrite: raise FileExistsError(f'{self.filename} already exists') if isinstance(image_data, BytesIO): save_iobuf_fig( image_data, self.filename, force_overwrite=self.force_overwrite) elif isinstance(image_data, np.ndarray): if image_data.ndim == 2: # Third party modules from imageio import imwrite imwrite(self.filename, image_data) elif image_data.ndim == 3: # Third party modules from tifffile import imwrite kwargs = {'bigtiff': True} imwrite(self.filename, image_data, **kwargs) elif isinstance(image_data, (ArtistAnimation, FuncAnimation)): image_data.save(self.filename) else: raise ValueError(f'Invalid image input type {type(image_data)}')
[docs] class MatplotlibAnimationWriter(Writer): """Writer for saving `Matplotlib <https://matplotlib.org>`__ animations. :ivar fps: Movie frame rate (frames per second), defaults to `1`. :vartype fps: int, optional """ fps: Optional[conint(gt=0)] = 1
[docs] def write(self, data): """Write the matplotlib.animation.ArtistAnimation object contained in `data` to file. :param data: The data to write to file. :type data: list[PipelineData] """ data = self.get_pipelinedata_item(data, remove=self.remove) extension = os.path.splitext(self.filename)[1] if not extension: data.save(f'{self.filename}.gif', fps=self.fps) elif extension == '.gif': data.save(self.filename, fps=self.fps) elif extension == '.mp4': data.save(self.filename, writer='ffmpeg', fps=self.fps)
[docs] class MatplotlibFigureWriter(Writer): """Writer for saving `Matplotlib <https://matplotlib.org>`__ figures to image files. :ivar savefig_kw: Keyword args to pass to matplotlib.figure.Figure.savefig. :vartype savefig_kw: dict, optional """ savefig_kw: Optional[dict] = None
[docs] def write(self, data): """Write the matplotlib.figure.Figure contained in `data` to file. :param data: The data to write to file. :type data: list[PipelineData] :raises RuntimeError: If `filename` already exists and `force_overwrite` is `False`. """ data = self.get_pipelinedata_item(data, remove=self.remove) write_matplotlibfigure( data, self.filename, self.savefig_kw, self.force_overwrite)
[docs] class NexusWriter(Writer): """Writer for `NeXus <https://www.nexusformat.org>`__ files from NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__ objexts. :ivar nxpath: Path to a specific location in the NeXus file tree to write to (ignored if `filename` does not yet exist). :vartype nxpath: str, optional """ nxpath: Optional[constr(strip_whitespace=True, min_length=1)] = None
[docs] def write(self, data): """Write the NeXus style `NXobject <https://manual.nexusformat.org/classes/base_classes/NXobject.html#index-0>`__ contained in `data` to file. :param data: The data to write to file. :type data: list[PipelineData] :raises RuntimeError: If `filename` already exists and `force_overwrite` is `False`. """ # Third party modules from nexusformat.nexus import ( NXFile, NXentry, NXroot, ) nxobject = self.get_data(data, remove=self.remove) nxname = nxobject.nxname if not os.path.isfile(self.filename) and self.nxpath is not None: self.logger.warning( f'{self.filename} does not yet exist, ignoring nxpath ' f'argument ({self.nxpath})') self.nxpath = None if self.nxpath is None: nxclass = nxobject.nxclass if nxclass == 'NXroot': nxroot = nxobject elif nxclass == 'NXentry': nxroot = NXroot(nxobject) nxroot[nxname].set_default() else: nxroot = NXroot(NXentry(nxobject)) if nxclass == 'NXdata': nxroot.entry[nxname].set_default() nxroot.entry.set_default() write_nexus(nxroot, self.filename, self.force_overwrite) else: with NXFile(self.filename, 'rw') as nxfile: self.logger.debug(f'nxfile.mode = {nxfile.mode}') root = nxfile.readfile() if nxfile.get(self.nxpath) is None: if nxfile.get(os.path.dirname(self.nxpath)) is not None: self.nxpath, nxname = os.path.split(self.nxpath) else: self.logger.warning( f'Path "{self.nxpath}" not present in ' f'{self.filename}. ' f'Using {root.NXentry[0].nxpath} instead.') self.nxpath = root.NXentry[0].nxpath full_nxpath = os.path.join(self.nxpath, nxname) self.logger.debug( f'Full path for object to write: {full_nxpath}') if nxfile.get(full_nxpath) is not None: self.logger.debug( f'{full_nxpath} already exists in {self.filename}') if self.force_overwrite: self.logger.warning( 'Deleting existing NXobject at ' f'{full_nxpath, nxname} in {self.filename}') del root[full_nxpath] try: root[full_nxpath] = nxobject except Exception as exc: raise exc # Return provenance with the output file name added return self._update_provenance(data)
[docs] class NexusValuesWriter(Writer): """Writer for updating values in an existing `NeXus <https://www.nexusformat.org>`__ file."""
[docs] def write(self, data, filename, path_prefix=''): """Write new values specified in `data` to the exising `NeXus <https://www.nexusformat.org>`__ file `filename`. :param data: List of dictionaries with the following entries -- `'path'` identifying the location of the NeXus style `NXfield <https://nexpy.github.io/nexpy/treeapi.html#nexusformat.nexus.tree.NXfield>`__ object to which values will be written, `'data'` identifying the data to be written, and `'idx'` identifying the index / indicies of the NXfield to which the data will be written. :type data: list[PipelineData] :param filename: Name of an existing NeXus file to update. :type filename: str :param path_prefix: Prefix to use for all paths in input `data`, defaults to `''`. :type path_prefix: str, optional """ # Third party modules from nexusformat.nexus import NXFile data = self.get_pipelinedata_item(data, remove=self.remove) for d in data: with NXFile(filename, 'a') as nxroot: self.nxs_writer( nxroot=nxroot, path=os.path.join(path_prefix, d['path']), idx=d['idx'], data=d['data'] )
[docs] def nxs_writer(self, nxroot, path, idx, data): """Write data to a specific `NeXus <https://www.nexusformat.org>`__ file. This method writes `data` to a specified dataset within a NeXus file at the given index (`idx`). If the dataset does not exist, an error is raised. The method ensures that the shape of `data` matches the shape of the target slice before writing. :param nxroot: NeXus Style `NXroot <https://manual.nexusformat.org/classes/base_classes/NXroot.html#index-0>`__ object. :type nxroot: nexusformat.nexus.NXroot :param path: Path to the dataset inside the NeXus file. :type path: str :param idx: Index or slice where the data should be written. :type idx: tuple or int :param data: Data to be written to the specified slice in the dataset. :type data: numpy.ndarray or compatible array-like object :raises ValueError: If the specified dataset does not exist or if the shape of `data` does not match the target slice. """ self.logger.info(f'Writing to {path} at {idx}') # Check if the dataset exists if path not in nxroot: raise ValueError( f'Dataset "{path}" does not exist in the NeXus file.') # Access the specified dataset dataset = nxroot[path] # Check that the slice shape matches the data shape data = np.asarray(data) if dataset[idx].shape != data.shape: raise ValueError( f'Data shape {data.shape} does not match the target slice ' f'shape {dataset[idx].shape}.') # Write the data to the specified slice dataset[idx] = data self.logger.info(f'Data written to "{path}" at slice {idx}.')
[docs] class PyfaiResultsWriter(Writer): """Writer for results of one or more `pyFAI <https://pyfai.readthedocs.io/en/stable>`__ integrations. Able to handle multiple output formats. Currently supported formats are: .npz, .nxs. """
[docs] def write(self, data): """Save `pyFAI <https://pyfai.readthedocs.io/en/stable>`__ integration results to a file. Format is determined automatically form the extension of `filename`. :param data: The data to write to file. :type data: list[PipelineData] | list[pyFAI.containers.IntegrateResult] """ # Third party modules from pyFAI.containers import Integrate1dResult, Integrate2dResult try: results = self.get_pipelinedata_item(data, remove=self.remove) except ValueError: results = data if not isinstance(results, list): results = [results] if (not all([isinstance(r, Integrate1dResult) for r in results]) and not all( [isinstance(r, Integrate2dResult) for r in results])): raise ValueError( 'Bad input data: all items must have the same type -- either ' 'all pyFAI.containers.Integrate1dResult, or all ' 'pyFAI.containers.Integrate2dResult.') if os.path.isfile(self.filename): if self.force_overwrite: self.logger.warning(f'Removing existing file {self.filename}') os.remove(self.filename) else: raise RuntimeError(f'{self.filename} already exists.') _, ext = os.path.splitext(self.filename) if ext.lower() == '.npz': self.write_npz(results, self.filename) elif ext.lower() == '.nxs': self.write_nxs(results, self.filename) else: raise RuntimeError(f'Unsupported file format: {ext}') self.logger.info(f'Wrote to {self.filename}')
[docs] def write_npz(self, results, filename): """Save `results` to the .npz file, `filename`.""" data = {'radial': results[0].radial, 'intensity': [r.intensity for r in results]} if hasattr(results[0], 'azimuthal'): # 2d results data['azimuthal'] = results[0].azimuthal if all([r.sigma for r in results]): # errors were included data['sigma'] = [r.sigma for r in results] np.savez(filename, **data)
[docs] def write_nxs(self, results, filename): """Save `results` to the .nxs file, `filename`.""" raise NotImplementedError
[docs] class TXTWriter(Writer): """Writer for plain text files from string or tuples or lists of strings. :ivar append: Flag to allow data in `filename` to be be appended, defaults to `False`. :vartype append: bool, optional """ append: Optional[bool] = False
[docs] def write(self, data): """Write a string or tuple or list of strings contained in `data` to file. :param data: The data to write to file. :type data: list[PipelineData] :raises TypeError: If the object contained in `data` is not a `str`, `tuple[str]` or `list[str]`. :raises RuntimeError: If `filename` already exists and `force_overwrite` is `False`. """ write_txt( self.get_pipelinedata_item(data, remove=self.remove), self.filename, self.force_overwrite, self.append)
[docs] class YAMLWriter(Writer): """Writer for YAML files from `dict`-s."""
[docs] def write(self, data): """Write the last matching dictionary contained in `data` to file (the schema mush match is a schema is provided). :param data: The data to write to file. :type data: list[PipelineData] :raises TypeError: If the object contained in `data` is not a `dict`. :raises RuntimeError: If `filename` already exists and `force_overwrite` is `False`. """ # Third party modules from pydantic import BaseModel # Local modules from CHAP.models import CHAPBaseModel def get_dict(data): if isinstance(data, dict): return data if isinstance(data, (BaseModel, CHAPBaseModel)): try: return data.model_dump() except Exception: pass return None schema = self.get_schema() yaml_dict = None if schema is not None: for i, d in reversed(list(enumerate(data))): if schema == d['schema']: yaml_dict = get_dict(d['data']) if yaml_dict is not None: if self.remove: data.pop(i) break if yaml_dict is None: if schema is not None: self.logger.warning( f'Unable to find match with schema {schema}: ' 'try finding a dictionary without matching schema') for i, d in reversed(list(enumerate(data))): yaml_dict = get_dict(d['data']) if yaml_dict is not None: if self.remove: data.pop(i) break write_yaml(yaml_dict, self.filename, self.force_overwrite) self.status = 'written' # Right now does nothing yet, but could # add a sort of modification flag later # Return provenance with the output file name added return self._update_provenance(data)
[docs] class ZarrValuesWriter(Writer): """Writer for updating values in arrays of an existing `Zarr <https://zarr.readthedocs.io/en/stable/>`__ file. :ivar path_prefix: Prefix to prepend to all "path" fields in `data` before writing. Defaults to `""`. :vartype path_prefix: str, optional """ path_prefix: Optional[str] = ''
[docs] def write(self, data): """Write values to specific paths and slices in an existing zarr file. :param data: Data whose last item contains a list of dictionaries that each have three keys: `"data"`, `"path"`, `"idx"`. :type data: list[PipelineData] """ # Third party modules # pylint: disable=import-error import zarr # Open file in append mode to allow modifications zarrfile = zarr.open(self.filename, mode='a') # Get list of PyfaiIntegrationProcessor results to write for d in self.get_pipelinedata_item(data, remove=self.remove): self.zarr_writer( zarrfile=zarrfile, path=os.path.join(self.path_prefix, d['path']), idx=d['idx'], data=d['data'])
[docs] def zarr_writer(self, zarrfile, path, idx, data): """Write data to a specific dataset. This method writes `data` to a specified dataset within a `Zarr <https://zarr.readthedocs.io/en/stable/>`__ file at the given index (`idx`). If the dataset does not exist, an error is raised. The method ensures that the shape of `data` matches the shape of the target slice before writing. :param zarrfile: Path to the Zarr file. :type zarrfile: zarr.core.group.Group :param path: Path to the dataset inside the Zarr file. :type path: str :param idx: Index or slice where the data should be written. :type idx: tuple or int :param data: Data to be written to the specified slice in the dataset. :type data: numpy.ndarray or compatible array-like object :raises ValueError: If the specified dataset does not exist or if the shape of `data` does not match the target slice. """ self.logger.info(f'Writing to {path} at {idx}') # Check if the dataset exists if path not in zarrfile: raise ValueError( f'Dataset "{path}" does not exist in the Zarr file.') # Access the specified dataset dataset = zarrfile[path] # Check that the slice shape matches the data shape if dataset[idx].shape != data.shape and data.shape[0] == 1: data = np.squeeze(data, axis=0) if dataset[idx].shape != data.shape: raise ValueError( f'Data shape {data.shape} does not match the target slice ' f'shape {dataset[idx].shape}.') # Write the data to the specified slice dataset[idx] = data self.logger.info(f'Data written to "{path}" at slice {idx}.')
[docs] class ZarrWriter(Writer): """Writer for `Zarr <https://zarr.readthedocs.io/en/stable/>`__ groups."""
[docs] def write(self, data): # System modules import asyncio # Third party modules # pylint: disable=import-error from zarr.core.buffer import default_buffer_prototype from zarr.storage import LocalStore from zarr.abc.store import Store from zarr.core.group import AsyncGroup, Group async def copy_zarr_store_to_local_store(zarr_store, local_store): async for k in zarr_store.list(): self.logger.info(f'Copying {k}') buf = await zarr_store.get( k, prototype=default_buffer_prototype()) await local_store.set(k, buf) zarr_obj = self.get_pipelinedata_item(data, remove=self.remove) if isinstance(zarr_obj, Store): _zarr_store = zarr_obj elif isinstance(zarr_obj, (AsyncGroup, Group)): _zarr_store = zarr_obj.store else: raise TypeError( 'Expected zarr.abc.store.Store, zarr.core.group.AsyncGroup, ' f'or zarr.core.group.Group, got {type(zarr_obj)}' ) _local_store = LocalStore(self.filename) asyncio.run(copy_zarr_store_to_local_store( _zarr_store, _local_store))
if __name__ == '__main__': # Local modules from CHAP.writer import main main()