Skip to content

Commit f7fac79

Browse files
authored
Complex testing for correct plotting after applying transformation (#198)
* Added test I'm breaking and incorrect plot * fixed point transformation * Added changelog
1 parent b82e7f5 commit f7fac79

4 files changed

Lines changed: 105 additions & 8 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ and this project adheres to [Semantic Versioning][].
1515
- Multiscale image handling: user can specify a scale, else the best scale is selected automatically given the figure size and dpi (#164)
1616
- Large images are automatically rasterized to speed up performance (#164)
1717
- Added better error message for mismatch in cs and ax number (#185)
18+
- Beter test coverage for correct plotting of elements after transformation (#198)
1819
- Can now stack render commands (#190, #192)
1920

2021
### Fixed
2122

2223
- Now dropping index when plotting shapes after spatial query (#177)
24+
- Points are now being correctly rotated (#198)
2325
- User can now pass Colormap objects to the cmap argument in render_images. When only one cmap is given for 3 channels, it is now applied to each channel (#188, #194)
2426

2527
## [0.0.6] - 2023-11-06

src/spatialdata_plot/pl/render.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def _render_points(
220220
coords.extend(color)
221221

222222
points = points[coords].compute()
223-
# points[color[0]].cat.set_categories(render_params.groups, inplace=True)
224223
if render_params.groups is not None:
225224
points = points[points[color].isin(render_params.groups).values]
226225
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups)
@@ -260,6 +259,10 @@ def _render_points(
260259
if color_source_vector is None and render_params.transfunc is not None:
261260
color_vector = render_params.transfunc(color_vector)
262261

262+
trans = get_transformation(sdata.points[e], get_all=True)[coordinate_system]
263+
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
264+
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
265+
263266
norm = copy(render_params.cmap_params.norm)
264267
_cax = ax.scatter(
265268
adata[:, 0].X.flatten(),
@@ -270,17 +273,11 @@ def _render_points(
270273
cmap=render_params.cmap_params.cmap,
271274
norm=norm,
272275
alpha=render_params.alpha,
276+
transform=trans
273277
# **kwargs,
274278
)
275279
cax = ax.add_collection(_cax)
276280

277-
trans = get_transformation(sdata.points[e], get_all=True)[coordinate_system]
278-
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
279-
trans = mtransforms.Affine2D(matrix=affine_trans)
280-
281-
for path in _cax.get_paths():
282-
path.vertices = trans.transform(path.vertices)
283-
284281
if not (
285282
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
286283
):
23.4 KB
Loading

tests/pl/test_get_extent.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
import math
2+
13
import matplotlib
4+
import matplotlib.pyplot as plt
5+
import numpy as np
26
import scanpy as sc
37
import spatialdata_plot # noqa: F401
8+
from geopandas import GeoDataFrame
9+
from shapely.geometry import MultiPolygon, Point, Polygon
410
from spatialdata import SpatialData
11+
from spatialdata.models import PointsModel, ShapesModel
12+
from spatialdata.transformations import Affine, set_transformation
513

614
from tests.conftest import PlotTester, PlotTesterMeta
715

@@ -42,3 +50,93 @@ def test_plot_extent_of_img_is_correct_after_spatial_query(self, sdata_blobs: Sp
4250
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
4351
)
4452
cropped_blobs.pl.render_images().pl.show()
53+
54+
def test_plot_correct_plot_after_transformations(self):
55+
# inspired by https://github.com/scverse/spatialdata/blob/ef0a2dc7f9af8d4c84f15eec503177f1d08c3d46/tests/core/test_data_extent.py#L125
56+
57+
circles = [Point(p) for p in [[0.5, 0.1], [0.9, 0.5], [0.5, 0.9], [0.1, 0.5]]]
58+
circles_gdf = GeoDataFrame(geometry=circles)
59+
circles_gdf["radius"] = 0.1
60+
circles_gdf = ShapesModel.parse(circles_gdf)
61+
62+
polygons = [Polygon([(0.5, 0.5), (0.5, 0), (0.6, 0.1), (0.5, 0.5)])]
63+
polygons.append(Polygon([(0.5, 0.5), (1, 0.5), (0.9, 0.6), (0.5, 0.5)]))
64+
polygons.append(Polygon([(0.5, 0.5), (0.5, 1), (0.4, 0.9), (0.5, 0.5)]))
65+
polygons.append(Polygon([(0.5, 0.5), (0, 0.5), (0.1, 0.4), (0.5, 0.5)]))
66+
polygons_gdf = GeoDataFrame(geometry=polygons)
67+
polygons_gdf = ShapesModel.parse(polygons_gdf)
68+
69+
multipolygons = [
70+
MultiPolygon(
71+
[
72+
polygons[0],
73+
Polygon([(0.7, 0.1), (0.9, 0.1), (0.9, 0.3), (0.7, 0.1)]),
74+
]
75+
)
76+
]
77+
multipolygons.append(MultiPolygon([polygons[1], Polygon([(0.9, 0.7), (0.9, 0.9), (0.7, 0.9), (0.9, 0.7)])]))
78+
multipolygons.append(MultiPolygon([polygons[2], Polygon([(0.3, 0.9), (0.1, 0.9), (0.1, 0.7), (0.3, 0.9)])]))
79+
multipolygons.append(MultiPolygon([polygons[3], Polygon([(0.1, 0.3), (0.1, 0.1), (0.3, 0.1), (0.1, 0.3)])]))
80+
multipolygons_gdf = GeoDataFrame(geometry=multipolygons)
81+
multipolygons_gdf = ShapesModel.parse(multipolygons_gdf)
82+
83+
points_df = PointsModel.parse(np.array([[0.5, 0], [1, 0.5], [0.5, 1], [0, 0.5]]))
84+
85+
sdata = SpatialData(
86+
shapes={
87+
"circles": circles_gdf,
88+
"polygons": polygons_gdf,
89+
"multipolygons": multipolygons_gdf,
90+
"circles_pi3": circles_gdf,
91+
"polygons_pi3": polygons_gdf,
92+
"multipolygons_pi3": multipolygons_gdf,
93+
"circles_pi4": circles_gdf,
94+
"polygons_pi4": polygons_gdf,
95+
"multipolygons_pi4": multipolygons_gdf,
96+
},
97+
points={"points": points_df, "points_pi3": points_df, "points_pi4": points_df},
98+
)
99+
100+
for i in [3, 4]:
101+
theta = math.pi / i
102+
rotation = Affine(
103+
[
104+
[math.cos(theta), -math.sin(theta), 0],
105+
[math.sin(theta), math.cos(theta), 0],
106+
[0, 0, 1],
107+
],
108+
input_axes=("x", "y"),
109+
output_axes=("x", "y"),
110+
)
111+
for element_name in [f"circles_pi{i}", f"polygons_pi{i}", f"multipolygons_pi{i}", f"points_pi{i}"]:
112+
set_transformation(element=sdata[element_name], transformation=rotation, to_coordinate_system=f"pi{i}")
113+
114+
_, axs = plt.subplots(ncols=3, nrows=4, figsize=(7, 9))
115+
116+
for cs_idx, cs in enumerate(["global", "pi3", "pi4"]):
117+
if cs == "global":
118+
circles_name = "circles"
119+
polygons_name = "polygons"
120+
multipolygons_name = "multipolygons"
121+
points_name = "points"
122+
elif cs == "pi3":
123+
circles_name = "circles_pi3"
124+
polygons_name = "polygons_pi3"
125+
multipolygons_name = "multipolygons_pi3"
126+
points_name = "points_pi3"
127+
else:
128+
circles_name = "circles_pi4"
129+
polygons_name = "polygons_pi4"
130+
multipolygons_name = "multipolygons_pi4"
131+
points_name = "points_pi4"
132+
133+
sdata.pl.render_shapes(elements=circles_name).pl.show(coordinate_systems=cs, ax=axs[0, cs_idx], title="")
134+
sdata.pl.render_shapes(elements=polygons_name).pl.show(coordinate_systems=cs, ax=axs[1, cs_idx], title="")
135+
sdata.pl.render_shapes(elements=multipolygons_name).pl.show(
136+
coordinate_systems=cs, ax=axs[2, cs_idx], title=""
137+
)
138+
sdata.pl.render_points(elements=points_name, size=10).pl.show(
139+
coordinate_systems=cs, ax=axs[3, cs_idx], title="", pad_extent=0.02
140+
)
141+
142+
plt.tight_layout()

0 commit comments

Comments
 (0)