@@ -220,7 +220,6 @@ def _render_points(
220220 coords .extend (color )
221221
222222 points = points [coords ].compute ()
223- # points[color[0]].cat.set_categories(render_params.groups, inplace=True)
224223 if render_params .groups is not None :
225224 points = points [points [color ].isin (render_params .groups ).values ]
226225 points [color [0 ]] = points [color [0 ]].cat .set_categories (render_params .groups )
@@ -260,6 +259,10 @@ def _render_points(
260259 if color_source_vector is None and render_params .transfunc is not None :
261260 color_vector = render_params .transfunc (color_vector )
262261
262+ trans = get_transformation (sdata .points [e ], get_all = True )[coordinate_system ]
263+ affine_trans = trans .to_affine_matrix (input_axes = ("x" , "y" ), output_axes = ("x" , "y" ))
264+ trans = mtransforms .Affine2D (matrix = affine_trans ) + ax .transData
265+
263266 norm = copy (render_params .cmap_params .norm )
264267 _cax = ax .scatter (
265268 adata [:, 0 ].X .flatten (),
@@ -270,17 +273,11 @@ def _render_points(
270273 cmap = render_params .cmap_params .cmap ,
271274 norm = norm ,
272275 alpha = render_params .alpha ,
276+ transform = trans
273277 # **kwargs,
274278 )
275279 cax = ax .add_collection (_cax )
276280
277- trans = get_transformation (sdata .points [e ], get_all = True )[coordinate_system ]
278- affine_trans = trans .to_affine_matrix (input_axes = ("x" , "y" ), output_axes = ("x" , "y" ))
279- trans = mtransforms .Affine2D (matrix = affine_trans )
280-
281- for path in _cax .get_paths ():
282- path .vertices = trans .transform (path .vertices )
283-
284281 if not (
285282 len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
286283 ):
0 commit comments