Skip to content

Commit 0a5845a

Browse files
committed
fixed point transformation
1 parent ea0511d commit 0a5845a

3 files changed

Lines changed: 7 additions & 10 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def _render_points(
220220
coords.extend(color)
221221

222222
points = points[coords].compute()
223-
# points[color[0]].cat.set_categories(render_params.groups, inplace=True)
224223
if render_params.groups is not None:
225224
points = points[points[color].isin(render_params.groups).values]
226225
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups)
@@ -260,6 +259,10 @@ def _render_points(
260259
if color_source_vector is None and render_params.transfunc is not None:
261260
color_vector = render_params.transfunc(color_vector)
262261

262+
trans = get_transformation(sdata.points[e], get_all=True)[coordinate_system]
263+
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
264+
trans = mtransforms.Affine2D(matrix=affine_trans) + ax.transData
265+
263266
norm = copy(render_params.cmap_params.norm)
264267
_cax = ax.scatter(
265268
adata[:, 0].X.flatten(),
@@ -270,17 +273,11 @@ def _render_points(
270273
cmap=render_params.cmap_params.cmap,
271274
norm=norm,
272275
alpha=render_params.alpha,
276+
transform=trans
273277
# **kwargs,
274278
)
275279
cax = ax.add_collection(_cax)
276280

277-
trans = get_transformation(sdata.points[e], get_all=True)[coordinate_system]
278-
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
279-
trans = mtransforms.Affine2D(matrix=affine_trans)
280-
281-
for path in _cax.get_paths():
282-
path.vertices = trans.transform(path.vertices)
283-
284281
if not (
285282
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
286283
):
645 Bytes
Loading

tests/pl/test_get_extent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_plot_correct_plot_after_transformations(self):
111111
for element_name in [f"circles_pi{i}", f"polygons_pi{i}", f"multipolygons_pi{i}", f"points_pi{i}"]:
112112
set_transformation(element=sdata[element_name], transformation=rotation, to_coordinate_system=f"pi{i}")
113113

114-
fig, axs = plt.subplots(ncols=3, nrows=4, figsize=(7, 9))
114+
_, axs = plt.subplots(ncols=3, nrows=4, figsize=(7, 9))
115115

116116
for cs_idx, cs in enumerate(["global", "pi3", "pi4"]):
117117
if cs == "global":
@@ -136,7 +136,7 @@ def test_plot_correct_plot_after_transformations(self):
136136
coordinate_systems=cs, ax=axs[2, cs_idx], title=""
137137
)
138138
sdata.pl.render_points(elements=points_name, size=10).pl.show(
139-
coordinate_systems=cs, ax=axs[3, cs_idx], title="", pad_extent=0.05
139+
coordinate_systems=cs, ax=axs[3, cs_idx], title="", pad_extent=0.02
140140
)
141141

142142
plt.tight_layout()

0 commit comments

Comments
 (0)