Skip to content

Commit 0ca6b18

Browse files
committed
BUG:merge: Revert rasterio.io.MemoryFile code added in corteva#765 & corteva#781
1 parent 3c8512c commit 0ca6b18

5 files changed

Lines changed: 104 additions & 85 deletions

File tree

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
gdal-version: ['3.10.0']
4343
include:
4444
- python-version: '3.10'
45-
rasterio-version: '==1.3.7'
45+
rasterio-version: ''
4646
xarray-version: '==2024.7.0'
4747
numpy-version: '<2'
4848
run-with-scipy: 'YES'

docs/history.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
History
22
=======
33

4-
Latest
4+
0.19.0
55
------
6+
- BUG:merge: Revert `rasterio.io.MemoryFile` code added in #765 & #781
7+
68

79
0.18.2
810
------

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ classifiers = [
3535
requires-python = ">=3.10"
3636
dependencies = [
3737
"packaging",
38-
"rasterio>=1.3.7",
38+
"rasterio>=1.4.3",
3939
"xarray>=2024.7.0",
4040
"pyproj>=3.3",
4141
"numpy>=1.23",

rioxarray/merge.py

Lines changed: 97 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
This module allows you to merge xarray Datasets/DataArrays
33
geospatially with the `rasterio.merge` module.
44
"""
5+
56
from collections.abc import Sequence
67
from typing import Callable, Optional, Union
78

89
import numpy
910
from rasterio.crs import CRS
10-
from rasterio.io import MemoryFile
1111
from rasterio.merge import merge as _rio_merge
12-
from xarray import DataArray, Dataset
12+
from xarray import DataArray, Dataset, IndexVariable
1313

14-
from rioxarray._io import open_rasterio
15-
from rioxarray.rioxarray import _get_nonspatial_coords
14+
from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords
1615

1716

1817
class RasterioDatasetDuck:
@@ -31,29 +30,13 @@ def __init__(self, xds: DataArray):
3130
self.count = int(xds.rio.count)
3231
self.dtypes = [xds.dtype]
3332
self.name = xds.name
34-
if xds.rio.encoded_nodata is not None:
35-
self.nodatavals = [xds.rio.encoded_nodata]
36-
else:
37-
self.nodatavals = [xds.rio.nodata]
33+
self.nodatavals = [xds.rio.nodata]
3834
res = xds.rio.resolution(recalc=True)
3935
self.res = (abs(res[0]), abs(res[1]))
4036
self.transform = xds.rio.transform(recalc=True)
41-
self.profile: dict = {
42-
"crs": self.crs,
43-
"nodata": self.nodatavals[0],
44-
}
45-
valid_scale_factor = self._xds.encoding.get("scale_factor", 1) != 1 or any(
46-
scale != 1 for scale in self._xds.encoding.get("scales", (1,))
47-
)
48-
valid_offset = self._xds.encoding.get("add_offset", 0.0) != 0 or any(
49-
offset != 0 for offset in self._xds.encoding.get("offsets", (0,))
50-
)
51-
self._mask_and_scale = (
52-
self._xds.rio.encoded_nodata is not None
53-
or valid_scale_factor
54-
or valid_offset
55-
or self._xds.encoding.get("_Unsigned") is not None
56-
)
37+
# profile is only used for writing to a file.
38+
# This never happens with rioxarray merge.
39+
self.profile: dict = {}
5740

5841
def colormap(self, *args, **kwargs) -> None:
5942
"""
@@ -63,21 +46,44 @@ def colormap(self, *args, **kwargs) -> None:
6346
# pylint: disable=unused-argument
6447
return None
6548

66-
def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
49+
def read(self, window, out_shape, *args, **kwargs) -> numpy.ma.MaskedArray:
50+
# pylint: disable=unused-argument
6751
"""
6852
This method is meant to be used by the rasterio.merge.merge function.
6953
"""
70-
with MemoryFile() as memfile:
71-
self._xds.rio.to_raster(memfile.name)
72-
with memfile.open() as dataset:
73-
if self._mask_and_scale:
74-
kwargs["masked"] = True
75-
out = dataset.read(*args, **kwargs)
76-
if self._mask_and_scale:
77-
out = out.astype(self._xds.dtype)
78-
for iii in range(self.count):
79-
out[iii] = out[iii] * dataset.scales[iii] + dataset.offsets[iii]
80-
return out
54+
data_window = self._xds.rio.isel_window(window)
55+
if data_window.shape != out_shape:
56+
# in this section, the data is geographically the same
57+
# however it is not the same dimensions as requested
58+
# so need to resample to the requested shape
59+
if len(out_shape) == 3:
60+
_, out_height, out_width = out_shape
61+
else:
62+
out_height, out_width = out_shape
63+
data_window = self._xds.rio.reproject(
64+
self._xds.rio.crs,
65+
transform=self.transform,
66+
shape=(out_height, out_width),
67+
)
68+
69+
nodata = self.nodatavals[0]
70+
mask = False
71+
fill_value = None
72+
if nodata is not None and numpy.isnan(nodata):
73+
mask = numpy.isnan(data_window)
74+
elif nodata is not None:
75+
mask = data_window == nodata
76+
fill_value = nodata
77+
78+
# make sure the returned shape matches
79+
# the expected shape. This can be the case
80+
# when the xarray dataset was squeezed to 2D beforehand
81+
if len(out_shape) == 3 and len(data_window.shape) == 2:
82+
data_window = data_window.values.reshape((1, out_height, out_width))
83+
84+
return numpy.ma.array(
85+
data_window, mask=mask, fill_value=fill_value, dtype=self.dtypes[0]
86+
)
8187

8288

8389
def merge_arrays(
@@ -155,47 +161,66 @@ def merge_arrays(
155161
rioduckarrays.append(RasterioDatasetDuck(dataarray))
156162

157163
# use rasterio to merge
164+
merged_data, merged_transform = _rio_merge(
165+
rioduckarrays,
166+
**{key: val for key, val in input_kwargs.items() if val is not None},
167+
)
158168
# generate merged data array
159169
representative_array = rioduckarrays[0]._xds
160-
with MemoryFile() as memfile:
161-
_rio_merge(
162-
rioduckarrays,
163-
**{key: val for key, val in input_kwargs.items() if val is not None},
164-
dst_path=memfile.name,
170+
if parse_coordinates:
171+
coords = _make_coords(
172+
src_data_array=representative_array,
173+
dst_affine=merged_transform,
174+
dst_width=merged_data.shape[-1],
175+
dst_height=merged_data.shape[-2],
165176
)
166-
with open_rasterio( # type: ignore
167-
memfile.name,
168-
parse_coordinates=parse_coordinates,
169-
mask_and_scale=rioduckarrays[0]._mask_and_scale,
170-
) as merged_data:
171-
merged_data = merged_data.load()
172-
173-
# make sure old & new coorinate names match & dimensions are correct
174-
rename_map = {}
175-
original_extra_dim = representative_array.rio._check_dimensions()
176-
new_extra_dim = merged_data.rio._check_dimensions()
177-
# make sure the output merged data shape is 2D if the
178-
# original data was 2D. this can happen if the
179-
# xarray datasarray was squeezed.
180-
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
181-
merged_data = merged_data.squeeze(
182-
dim=new_extra_dim, drop=original_extra_dim is None
177+
if (
178+
representative_array.rio.x_dim != "x"
179+
and "x" in coords
180+
and coords["x"].ndim == 1
181+
):
182+
coords[representative_array.rio.x_dim] = IndexVariable(
183+
representative_array.rio.x_dim, coords.pop("x")
183184
)
184-
new_extra_dim = merged_data.rio._check_dimensions()
185185
if (
186-
original_extra_dim is not None
187-
and new_extra_dim is not None
188-
and original_extra_dim != new_extra_dim
186+
representative_array.rio.y_dim != "y"
187+
and "y" in coords
188+
and coords["y"].ndim == 1
189189
):
190-
rename_map[new_extra_dim] = original_extra_dim
191-
if representative_array.rio.x_dim != merged_data.rio.x_dim:
192-
rename_map[merged_data.rio.x_dim] = representative_array.rio.x_dim
193-
if representative_array.rio.y_dim != merged_data.rio.y_dim:
194-
rename_map[merged_data.rio.y_dim] = representative_array.rio.y_dim
195-
if rename_map:
196-
merged_data = merged_data.rename(rename_map)
197-
merged_data.coords.update(_get_nonspatial_coords(representative_array))
198-
return merged_data # type: ignore
190+
coords[representative_array.rio.y_dim] = IndexVariable(
191+
representative_array.rio.y_dim, coords.pop("y")
192+
)
193+
else:
194+
coords = _get_nonspatial_coords(representative_array)
195+
196+
# make sure the output merged data shape is 2D if the
197+
# original data was 2D. this can happen if the
198+
# xarray datasarray was squeezed.
199+
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
200+
merged_data = merged_data.squeeze()
201+
202+
xda = DataArray(
203+
name=representative_array.name,
204+
data=merged_data,
205+
coords=coords,
206+
dims=tuple(representative_array.dims),
207+
attrs=representative_array.attrs,
208+
)
209+
xda.encoding = representative_array.encoding.copy()
210+
xda.rio.write_nodata(
211+
nodata if nodata is not None else representative_array.rio.nodata, inplace=True
212+
)
213+
xda.rio.write_crs(
214+
representative_array.rio.crs,
215+
grid_mapping_name=representative_array.rio.grid_mapping,
216+
inplace=True,
217+
)
218+
xda.rio.write_transform(
219+
merged_transform,
220+
grid_mapping_name=representative_array.rio.grid_mapping,
221+
inplace=True,
222+
)
223+
return xda
199224

200225

201226
def merge_datasets(
@@ -218,7 +243,7 @@ def merge_datasets(
218243
Parameters
219244
----------
220245
datasets: list[xarray.Dataset]
221-
List of multiple xarray.Dataset with all geo attributes.
246+
List of xarray.Dataset's with all geo attributes.
222247
The first one is assumed to have the same
223248
CRS, dtype, dimensions, and data_vars as the others in the array.
224249
bounds: tuple, optional

test/integration/test_integration_merge.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,7 @@ def test_merge_arrays(squeeze):
5252
assert sorted(merged.coords) == sorted(rds.coords)
5353
assert merged.coords["band"].values == [1]
5454
assert merged.rio.crs == rds.rio.crs
55-
assert merged.attrs == {
56-
"AREA_OR_POINT": "Area",
57-
"add_offset": 0.0,
58-
"scale_factor": 1.0,
59-
**rds.attrs,
60-
}
55+
assert merged.attrs == rds.attrs
6156
assert merged.encoding["grid_mapping"] == "spatial_ref"
6257

6358

@@ -113,10 +108,7 @@ def test_merge__different_crs(dataset):
113108
assert merged.rio.crs == rds.rio.crs
114109
if not dataset:
115110
assert merged.attrs == {
116-
"AREA_OR_POINT": "Area",
117111
"_FillValue": -28672,
118-
"add_offset": 0.0,
119-
"scale_factor": 1.0,
120112
}
121113
assert merged.encoding["grid_mapping"] == "spatial_ref"
122114

@@ -284,7 +276,7 @@ def test_merge_datasets__mask_and_scale(mask_and_scale):
284276
rds.isel(x=slice(100, None), y=slice(100)),
285277
]
286278
merged = merge_datasets(datasets)
287-
assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"])
279+
assert sorted(merged.coords) == sorted(list(rds.coords))
288280
total = merged.air_temperature.sum()
289281
if mask_and_scale:
290282
assert_almost_equal(total, 133376696)

0 commit comments

Comments
 (0)