22This module allows you to merge xarray Datasets/DataArrays
33geospatially with the `rasterio.merge` module.
44"""
5+
56from collections .abc import Sequence
67from typing import Callable , Optional , Union
78
89import numpy
910from rasterio .crs import CRS
1011from rasterio .io import MemoryFile
1112from rasterio .merge import merge as _rio_merge
12- from xarray import DataArray , Dataset
13+ from xarray import DataArray , Dataset , IndexVariable
1314
14- from rioxarray ._io import open_rasterio
15- from rioxarray .rioxarray import _get_nonspatial_coords
15+ from rioxarray .rioxarray import _get_nonspatial_coords , _make_coords
1616
1717
1818class RasterioDatasetDuck :
@@ -31,29 +31,13 @@ def __init__(self, xds: DataArray):
3131 self .count = int (xds .rio .count )
3232 self .dtypes = [xds .dtype ]
3333 self .name = xds .name
34- if xds .rio .encoded_nodata is not None :
35- self .nodatavals = [xds .rio .encoded_nodata ]
36- else :
37- self .nodatavals = [xds .rio .nodata ]
34+ self .nodatavals = [xds .rio .nodata ]
3835 res = xds .rio .resolution (recalc = True )
3936 self .res = (abs (res [0 ]), abs (res [1 ]))
4037 self .transform = xds .rio .transform (recalc = True )
41- self .profile : dict = {
42- "crs" : self .crs ,
43- "nodata" : self .nodatavals [0 ],
44- }
45- valid_scale_factor = self ._xds .encoding .get ("scale_factor" , 1 ) != 1 or any (
46- scale != 1 for scale in self ._xds .encoding .get ("scales" , (1 ,))
47- )
48- valid_offset = self ._xds .encoding .get ("add_offset" , 0.0 ) != 0 or any (
49- offset != 0 for offset in self ._xds .encoding .get ("offsets" , (0 ,))
50- )
51- self ._mask_and_scale = (
52- self ._xds .rio .encoded_nodata is not None
53- or valid_scale_factor
54- or valid_offset
55- or self ._xds .encoding .get ("_Unsigned" ) is not None
56- )
38+ # profile is only used for writing to a file.
39+ # This never happens with rioxarray merge.
40+ self .profile : dict = {}
5741
5842 def colormap (self , * args , ** kwargs ) -> None :
5943 """
@@ -68,16 +52,22 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
6852 This method is meant to be used by the rasterio.merge.merge function.
6953 """
7054 with MemoryFile () as memfile :
71- self ._xds .rio .to_raster (memfile .name )
72- with memfile .open () as dataset :
73- if self ._mask_and_scale :
74- kwargs ["masked" ] = True
75- out = dataset .read (* args , ** kwargs )
76- if self ._mask_and_scale :
77- out = out .astype (self ._xds .dtype )
78- for iii in range (self .count ):
79- out [iii ] = out [iii ] * dataset .scales [iii ] + dataset .offsets [iii ]
80- return out
55+ with memfile .open (
56+ driver = "GTiff" ,
57+ height = int (self ._xds .rio .height ),
58+ width = int (self ._xds .rio .width ),
59+ count = self .count ,
60+ dtype = self .dtypes [0 ],
61+ crs = self .crs ,
62+ transform = self .transform ,
63+ nodata = self .nodatavals [0 ],
64+ ) as dataset :
65+ data = self ._xds .values
66+ if data .ndim == 2 :
67+ dataset .write (data , 1 )
68+ else :
69+ dataset .write (data )
70+ return dataset .read (* args , ** kwargs )
8171
8272
8373def merge_arrays (
@@ -155,47 +145,66 @@ def merge_arrays(
155145 rioduckarrays .append (RasterioDatasetDuck (dataarray ))
156146
157147 # use rasterio to merge
148+ merged_data , merged_transform = _rio_merge (
149+ rioduckarrays ,
150+ ** {key : val for key , val in input_kwargs .items () if val is not None },
151+ )
158152 # generate merged data array
159153 representative_array = rioduckarrays [0 ]._xds
160- with MemoryFile () as memfile :
161- _rio_merge (
162- rioduckarrays ,
163- ** {key : val for key , val in input_kwargs .items () if val is not None },
164- dst_path = memfile .name ,
154+ if parse_coordinates :
155+ coords = _make_coords (
156+ src_data_array = representative_array ,
157+ dst_affine = merged_transform ,
158+ dst_width = merged_data .shape [- 1 ],
159+ dst_height = merged_data .shape [- 2 ],
165160 )
166- with open_rasterio ( # type: ignore
167- memfile .name ,
168- parse_coordinates = parse_coordinates ,
169- mask_and_scale = rioduckarrays [0 ]._mask_and_scale ,
170- ) as merged_data :
171- merged_data = merged_data .load ()
172-
173- # make sure old & new coorinate names match & dimensions are correct
174- rename_map = {}
175- original_extra_dim = representative_array .rio ._check_dimensions ()
176- new_extra_dim = merged_data .rio ._check_dimensions ()
177- # make sure the output merged data shape is 2D if the
178- # original data was 2D. this can happen if the
179- # xarray datasarray was squeezed.
180- if len (merged_data .shape ) == 3 and len (representative_array .shape ) == 2 :
181- merged_data = merged_data .squeeze (
182- dim = new_extra_dim , drop = original_extra_dim is None
161+ if (
162+ representative_array .rio .x_dim != "x"
163+ and "x" in coords
164+ and coords ["x" ].ndim == 1
165+ ):
166+ coords [representative_array .rio .x_dim ] = IndexVariable (
167+ representative_array .rio .x_dim , coords .pop ("x" )
183168 )
184- new_extra_dim = merged_data .rio ._check_dimensions ()
185169 if (
186- original_extra_dim is not None
187- and new_extra_dim is not None
188- and original_extra_dim != new_extra_dim
170+ representative_array . rio . y_dim != "y"
171+ and "y" in coords
172+ and coords [ "y" ]. ndim == 1
189173 ):
190- rename_map [new_extra_dim ] = original_extra_dim
191- if representative_array .rio .x_dim != merged_data .rio .x_dim :
192- rename_map [merged_data .rio .x_dim ] = representative_array .rio .x_dim
193- if representative_array .rio .y_dim != merged_data .rio .y_dim :
194- rename_map [merged_data .rio .y_dim ] = representative_array .rio .y_dim
195- if rename_map :
196- merged_data = merged_data .rename (rename_map )
197- merged_data .coords .update (_get_nonspatial_coords (representative_array ))
198- return merged_data # type: ignore
174+ coords [representative_array .rio .y_dim ] = IndexVariable (
175+ representative_array .rio .y_dim , coords .pop ("y" )
176+ )
177+ else :
178+ coords = _get_nonspatial_coords (representative_array )
179+
180+ # make sure the output merged data shape is 2D if the
181+ # original data was 2D. this can happen if the
182+ # xarray datasarray was squeezed.
183+ if len (merged_data .shape ) == 3 and len (representative_array .shape ) == 2 :
184+ merged_data = merged_data .squeeze ()
185+
186+ xda = DataArray (
187+ name = representative_array .name ,
188+ data = merged_data ,
189+ coords = coords ,
190+ dims = tuple (representative_array .dims ),
191+ attrs = representative_array .attrs ,
192+ )
193+ xda .encoding = representative_array .encoding .copy ()
194+ xda .rio .write_nodata (
195+ nodata if nodata is not None else representative_array .rio .nodata , inplace = True
196+ )
197+ xda .rio .write_crs (
198+ representative_array .rio .crs ,
199+ grid_mapping_name = representative_array .rio .grid_mapping ,
200+ inplace = True ,
201+ )
202+ xda .rio .write_transform (
203+ merged_transform ,
204+ grid_mapping_name = representative_array .rio .grid_mapping ,
205+ inplace = True ,
206+ )
207+ return xda
199208
200209
201210def merge_datasets (
@@ -218,7 +227,7 @@ def merge_datasets(
218227 Parameters
219228 ----------
220229 datasets: list[xarray.Dataset]
221- List of multiple xarray.Dataset with all geo attributes.
230+ List of xarray.Dataset's with all geo attributes.
222231 The first one is assumed to have the same
223232 CRS, dtype, dimensions, and data_vars as the others in the array.
224233 bounds: tuple, optional
0 commit comments