44from copy import copy
55from typing import Union
66
7+ import dask
78import geopandas as gpd
89import matplotlib
910import numpy as np
1819from spatialdata .models import (
1920 Image2DModel ,
2021 Labels2DModel ,
22+ PointsModel ,
2123)
2224
2325from spatialdata_plot ._logging import logger
@@ -57,6 +59,12 @@ def _render_shapes(
5759) -> None :
5860 elements = render_params .elements
5961
62+ if render_params .groups is not None :
63+ if isinstance (render_params .groups , str ):
64+ render_params .groups = [render_params .groups ]
65+ if not all (isinstance (g , str ) for g in render_params .groups ):
66+ raise TypeError ("All groups must be strings." )
67+
6068 sdata_filt = sdata .filter_by_coordinate_system (
6169 coordinate_system = coordinate_system ,
6270 filter_table = sdata .table is not None ,
@@ -68,7 +76,6 @@ def _render_shapes(
6876 elements = list (sdata_filt .shapes .keys ())
6977
7078 for e in elements :
71- # shapes = [sdata.shapes[e] for e in elements]
7279 shapes = sdata .shapes [e ]
7380 n_shapes = sum ([len (s ) for s in shapes ])
7481
@@ -88,6 +95,7 @@ def _render_shapes(
8895 palette = render_params .palette ,
8996 na_color = render_params .cmap_params .na_color ,
9097 alpha = render_params .fill_alpha ,
98+ cmap_params = render_params .cmap_params ,
9199 )
92100
93101 values_are_categorical = color_source_vector is not None
@@ -101,7 +109,15 @@ def _render_shapes(
101109 if len (color_vector ) == 0 :
102110 color_vector = [render_params .cmap_params .na_color ]
103111
112+ # filter by `groups`
113+ if render_params .groups is not None and color_source_vector is not None :
114+ mask = color_source_vector .isin (render_params .groups )
115+ shapes = shapes [mask ]
116+ shapes = shapes .reset_index ()
117+ color_source_vector = color_source_vector [mask ]
118+ color_vector = color_vector [mask ]
104119 shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
120+
105121 _cax = _get_collection_shape (
106122 shapes = shapes ,
107123 s = render_params .scale ,
@@ -122,9 +138,12 @@ def _render_shapes(
122138 cax = ax .add_collection (_cax )
123139
124140 # Using dict.fromkeys here since set returns in arbitrary order
125- palette = (
126- ListedColormap (dict .fromkeys (color_vector )) if render_params .palette is None else render_params .palette
127- )
141+ # remove the color of NaN values, else it might be assigned to a category
142+ # order of color in the palette should agree to order of occurence
143+ if color_source_vector is None :
144+ palette = ListedColormap (dict .fromkeys (color_vector ))
145+ else :
146+ palette = ListedColormap (dict .fromkeys (color_vector [~ pd .Categorical (color_source_vector ).isnull ()]))
128147
129148 if not (
130149 len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
@@ -159,6 +178,12 @@ def _render_points(
159178 scalebar_params : ScalebarParams ,
160179 legend_params : LegendParams ,
161180) -> None :
181+ if render_params .groups is not None :
182+ if isinstance (render_params .groups , str ):
183+ render_params .groups = [render_params .groups ]
184+ if not all (isinstance (g , str ) for g in render_params .groups ):
185+ raise TypeError ("All groups must be strings." )
186+
162187 elements = render_params .elements
163188
164189 sdata_filt = sdata .filter_by_coordinate_system (
@@ -178,6 +203,14 @@ def _render_points(
178203 color = [render_params .color ] if isinstance (render_params .color , str ) else render_params .color
179204 coords .extend (color )
180205
206+ points = points [coords ].compute ()
207+ # points[color[0]].cat.set_categories(render_params.groups, inplace=True)
208+ if render_params .groups is not None :
209+ points = points [points [color ].isin (render_params .groups ).values ]
210+ points [color [0 ]] = points [color [0 ]].cat .set_categories (render_params .groups )
211+ points = dask .dataframe .from_pandas (points , npartitions = 1 )
212+ sdata_filt .points [e ] = PointsModel .parse (points , coordinates = {"x" : "x" , "y" : "y" })
213+
181214 point_df = points [coords ].compute ()
182215
183216 # we construct an anndata to hack the plotting functions
@@ -204,6 +237,7 @@ def _render_points(
204237 palette = render_params .palette ,
205238 na_color = render_params .cmap_params .na_color ,
206239 alpha = render_params .alpha ,
240+ cmap_params = render_params .cmap_params ,
207241 )
208242
209243 # color_source_vector is None when the values aren't categorical
@@ -226,14 +260,19 @@ def _render_points(
226260 if not (
227261 len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
228262 ):
263+ if color_source_vector is None :
264+ palette = ListedColormap (dict .fromkeys (color_vector ))
265+ else :
266+ palette = ListedColormap (dict .fromkeys (color_vector [~ pd .Categorical (color_source_vector ).isnull ()]))
267+
229268 _ = _decorate_axs (
230269 ax = ax ,
231270 cax = cax ,
232271 fig_params = fig_params ,
233272 adata = adata ,
234273 value_to_plot = render_params .color ,
235274 color_source_vector = color_source_vector ,
236- palette = render_params . palette ,
275+ palette = palette ,
237276 alpha = render_params .alpha ,
238277 na_color = render_params .cmap_params .na_color ,
239278 legend_fontsize = legend_params .legend_fontsize ,
@@ -415,6 +454,12 @@ def _render_labels(
415454) -> None :
416455 elements = render_params .elements
417456
457+ if render_params .groups is not None :
458+ if isinstance (render_params .groups , str ):
459+ render_params .groups = [render_params .groups ]
460+ if not all (isinstance (g , str ) for g in render_params .groups ):
461+ raise TypeError ("All groups must be strings." )
462+
418463 sdata_filt = sdata .filter_by_coordinate_system (
419464 coordinate_system = coordinate_system ,
420465 filter_table = sdata .table is not None ,
@@ -441,7 +486,7 @@ def _render_labels(
441486
442487 table = sdata .table [sdata .table .obs [region_key ].isin ([label_key ])]
443488
444- # get isntance id based on subsetted table
489+ # get instance id based on subsetted table
445490 instance_id = table .obs [instance_key ].values
446491
447492 # get color vector (categorical or continuous)
@@ -455,6 +500,7 @@ def _render_labels(
455500 palette = render_params .palette ,
456501 na_color = render_params .cmap_params .na_color ,
457502 alpha = render_params .fill_alpha ,
503+ cmap_params = render_params .cmap_params ,
458504 )
459505
460506 if (render_params .fill_alpha != render_params .outline_alpha ) and render_params .contour_px is not None :
0 commit comments