"""
General-purpose helper functions for the riogrande package.
This module collects utility functions that are used across the package but do
not belong to the I/O layer or the parallelization machinery. It covers:
- **Compatibility checks**: CRS, spatial resolution, and unit validation across
multiple raster sources (:func:`check_compatibility`, :func:`check_crs`,
:func:`check_resolution`, :func:`check_units`).
- **Dtype conversion**: Converting array values between numeric types with
optional range rescaling (:func:`convert_to_dtype`, :func:`dtype_range`).
- **Tag utilities**: Serializing, deserializing, sanitizing, and matching
metadata tag dictionaries (:func:`serialize`, :func:`deserialize`,
:func:`sanitize`, :func:`match_all`, :func:`match_any`).
- **Mask aggregation**: Combining boolean selector arrays with logical AND/OR
(:func:`aggregated_selector`, :func:`reduced_mask`).
- **Multiprocessing setup**: Obtaining a multiprocessing context and determining
the number of worker processes (:func:`get_or_set_context`,
:func:`get_nbr_workers`).
- **Miscellaneous**: Output filename generation, window-to-view conversion, and
pixel contribution counting.
"""
from __future__ import annotations
import os
import json
import warnings
import numpy as np
from numpy.typing import NDArray
import rasterio as rio
from rasterio.windows import Window
from decimal import Decimal
from typing import Any, Union, Tuple, Optional
from collections.abc import Collection
import multiprocessing as mpc
from multiprocessing import context as _context_module
[docs]
MPC_STARTER_METHODS = ['spawn', 'fork', 'forkserver']
[docs]
def get_nbr_workers(number: Optional[int] = None) -> int:
"""Determine the number of worker processes to use in mulitprocessing.
Parameters
----------
number : int or None, optional
Desired number of workers. If ``None``, the function will use the
number of CPUs available via :func:`multiprocessing.cpu_count`,
but never less than 2.
Returns
-------
int
Number of workers to use (always `>= 2`).
Notes
-----
A warning is emitted when a requested ``number`` is lower than 2 and the
request is ignored setting the number of used workers to 2.
See Also
--------
:func:`~riogrande.helper.get_or_set_context` : Return a multiprocessing context.
"""
_min_count = 2 # Hardcoded: some parallelization routines fail when < 2
if number is None:
_use = max(_min_count, mpc.cpu_count())
elif number <= _min_count:
warnings.warn(
message=f"For this routine to work properly at least {_min_count} "
f"workers are required - the requested {number} are not "
"enough and thus the request will be ignored.",
category=RuntimeWarning
)
_use = _min_count
else:
_use = int(number)
return _use
[docs]
def get_or_set_context(method: Optional[str] = None) -> _context_module.BaseContext:
"""
Return a multiprocessing context and set the global start method if unset.
The function tries to be conservative about changing global interpreter state:
- If `method` is None, it returns a context for the currently configured
global start method when one exists; otherwise it warns and returns a
context for a sensible default ('spawn' is used to establish
compatibility with windows).
- If `method` is provided and no global start method is set, it attempts to
set the global start method to `method`. If that attempt races with
another thread/process, it falls back to returning a context for `method`
without changing the global start method.
- If `method` is provided and a different global start method is already
set, the global start method is not changed; a warning is emitted and a
context for the requested `method` is returned so callers can still
create objects using the requested start semantics.
Parameters
----------
method : {None, 'fork', 'spawn', 'forkserver'}, optional
Desired multiprocessing start method to use for the returned context.
If ``None`` the function will:
- return a context for the currently configured global start method if
one exists, or
- emit a ``RuntimeWarning`` and return a context for the configured
default method (``spawn``) if no global method is set.
Valid explicit values are ``'fork'``, ``'spawn'`` and ``'forkserver'``
(availability depends on the platform and Python build). Passing an
unsupported value raises ``ValueError``.
Returns
-------
multiprocessing.context.BaseContext
A multiprocessing context object appropriate for creating
:class:`multiprocessing.Process`, :class:`multiprocessing.pool.Pool`
and related objects. The returned context will use the start method
determined by the logic described above. The function always returns a
context and never mutates an already-set global start method to a
different value.
Raises
------
ValueError
If ``method`` is not one of the supported start methods or ``None``.
RuntimeError
If the function attempts to set the global start method and the call to
``multiprocessing.set_start_method`` raises ``RuntimeError`` for reasons
other than a race (this is rare); in normal race cases the function
catches the ``RuntimeError`` and falls back to returning the requested
context.
Notes
-----
- Calling ``multiprocessing.set_start_method`` can only be done once per
interpreter process. Once the global start method is set, it cannot be
changed without restarting the interpreter. This function therefore
avoids forcibly overwriting an existing different global start method.
- The returned context is safe to use even when the global start method
differs, because context objects encapsulate start semantics for the
created processes independently of global state.
- On Windows the only available start method is ``'spawn'``; on Unix-like
systems ``'fork'`` and ``'spawn'`` are commonly available and
``'forkserver'`` may be available depending on the platform.
- Use this helper in library code when you need a guaranteed context but
do not want to unconditionally mutate global multiprocessing state.
See Also
--------
:func:`~riogrande.helper.get_nbr_workers` : Determine the number of worker processes.
Examples
--------
>>> ctx = get_or_set_context('spawn')
>>> with ctx.Process(target=worker) as p:
>>> p.start()
>>> p.join()
"""
allowed = MPC_STARTER_METHODS + [None, ]
default_method = MPC_STARTER_METHODS[0] # default is 'spawn'
if method not in allowed:
raise ValueError(f"Unsupported start method: {method!r}")
# get the current context
_context = mpc.get_start_method(allow_none=True)
if _context is None:
# if method is not None, set the global method and the current context
if method is not None:
try:
mpc.set_start_method(method) # set starting method
except RuntimeError:
# concurrent set; warn and ignore the global context
warnings.warn(
"Race when setting start method; returning requested context.",
RuntimeWarning)
finally:
_context = method
else: # both the gloabl context and method are None:
# avoid setting the global context, only set locally
warnings.warn(
"No multiprocessing start method set and no global either"
f"— defaulting to local context only with '{default_method}'.",
RuntimeWarning
)
_context = default_method
else: # global context is set already
if method is not None:
if method != _context:
warnings.warn(
f"Global multiprocessing start method is '{_context}'"
f" but requested context is '{method}'"
f"— using local context only with '{method}'"
"keeping the global unchanged.",
RuntimeWarning
)
_context = method
else: # simply use _context
pass
else: # global is set local in None > use global (_context)
pass
# print(f"{mpc.get_start_method()=}")
# print(f"{_context=}")
return mpc.get_context(_context)
[docs]
def serialize(tags: dict[str, Any]) -> dict[str, str]:
"""Convert the values of a dict into JSON
Each value is serialized using :func:`json.dumps`.
Parameters
----------
tags : dict[str, Any]
Dictionary of tags with string keywords and any-type values,
which are serializable.
Returns
-------
dict
Dictionary with tag as key and serialized value as value.
See Also
--------
:func:`~riogrande.helper.deserialize` : Inverse operation; parse JSON back to Python objects.
:func:`~riogrande.helper.sanitize` : Serialize then deserialize in one step.
"""
return {tag: json.dumps(obj=value) for tag, value in tags.items()}
[docs]
def deserialize(tags: dict[str, str]) -> dict[str, Any]:
"""Reads python objects from JSON-encoded values of a dict
Each value is parsed using :func:`json.loads`.
Parameters
----------
tags : dict[str, str]
Dictionary with tag as key and serialized values.
Returns
-------
dict
Dictionary with tag as key and deserialized value as value.
Notes
------
Inverse operation of :func:`~riogrande.helper.serialize`.
See Also
--------
:func:`~riogrande.helper.serialize` : Convert dict values to JSON strings.
:func:`~riogrande.helper.sanitize` : Serialize then deserialize in one step.
"""
return {tag: json.loads(s=value)
for tag, value in tags.items()}
[docs]
def sanitize(tags: dict[str, Any]) -> Any:
"""Serializes then deserializes values of a dict
Convenience wrapper that calls :func:`~riogrande.helper.serialize`
followed by :func:`~riogrande.helper.deserialize`, ensuring values are
in the same form they would be when loaded back from a ``.tif`` tag.
Parameters
----------
tags : dict[str, Any]
Dictionary with tag as key and serializable value as value.
Returns
---------
dict
Dictionary with tag as key and deserialized value as value.
See Also
--------
:func:`~riogrande.helper.serialize` : Convert dict values to JSON strings.
:func:`~riogrande.helper.deserialize` : Parse JSON strings back to Python objects.
"""
return deserialize(serialize(tags))
[docs]
def match_all(targets: dict, tags: dict) -> bool:
"""Check if all tags in targets are present in tags
Parameters
----------
targets : dict
Dictionary with tags to match to.
tags : dict
Dictionary with tags to check for matching items.
Returns
---------
bool
True if all tags in targets are present in tags, otherwise False.
See Also
--------
:func:`~riogrande.helper.match_any` : Return True if *any* tag matches.
"""
match = True
for t, v in targets.items():
if not match:
break
if t in tags:
if tags[t] == v:
match = True
else:
match = False
else:
match = False
return match
[docs]
def match_any(targets: dict, tags: dict) -> bool:
"""Check if any tag in targets is present in tags
Parameters
----------
targets : dict
Dictionary with tags to match to.
tags : dict
Dictionary with tags to check for matching items.
Returns
---------
bool
True if any tags in targets are present in tags, otherwise False.
See Also
--------
:func:`~riogrande.helper.match_all` : Return True only if *all* tags match.
"""
match = False
for t, v in targets.items():
if match:
break
if t in tags:
if tags[t] == v:
match = True
else:
match = False
else:
match = False
return match
[docs]
def view_to_window(view: None | tuple[int, int, int, int]) -> Window:
"""Conerts a view into a rasterio Window
Parameters
----------
view : tuple[int, int, int, int] or None
tuple (x, y, width, height) defining the view of the data array to update
Returns
---------
:class:`rasterio.windows.Window`
Rasterio window object, or ``None`` if `view` is ``None``.
"""
if view is not None:
window = Window(view[0], view[1], view[2], view[3])
else:
window = None
return window
[docs]
def check_units(*sources: str) -> list:
"""Assert that all sources have the same linear units in the coordinate reference system (crs)
Parameters
----------
*sources : str
List of sources (paths to files) from which units are to be compared to each other.
Returns
---------
list
All unique units in a list.
See Also
--------
:func:`~riogrande.helper.check_crs` : Check that sources share the same CRS.
:func:`~riogrande.helper.check_resolution` : Check that sources share the same resolution.
:func:`~riogrande.helper.check_compatibility` : Run all three checks at once.
"""
units = []
for source in sources:
with rio.open(source) as src:
crs = src.profile['crs']
if crs is not None:
units.append(src.profile['crs'].linear_units.lower())
else:
units.append(None)
if len(set(units)) != 1:
raise TypeError(f"{source=} has linear units {units[-1]}, "
"which is different from the other(s) "
f"({units[0]})")
return units
[docs]
def check_crs(*sources: str) -> list:
"""Assert that all the sources have the same coordinate reference system (crs)
Parameters
----------
*sources : str
List of sources (paths to files) from which crs are to be compared to each other.
Returns
---------
list
All unique crs from sources in a list.
See Also
--------
:func:`~riogrande.helper.check_units` : Check that sources share the same linear units.
:func:`~riogrande.helper.check_resolution` : Check that sources share the same resolution.
:func:`~riogrande.helper.check_compatibility` : Run all three checks at once.
"""
crss = []
for source in sources:
with rio.open(source) as src:
crss.append(str(src.profile.get('crs', None)))
if len(set(crss)) != 1:
raise TypeError(f"{source=} has crs {crss[-1]}, which is "
f"different from the other(s) ({crss[0]})")
return crss
[docs]
def check_resolution(*sources: str) -> list:
"""Assert that all the sources have the same spatial resolution
Parameters
----------
*sources : str
List of sources (paths to files) from which resolutions are to be compared to each other.
Returns
---------
list
All unique resolutions from sources in a list.
See Also
--------
:func:`~riogrande.helper.check_units` : Check that sources share the same linear units.
:func:`~riogrande.helper.check_crs` : Check that sources share the same CRS.
:func:`~riogrande.helper.check_compatibility` : Run all three checks at once.
"""
ress = []
for source in sources:
with rio.open(source) as src:
# NOTE: we round 8th digit after the comma here
ress.append(tuple(map(lambda x: round(x, 8), src.res)))
if len(set(ress)) != 1:
raise TypeError(f"{source=} has resolution {ress[-1]}, which "
f"is different from the other(s) ({ress[0]})")
return ress
[docs]
def check_compatibility(*sources: str) -> Tuple[list, list, list]:
"""Assert that all the sources are compatible with each other.
The checks include:
- crs (via :func:`~riogrande.helper.check_crs`)
- units (via :func:`~riogrande.helper.check_units`)
- resolution (via :func:`~riogrande.helper.check_resolution`)
Parameters
----------
*sources : str
List of sources (paths to files) from which are to be compared to each other.
Returns
---------
crss : list
All unique crs from sources in a list (see :func:`~riogrande.helper.check_crs`).
units : list
All unique units from sources in a list (see :func:`~riogrande.helper.check_units`).
ress : list
All unique resolutions from sources in a list (see :func:`~riogrande.helper.check_resolution`).
See Also
--------
:func:`~riogrande.helper.check_crs` : Check that sources share the same CRS.
:func:`~riogrande.helper.check_units` : Check that sources share the same linear units.
:func:`~riogrande.helper.check_resolution` : Check that sources share the same resolution.
"""
units = check_units(*sources)
crss = check_crs(*sources)
ress = check_resolution(*sources)
return crss, units, ress
[docs]
def output_filename(base_name: str, out_type: str, blur_params: None | dict = None) -> str:
"""Construct the filename for the specific output type.
Parameters
----------
base_name : str
The basic output name in the form <name>.tif
out_type : str
The type of output that will be saved.
This should be either 'blur' or 'entropy' but any string is accepted
blur_params : dict or None
Output of `get_blur_params`, so 'sigma', 'truncate' and 'diameter'
are expected keys.
Returns
-------
str
The resulting filename of the form
'<name>_<out_type>_sig_<{sigma}>_diam_<{diameter}>_trunc_<{truncate}>.tif'
"""
_base_name, _ext = os.path.splitext(base_name)
_blur_string = ""
if blur_params is not None:
for name, value in blur_params.items():
_blur_string += f"_{name}_{round(value)}"
return f"{_base_name}_{out_type}{_blur_string}{_ext}"
[docs]
def dtype_range(dtype: type | str) -> Tuple[int | float, int | float]:
"""Get the range of the specified dtype
Uses :func:`numpy.iinfo` for integer types and :func:`numpy.finfo`
for floating-point types.
.. warning::
This functions returns min or max as either `int` or `floats`.
Be sure to convert them back into `dtype` if needed!
Parameters
----------
dtype : type or str
A NumPy dtype (e.g. ``np.uint8``, ``np.float32``) or a string
representation thereof (e.g. ``'uint8'``).
Returns
-------
tuple
``(max, min)`` of the dtype's representable range as Python
``int`` or ``float``.
Raises
------
ValueError
If `dtype` has no defined min/max values.
See Also
--------
:func:`~riogrande.helper.convert_to_dtype` : Convert and optionally rescale an array.
"""
if isinstance(dtype, str):
dtype = np.dtype(dtype)
# avoid issues of object not callable from rasterio
elif hasattr(dtype, 'type'):
dtype = dtype.type
try:
_max = int(np.iinfo(dtype).max)
_min = int(np.iinfo(dtype).min)
except ValueError:
try:
_max = float(np.finfo(dtype).max)
_min = float(np.finfo(dtype).min)
except ValueError:
raise ValueError(f"{dtype=} has no min-/maximal values.")
return _max, _min
[docs]
def convert_to_dtype(data: NDArray, as_dtype: None | type | np._dtype | str = None,
in_range: None | NDArray | Collection = None,
out_range: None | NDArray | Collection | str | type = None) -> NDArray:
"""Converts data to `as_dtype` and optionally rescales it.
Rescaling is done only if at least one of the ranges is explicitly set.
If only `in_range` is set then the input range is scaled to the full range
of the output data type, `ad_dtype`.
This behaviour is typically wanted when converting some floating typed data
in a limited range, e.g. [0, 1] to unsigned integer, e.g. `uint8`, thus
mapping the range [0,1] to [0, 255].
In case only `out_range` is set, the full data type range of the input
data is mapped to the provided `out_range`.
This is typically used if converting from a "limited" range, like `uint8`
to a floating data type.
.. note::
The default range for any floating type is `[0,1]`!
This means:
- If the output data type, `as_dtype` is any subclass of `np.floating`
and no `out_range` is defined then the output is scaled to the intervarl
`[0, 1]`.
- If data is of any `np.floating` type and the data range lies withing
`[0, 1]` (and `in_range` is not provided) then `in_range` is set to be
`[0, 1]`.
Parameters
----------
data : NDArray
Input numpy NDArray
as_dtype : type or str or None
Desired data type to convert to (e.g. np.float64).
If not provided then at least the `out_range` needs to be set in
which case the data type remains unchanges, but the data is
rescaled.
in_range : NDArray or Collection or None
An array or list from which min and max will be used as input range.
Min and max are read with :func:`numpy.nanmin` / :func:`numpy.nanmax`.
.. note::
You might simply provide the same value as for `data` in order to
use its min an max for scaling
out_range : NDArray or Collection or str or type or None
An array or list from which min and max will be used as limits for the
output.
Alternatively, a data type can be specified, in which case the data
will be scaled to the full range of the specified data type
(see :func:`~riogrande.helper.dtype_range`).
Returns
----------
NDArray
Converted numpy NDArray with desired data type.
See Also
--------
:func:`~riogrande.helper.dtype_range` : Get the min/max of a NumPy dtype.
Examples
--------
>>> # simple conversion, no rescaling
>>> my_data = np.array([0, 0.5, 1.], dtype=np.float64)
>>> convert_to_dtype(my_data, as_dtype='uint8')
array([0, 0, 1], dtype=uint8)
>>> # conversion with rescaling specifying in_range only
>>> new_data = convert_to_dtype(my_data, as_dtype='uint8', in_range=(0,1))
>>> new_data
array([ 0, 127, 255], dtype=uint8)
>>> # convert with scaling specifying out_range only
>>> convert_to_dtype(data=new_data, as_dtype='float64', out_range=[-1, 1])
array([-1. , -0.00392157, 1. ])
>>> # only scaling, keeping data type
>>> convert_to_dtype(data=my_data, in_range=[0,1], out_range=[-1, 1])
array([-1., 0., 1.])
>>> # scaling with data type as range
>>> convert_to_dtype(data=my_data, in_range=[0,1], as_dtype='uint16', out_range='uint8')
array([ 0, 127, 255], dtype=uint16)
"""
# convert to numpy dtype if string was provided
if isinstance(as_dtype, str):
as_dtype = np.dtype(as_dtype)
in_dtype = data.dtype
rescale = False
if in_range is not None:
rescale = True
if isinstance(in_range, (str, type)):
# we have a data type
_inmax, _inmin = dtype_range(in_range)
else:
_inmax = float(np.nanmax(in_range))
_inmin = float(np.nanmin(in_range))
else:
# us the full rang of input data type if scaling should be done
_inmax, _inmin = dtype_range(in_dtype)
if out_range is not None:
if not rescale:
# use the full range in case
_inmax, _inmin = dtype_range(in_dtype)
rescale = True
if isinstance(out_range, (str, type)):
# we have a data type
_outmax, _outmin = dtype_range(out_range)
else:
_outmax = float(np.nanmax(out_range))
_outmin = float(np.nanmin(out_range))
elif rescale:
# no output range but rescale due to input range
# use the full range of output data type if scaling should be done
_outmax, _outmin = dtype_range(as_dtype)
if rescale:
# we rescale
if out_range is None and np.issubdtype(as_dtype, np.floating):
# we are about to map something to the full float range, this is rather
# unlikely done on purpose
warnings.warn(
f"You are about to rescale data of type '{in_dtype}' in range "
f"[{_inmin}, {_inmax}] to the full range of '{as_dtype}'. "
"Consider specifying `out_range` to avoid this."
)
# first get the scaling factor
scale = (Decimal(_outmax) - Decimal(_outmin)) / \
(Decimal(_inmax) - Decimal(_inmin))
# now rescale
out_data = np.array(_outmin).astype(as_dtype) + ((data - _inmin) * float(scale)).astype(as_dtype)
outmax = float(np.nanmax(out_data))
outmin = float(np.nanmin(out_data))
if outmax > _outmax or outmin < outmin:
warnings.warn(
f"The rescaled data (range [{outmin}, {outmax}]), exceeds the "
f"determined output range [{_outmin}, {_outmax}]. "
"If this is unwanted make sure that the input data does not "
"exceed the `in_range`."
)
else:
# we simply change the data type - no rescaling
out_data = data.astype(as_dtype)
return out_data
[docs]
def aggregated_selector(masks: list[NDArray], logic: str = 'all') -> NDArray:
"""Turns several rasterio masks into a boolen selector for a numpy array
Rasterio masks are uint8 numpy arrays where every value > 0 is considered
a valid cell
Parameters
----------
masks : list[NDArray]
Arbitrary number of numpy arrays resulting from
:meth:`rasterio.io.DatasetReader.dataset_mask` or
:meth:`rasterio.io.DatasetReader.read_masks`.
logic : str
Determines how the aggreagation should happen.
If ``'all'`` (the default) a cell is only selected if **all** masks
consider it valid data — aggregated via :func:`numpy.logical_and`.
``'any'`` selects cells which **at least one** mask considers valid
— aggregated via :func:`numpy.logical_or`.
Returns
----------
NDArray
Boolean numpy array as result of logical mask applied.
See Also
--------
:func:`~riogrande.helper.reduced_mask` : Compute a mask from nodata values across bands.
"""
selector = masks[0] != 0 # values > 0 are selected (i.e. True)
if logic == 'any':
_logic = np.logical_or
else:
_logic = np.logical_and
if len(masks) > 1:
for mask in masks[1:]:
_logic(selector, mask != 0, out=selector)
return selector
[docs]
def reduced_mask(array: NDArray, nodata: float | int | np.nan = 0, logic: str = 'all') -> NDArray:
"""Computes a mask based on the value of several bands
Parameters
----------
array : NDArray
3D array holding multiple bands of map data
nodata : float or int or None
Nodata value to use. Defaults to 0.
Pass :data:`numpy.nan` to mask NaN cells (detected via :func:`numpy.isnan`).
logic : str
Allowed strings are:
- ``"all"`` : Masked will be each cell for which **all** bands match the nodata value
(aggregated via :func:`numpy.logical_or` across bands).
- ``"any"`` : Masked will be each cell for which **any** band matches the nodata value
(aggregated via :func:`numpy.logical_and` across bands).
Returns
----------
NDArray
Boolean numpy array resulting from applied logic.
See Also
--------
:func:`~riogrande.helper.aggregated_selector` : Aggregate rasterio band masks into a selector.
Examples
--------
>>> mydata = np.array([[[2, 4], [0, 1]], [[5, 5], [1, 0]]])
>>> # only mask if all are nodata
>>> reduced_mask(mydata)
array([[1, 1],
[1, 1]], dtype=uint8)
>>> # mask if any are nodata
>>> reduced_mask(mydata, logic='any')
array([[1, 1],
[0, 0]], dtype=uint8)
"""
if logic == 'any':
_logic = np.logical_and
else:
_logic = np.logical_or
if np.isnan(nodata):
return _logic.reduce(array=~np.isnan(array), axis=0).astype(np.uint8)
else:
return _logic.reduce(array=array != nodata, axis=0).astype(np.uint8)
[docs]
def count_contribution(data: NDArray, selector: NDArray[np.bool_], no_data: Union[int, float] = 0) -> int:
"""The remaining number of data cells when applying the selector
Uses :func:`numpy.unique` with ``return_counts=True`` to count valid cells.
Parameters
----------
data : NDArray
The data to cont the contribution in
selector : NDArray
A boolean array in the shape of `data` selecting the single cells that
should be considered
no_data : int or float
The value that should be considered as invalid.
.. note::
You might also provide :data:`numpy.nan` as no data value
(detected via :func:`numpy.isnan`).
Returns
----------
int
Count of valid cells (pixels in rasterfile).
See Also
--------
:func:`~riogrande.helper.aggregated_selector` : Build a selector from rasterio band masks.
:func:`~riogrande.helper.reduced_mask` : Build a mask from nodata values across bands.
"""
if np.isnan(no_data):
b_vals, b_counts = np.unique(~np.isnan(data[selector]), return_counts=True)
else:
b_vals, b_counts = np.unique(data[selector] != no_data, return_counts=True)
# b_vals is [True, False] and can be used as selector for b_counts
# thus returning the count of True
if True in b_vals:
return int(b_counts[b_vals][0])
else:
return 0
[docs]
def rasterio_to_numpy_dtype(rasterio_dtype: str) -> np.dtype | None:
"""Map Rasterio actual data types to NumPy data types.
Rasterio types like ``rasterio.dtypes.int16``, ``rasterio.dtypes.float32``
are mapped to their NumPy equivalents.
Parameters
----------
rasterio_dtype : str
Output of ``rasterio.open(source).profile['dtype']``, as returned by
:func:`rasterio.open`.
Returns
----------
numpy.dtype or None
Data type as :class:`numpy.dtype`, or ``None`` if the type is unknown.
"""
dtype_mapping = {
rio.dtypes.int16: np.int16,
rio.dtypes.int32: np.int32,
rio.dtypes.uint8: np.uint8,
rio.dtypes.uint16: np.uint16,
rio.dtypes.uint32: np.uint32,
rio.dtypes.float32: np.float32,
rio.dtypes.float64: np.float64,
}
return dtype_mapping.get(rasterio_dtype, None)