22This module allows you to merge xarray Datasets/DataArrays
33geospatially with the `rasterio.merge` module.
44"""
5+
56from collections .abc import Sequence
67from typing import Callable , Optional , Union
78
89import numpy
910from rasterio .crs import CRS
10- from rasterio .io import MemoryFile
1111from rasterio .merge import merge as _rio_merge
12- from xarray import DataArray , Dataset
12+ from xarray import DataArray , Dataset , IndexVariable
1313
14- from rioxarray ._io import open_rasterio
15- from rioxarray .rioxarray import _get_nonspatial_coords
14+ from rioxarray .rioxarray import _get_nonspatial_coords , _make_coords
1615
1716
1817class RasterioDatasetDuck :
@@ -31,29 +30,13 @@ def __init__(self, xds: DataArray):
3130 self .count = int (xds .rio .count )
3231 self .dtypes = [xds .dtype ]
3332 self .name = xds .name
34- if xds .rio .encoded_nodata is not None :
35- self .nodatavals = [xds .rio .encoded_nodata ]
36- else :
37- self .nodatavals = [xds .rio .nodata ]
33+ self .nodatavals = [xds .rio .nodata ]
3834 res = xds .rio .resolution (recalc = True )
3935 self .res = (abs (res [0 ]), abs (res [1 ]))
4036 self .transform = xds .rio .transform (recalc = True )
41- self .profile : dict = {
42- "crs" : self .crs ,
43- "nodata" : self .nodatavals [0 ],
44- }
45- valid_scale_factor = self ._xds .encoding .get ("scale_factor" , 1 ) != 1 or any (
46- scale != 1 for scale in self ._xds .encoding .get ("scales" , (1 ,))
47- )
48- valid_offset = self ._xds .encoding .get ("add_offset" , 0.0 ) != 0 or any (
49- offset != 0 for offset in self ._xds .encoding .get ("offsets" , (0 ,))
50- )
51- self ._mask_and_scale = (
52- self ._xds .rio .encoded_nodata is not None
53- or valid_scale_factor
54- or valid_offset
55- or self ._xds .encoding .get ("_Unsigned" ) is not None
56- )
37+ # profile is only used for writing to a file.
38+ # This never happens with rioxarray merge.
39+ self .profile : dict = {}
5740
5841 def colormap (self , * args , ** kwargs ) -> None :
5942 """
@@ -63,21 +46,44 @@ def colormap(self, *args, **kwargs) -> None:
6346 # pylint: disable=unused-argument
6447 return None
6548
66- def read (self , * args , ** kwargs ) -> numpy .ma .MaskedArray :
49+ def read (self , window , out_shape , * args , ** kwargs ) -> numpy .ma .MaskedArray :
50+ # pylint: disable=unused-argument
6751 """
6852 This method is meant to be used by the rasterio.merge.merge function.
6953 """
70- with MemoryFile () as memfile :
71- self ._xds .rio .to_raster (memfile .name )
72- with memfile .open () as dataset :
73- if self ._mask_and_scale :
74- kwargs ["masked" ] = True
75- out = dataset .read (* args , ** kwargs )
76- if self ._mask_and_scale :
77- out = out .astype (self ._xds .dtype )
78- for iii in range (self .count ):
79- out [iii ] = out [iii ] * dataset .scales [iii ] + dataset .offsets [iii ]
80- return out
54+ data_window = self ._xds .rio .isel_window (window )
55+ if data_window .shape != out_shape :
56+ # in this section, the data is geographically the same
57+ # however it is not the same dimensions as requested
58+ # so need to resample to the requested shape
59+ if len (out_shape ) == 3 :
60+ _ , out_height , out_width = out_shape
61+ else :
62+ out_height , out_width = out_shape
63+ data_window = self ._xds .rio .reproject (
64+ self ._xds .rio .crs ,
65+ transform = self .transform ,
66+ shape = (out_height , out_width ),
67+ )
68+
69+ nodata = self .nodatavals [0 ]
70+ mask = False
71+ fill_value = None
72+ if nodata is not None and numpy .isnan (nodata ):
73+ mask = numpy .isnan (data_window )
74+ elif nodata is not None :
75+ mask = data_window == nodata
76+ fill_value = nodata
77+
78+ # make sure the returned shape matches
79+ # the expected shape. This can be the case
80+ # when the xarray dataset was squeezed to 2D beforehand
81+ if len (out_shape ) == 3 and len (data_window .shape ) == 2 :
82+ data_window = data_window .values .reshape ((1 , out_height , out_width ))
83+
84+ return numpy .ma .array (
85+ data_window , mask = mask , fill_value = fill_value , dtype = self .dtypes [0 ]
86+ )
8187
8288
8389def merge_arrays (
@@ -155,47 +161,66 @@ def merge_arrays(
155161 rioduckarrays .append (RasterioDatasetDuck (dataarray ))
156162
157163 # use rasterio to merge
164+ merged_data , merged_transform = _rio_merge (
165+ rioduckarrays ,
166+ ** {key : val for key , val in input_kwargs .items () if val is not None },
167+ )
158168 # generate merged data array
159169 representative_array = rioduckarrays [0 ]._xds
160- with MemoryFile () as memfile :
161- _rio_merge (
162- rioduckarrays ,
163- ** {key : val for key , val in input_kwargs .items () if val is not None },
164- dst_path = memfile .name ,
170+ if parse_coordinates :
171+ coords = _make_coords (
172+ src_data_array = representative_array ,
173+ dst_affine = merged_transform ,
174+ dst_width = merged_data .shape [- 1 ],
175+ dst_height = merged_data .shape [- 2 ],
165176 )
166- with open_rasterio ( # type: ignore
167- memfile .name ,
168- parse_coordinates = parse_coordinates ,
169- mask_and_scale = rioduckarrays [0 ]._mask_and_scale ,
170- ) as merged_data :
171- merged_data = merged_data .load ()
172-
173- # make sure old & new coorinate names match & dimensions are correct
174- rename_map = {}
175- original_extra_dim = representative_array .rio ._check_dimensions ()
176- new_extra_dim = merged_data .rio ._check_dimensions ()
177- # make sure the output merged data shape is 2D if the
178- # original data was 2D. this can happen if the
179- # xarray datasarray was squeezed.
180- if len (merged_data .shape ) == 3 and len (representative_array .shape ) == 2 :
181- merged_data = merged_data .squeeze (
182- dim = new_extra_dim , drop = original_extra_dim is None
177+ if (
178+ representative_array .rio .x_dim != "x"
179+ and "x" in coords
180+ and coords ["x" ].ndim == 1
181+ ):
182+ coords [representative_array .rio .x_dim ] = IndexVariable (
183+ representative_array .rio .x_dim , coords .pop ("x" )
183184 )
184- new_extra_dim = merged_data .rio ._check_dimensions ()
185185 if (
186- original_extra_dim is not None
187- and new_extra_dim is not None
188- and original_extra_dim != new_extra_dim
186+ representative_array . rio . y_dim != "y"
187+ and "y" in coords
188+ and coords [ "y" ]. ndim == 1
189189 ):
190- rename_map [new_extra_dim ] = original_extra_dim
191- if representative_array .rio .x_dim != merged_data .rio .x_dim :
192- rename_map [merged_data .rio .x_dim ] = representative_array .rio .x_dim
193- if representative_array .rio .y_dim != merged_data .rio .y_dim :
194- rename_map [merged_data .rio .y_dim ] = representative_array .rio .y_dim
195- if rename_map :
196- merged_data = merged_data .rename (rename_map )
197- merged_data .coords .update (_get_nonspatial_coords (representative_array ))
198- return merged_data # type: ignore
190+ coords [representative_array .rio .y_dim ] = IndexVariable (
191+ representative_array .rio .y_dim , coords .pop ("y" )
192+ )
193+ else :
194+ coords = _get_nonspatial_coords (representative_array )
195+
196+ # make sure the output merged data shape is 2D if the
197+ # original data was 2D. this can happen if the
198+ # xarray datasarray was squeezed.
199+ if len (merged_data .shape ) == 3 and len (representative_array .shape ) == 2 :
200+ merged_data = merged_data .squeeze ()
201+
202+ xda = DataArray (
203+ name = representative_array .name ,
204+ data = merged_data ,
205+ coords = coords ,
206+ dims = tuple (representative_array .dims ),
207+ attrs = representative_array .attrs ,
208+ )
209+ xda .encoding = representative_array .encoding .copy ()
210+ xda .rio .write_nodata (
211+ nodata if nodata is not None else representative_array .rio .nodata , inplace = True
212+ )
213+ xda .rio .write_crs (
214+ representative_array .rio .crs ,
215+ grid_mapping_name = representative_array .rio .grid_mapping ,
216+ inplace = True ,
217+ )
218+ xda .rio .write_transform (
219+ merged_transform ,
220+ grid_mapping_name = representative_array .rio .grid_mapping ,
221+ inplace = True ,
222+ )
223+ return xda
199224
200225
201226def merge_datasets (
@@ -218,7 +243,7 @@ def merge_datasets(
218243 Parameters
219244 ----------
220245 datasets: list[xarray.Dataset]
221- List of multiple xarray.Dataset with all geo attributes.
246+ List of xarray.Dataset's with all geo attributes.
222247 The first one is assumed to have the same
223248 CRS, dtype, dimensions, and data_vars as the others in the array.
224249 bounds: tuple, optional
0 commit comments