Skip to content

Commit bb10ee9

Browse files
committed
BUG:merge: Revert rasterio.io.MemoryFile code added in corteva#765 & corteva#781
1 parent 3c8512c commit bb10ee9

4 files changed

Lines changed: 84 additions & 81 deletions

File tree

docs/history.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
History
22
=======
33

4-
Latest
4+
0.19.0
55
------
6+
- BUG:merge: Revert `rasterio.io.MemoryFile` code added in #765 & #781
7+
68

79
0.18.2
810
------

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ classifiers = [
3535
requires-python = ">=3.10"
3636
dependencies = [
3737
"packaging",
38-
"rasterio>=1.3.7",
38+
"rasterio>=1.4.3",
3939
"xarray>=2024.7.0",
4040
"pyproj>=3.3",
4141
"numpy>=1.23",

rioxarray/merge.py

Lines changed: 78 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
This module allows you to merge xarray Datasets/DataArrays
33
geospatially with the `rasterio.merge` module.
44
"""
5+
56
from collections.abc import Sequence
67
from typing import Callable, Optional, Union
78

89
import numpy
910
from rasterio.crs import CRS
1011
from rasterio.io import MemoryFile
1112
from rasterio.merge import merge as _rio_merge
12-
from xarray import DataArray, Dataset
13+
from xarray import DataArray, Dataset, IndexVariable
1314

14-
from rioxarray._io import open_rasterio
15-
from rioxarray.rioxarray import _get_nonspatial_coords
15+
from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords
1616

1717

1818
class RasterioDatasetDuck:
@@ -31,29 +31,13 @@ def __init__(self, xds: DataArray):
3131
self.count = int(xds.rio.count)
3232
self.dtypes = [xds.dtype]
3333
self.name = xds.name
34-
if xds.rio.encoded_nodata is not None:
35-
self.nodatavals = [xds.rio.encoded_nodata]
36-
else:
37-
self.nodatavals = [xds.rio.nodata]
34+
self.nodatavals = [xds.rio.nodata]
3835
res = xds.rio.resolution(recalc=True)
3936
self.res = (abs(res[0]), abs(res[1]))
4037
self.transform = xds.rio.transform(recalc=True)
41-
self.profile: dict = {
42-
"crs": self.crs,
43-
"nodata": self.nodatavals[0],
44-
}
45-
valid_scale_factor = self._xds.encoding.get("scale_factor", 1) != 1 or any(
46-
scale != 1 for scale in self._xds.encoding.get("scales", (1,))
47-
)
48-
valid_offset = self._xds.encoding.get("add_offset", 0.0) != 0 or any(
49-
offset != 0 for offset in self._xds.encoding.get("offsets", (0,))
50-
)
51-
self._mask_and_scale = (
52-
self._xds.rio.encoded_nodata is not None
53-
or valid_scale_factor
54-
or valid_offset
55-
or self._xds.encoding.get("_Unsigned") is not None
56-
)
38+
# profile is only used for writing to a file.
39+
# This never happens with rioxarray merge.
40+
self.profile: dict = {}
5741

5842
def colormap(self, *args, **kwargs) -> None:
5943
"""
@@ -68,16 +52,22 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
6852
This method is meant to be used by the rasterio.merge.merge function.
6953
"""
7054
with MemoryFile() as memfile:
71-
self._xds.rio.to_raster(memfile.name)
72-
with memfile.open() as dataset:
73-
if self._mask_and_scale:
74-
kwargs["masked"] = True
75-
out = dataset.read(*args, **kwargs)
76-
if self._mask_and_scale:
77-
out = out.astype(self._xds.dtype)
78-
for iii in range(self.count):
79-
out[iii] = out[iii] * dataset.scales[iii] + dataset.offsets[iii]
80-
return out
55+
with memfile.open(
56+
driver="GTiff",
57+
height=int(self._xds.rio.height),
58+
width=int(self._xds.rio.width),
59+
count=self.count,
60+
dtype=self.dtypes[0],
61+
crs=self.crs,
62+
transform=self.transform,
63+
nodata=self.nodatavals[0],
64+
) as dataset:
65+
data = self._xds.values
66+
if data.ndim == 2:
67+
dataset.write(data, 1)
68+
else:
69+
dataset.write(data)
70+
return dataset.read(*args, **kwargs)
8171

8272

8373
def merge_arrays(
@@ -155,47 +145,66 @@ def merge_arrays(
155145
rioduckarrays.append(RasterioDatasetDuck(dataarray))
156146

157147
# use rasterio to merge
148+
merged_data, merged_transform = _rio_merge(
149+
rioduckarrays,
150+
**{key: val for key, val in input_kwargs.items() if val is not None},
151+
)
158152
# generate merged data array
159153
representative_array = rioduckarrays[0]._xds
160-
with MemoryFile() as memfile:
161-
_rio_merge(
162-
rioduckarrays,
163-
**{key: val for key, val in input_kwargs.items() if val is not None},
164-
dst_path=memfile.name,
154+
if parse_coordinates:
155+
coords = _make_coords(
156+
src_data_array=representative_array,
157+
dst_affine=merged_transform,
158+
dst_width=merged_data.shape[-1],
159+
dst_height=merged_data.shape[-2],
165160
)
166-
with open_rasterio( # type: ignore
167-
memfile.name,
168-
parse_coordinates=parse_coordinates,
169-
mask_and_scale=rioduckarrays[0]._mask_and_scale,
170-
) as merged_data:
171-
merged_data = merged_data.load()
172-
173-
# make sure old & new coorinate names match & dimensions are correct
174-
rename_map = {}
175-
original_extra_dim = representative_array.rio._check_dimensions()
176-
new_extra_dim = merged_data.rio._check_dimensions()
177-
# make sure the output merged data shape is 2D if the
178-
# original data was 2D. this can happen if the
179-
# xarray datasarray was squeezed.
180-
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
181-
merged_data = merged_data.squeeze(
182-
dim=new_extra_dim, drop=original_extra_dim is None
161+
if (
162+
representative_array.rio.x_dim != "x"
163+
and "x" in coords
164+
and coords["x"].ndim == 1
165+
):
166+
coords[representative_array.rio.x_dim] = IndexVariable(
167+
representative_array.rio.x_dim, coords.pop("x")
183168
)
184-
new_extra_dim = merged_data.rio._check_dimensions()
185169
if (
186-
original_extra_dim is not None
187-
and new_extra_dim is not None
188-
and original_extra_dim != new_extra_dim
170+
representative_array.rio.y_dim != "y"
171+
and "y" in coords
172+
and coords["y"].ndim == 1
189173
):
190-
rename_map[new_extra_dim] = original_extra_dim
191-
if representative_array.rio.x_dim != merged_data.rio.x_dim:
192-
rename_map[merged_data.rio.x_dim] = representative_array.rio.x_dim
193-
if representative_array.rio.y_dim != merged_data.rio.y_dim:
194-
rename_map[merged_data.rio.y_dim] = representative_array.rio.y_dim
195-
if rename_map:
196-
merged_data = merged_data.rename(rename_map)
197-
merged_data.coords.update(_get_nonspatial_coords(representative_array))
198-
return merged_data # type: ignore
174+
coords[representative_array.rio.y_dim] = IndexVariable(
175+
representative_array.rio.y_dim, coords.pop("y")
176+
)
177+
else:
178+
coords = _get_nonspatial_coords(representative_array)
179+
180+
# make sure the output merged data shape is 2D if the
181+
# original data was 2D. this can happen if the
182+
# xarray datasarray was squeezed.
183+
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
184+
merged_data = merged_data.squeeze()
185+
186+
xda = DataArray(
187+
name=representative_array.name,
188+
data=merged_data,
189+
coords=coords,
190+
dims=tuple(representative_array.dims),
191+
attrs=representative_array.attrs,
192+
)
193+
xda.encoding = representative_array.encoding.copy()
194+
xda.rio.write_nodata(
195+
nodata if nodata is not None else representative_array.rio.nodata, inplace=True
196+
)
197+
xda.rio.write_crs(
198+
representative_array.rio.crs,
199+
grid_mapping_name=representative_array.rio.grid_mapping,
200+
inplace=True,
201+
)
202+
xda.rio.write_transform(
203+
merged_transform,
204+
grid_mapping_name=representative_array.rio.grid_mapping,
205+
inplace=True,
206+
)
207+
return xda
199208

200209

201210
def merge_datasets(
@@ -218,7 +227,7 @@ def merge_datasets(
218227
Parameters
219228
----------
220229
datasets: list[xarray.Dataset]
221-
List of multiple xarray.Dataset with all geo attributes.
230+
List of xarray.Dataset's with all geo attributes.
222231
The first one is assumed to have the same
223232
CRS, dtype, dimensions, and data_vars as the others in the array.
224233
bounds: tuple, optional

test/integration/test_integration_merge.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,7 @@ def test_merge_arrays(squeeze):
5252
assert sorted(merged.coords) == sorted(rds.coords)
5353
assert merged.coords["band"].values == [1]
5454
assert merged.rio.crs == rds.rio.crs
55-
assert merged.attrs == {
56-
"AREA_OR_POINT": "Area",
57-
"add_offset": 0.0,
58-
"scale_factor": 1.0,
59-
**rds.attrs,
60-
}
55+
assert merged.attrs == rds.attrs
6156
assert merged.encoding["grid_mapping"] == "spatial_ref"
6257

6358

@@ -113,10 +108,7 @@ def test_merge__different_crs(dataset):
113108
assert merged.rio.crs == rds.rio.crs
114109
if not dataset:
115110
assert merged.attrs == {
116-
"AREA_OR_POINT": "Area",
117111
"_FillValue": -28672,
118-
"add_offset": 0.0,
119-
"scale_factor": 1.0,
120112
}
121113
assert merged.encoding["grid_mapping"] == "spatial_ref"
122114

@@ -284,7 +276,7 @@ def test_merge_datasets__mask_and_scale(mask_and_scale):
284276
rds.isel(x=slice(100, None), y=slice(100)),
285277
]
286278
merged = merge_datasets(datasets)
287-
assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"])
279+
assert sorted(merged.coords) == sorted(list(rds.coords))
288280
total = merged.air_temperature.sum()
289281
if mask_and_scale:
290282
assert_almost_equal(total, 133376696)

0 commit comments

Comments
 (0)