1919from multiscale_spatial_image .multiscale_spatial_image import MultiscaleSpatialImage
2020from pandas .api .types import is_categorical_dtype
2121from spatial_image import SpatialImage
22- from spatialdata ._logging import logger as logg
22+ from spatialdata ._core .data_extent import get_extent
23+ from spatialdata .transformations .operations import get_transformation
2324
2425from spatialdata_plot ._accessor import register_spatial_data_accessor
2526from spatialdata_plot .pl .render import (
4041)
4142from spatialdata_plot .pl .utils import (
4243 _get_cs_contents ,
43- _get_extent ,
4444 _maybe_set_colors ,
4545 _mpl_ax_contains_elements ,
4646 _prepare_cmap_norm ,
4747 _prepare_params_plot ,
48- _robust_transform ,
4948 _set_outline ,
5049 save_fig ,
5150)
@@ -216,6 +215,8 @@ def render_shapes(
216215 na_color = na_color , # type: ignore[arg-type]
217216 ** kwargs ,
218217 )
218+ if isinstance (elements , str ):
219+ elements = [elements ]
219220 outline_params = _set_outline (outline , outline_width , outline_color )
220221 sdata .plotting_tree [f"{ n_steps + 1 } _render_shapes" ] = ShapesRenderParams (
221222 elements = elements ,
@@ -285,12 +286,15 @@ def render_points(
285286 sdata = self ._copy ()
286287 sdata = _verify_plotting_tree (sdata )
287288 n_steps = len (sdata .plotting_tree .keys ())
289+
288290 cmap_params = _prepare_cmap_norm (
289291 cmap = cmap ,
290292 norm = norm ,
291293 na_color = na_color , # type: ignore[arg-type]
292294 ** kwargs ,
293295 )
296+ if isinstance (elements , str ):
297+ elements = [elements ]
294298 sdata .plotting_tree [f"{ n_steps + 1 } _render_points" ] = PointsRenderParams (
295299 elements = elements ,
296300 color = color ,
@@ -370,6 +374,8 @@ def render_images(
370374 ** kwargs ,
371375 )
372376
377+ if isinstance (elements , str ):
378+ elements = [elements ]
373379 sdata .plotting_tree [f"{ n_steps + 1 } _render_images" ] = ImageRenderParams (
374380 elements = elements ,
375381 channel = channel ,
@@ -450,6 +456,8 @@ def render_labels(
450456 na_color = na_color , # type: ignore[arg-type]
451457 ** kwargs ,
452458 )
459+ if isinstance (elements , str ):
460+ elements = [elements ]
453461 sdata .plotting_tree [f"{ n_steps + 1 } _render_labels" ] = LabelsRenderParams (
454462 elements = elements ,
455463 color = color ,
@@ -552,12 +560,12 @@ def show(
552560 raise TypeError ("All titles must be strings." )
553561
554562 # get original axis extent for later comparison
555- x_min_orig , x_max_orig = (np .inf , - np .inf )
556- y_min_orig , y_max_orig = (np .inf , - np .inf )
563+ ax_x_min , ax_x_max = (np .inf , - np .inf )
564+ ax_y_min , ax_y_max = (np .inf , - np .inf )
557565
558566 if isinstance (ax , Axes ) and _mpl_ax_contains_elements (ax ):
559- x_min_orig , x_max_orig = ax .get_xlim ()
560- y_max_orig , y_min_orig = ax .get_ylim () # (0, 0) is top-left
567+ ax_x_min , ax_x_max = ax .get_xlim ()
568+ ax_y_max , ax_y_min = ax .get_ylim () # (0, 0) is top-left
561569
562570 # handle coordinate system
563571 coordinate_systems = sdata .coordinate_systems if coordinate_systems is None else coordinate_systems
@@ -568,50 +576,6 @@ def show(
568576 if cs not in sdata .coordinate_systems :
569577 raise ValueError (f"Unknown coordinate system '{ cs } ', valid choices are: { sdata .coordinate_systems } " )
570578
571- # Check if user specified only certain elements to be plotted
572- cs_contents = _get_cs_contents (sdata )
573- elements_to_be_rendered = []
574- for cmd , params in render_cmds .items ():
575- if cmd == "render_images" and cs_contents .query (f"cs == '{ cs } '" )["has_images" ][0 ]: # noqa: SIM114
576- if params .elements is not None :
577- elements_to_be_rendered += (
578- [params .elements ] if isinstance (params .elements , str ) else params .elements
579- )
580- elif cmd == "render_shapes" and cs_contents .query (f"cs == '{ cs } '" )["has_shapes" ][0 ]: # noqa: SIM114
581- if params .elements is not None :
582- elements_to_be_rendered += (
583- [params .elements ] if isinstance (params .elements , str ) else params .elements
584- )
585- elif cmd == "render_points" and cs_contents .query (f"cs == '{ cs } '" )["has_points" ][0 ]: # noqa: SIM114
586- if params .elements is not None :
587- elements_to_be_rendered += (
588- [params .elements ] if isinstance (params .elements , str ) else params .elements
589- )
590- elif cmd == "render_labels" and cs_contents .query (f"cs == '{ cs } '" )["has_labels" ][0 ]: # noqa: SIM102
591- if params .elements is not None :
592- elements_to_be_rendered += (
593- [params .elements ] if isinstance (params .elements , str ) else params .elements
594- )
595-
596- extent = _get_extent (
597- sdata = sdata ,
598- has_images = "render_images" in render_cmds ,
599- has_labels = "render_labels" in render_cmds ,
600- has_points = "render_points" in render_cmds ,
601- has_shapes = "render_shapes" in render_cmds ,
602- elements = elements_to_be_rendered ,
603- coordinate_systems = coordinate_systems ,
604- )
605-
606- # Use extent to filter out coordinate system without the relevant elements
607- valid_cs = []
608- for cs in coordinate_systems :
609- if cs in extent :
610- valid_cs .append (cs )
611- else :
612- logg .info (f"Dropping coordinate system '{ cs } ' since it doesn't have relevant elements." )
613- coordinate_systems = valid_cs
614-
615579 # set up canvas
616580 fig_params , scalebar_params = _prepare_params_plot (
617581 num_panels = len (coordinate_systems ),
@@ -633,32 +597,25 @@ def show(
633597 colorbar = colorbar ,
634598 )
635599
600+ cs_contents = _get_cs_contents (sdata )
601+
636602 # go through tree
603+
637604 for i , cs in enumerate (coordinate_systems ):
638605 sdata = self ._copy ()
639- # properly transform all elements to the current coordinate system
640- members = cs_contents .query (f"cs == '{ cs } '" )
641-
642- if members ["has_images" ].values [0 ]:
643- for key in sdata .images :
644- sdata .images [key ] = _robust_transform (sdata .images [key ], cs )
645-
646- if members ["has_labels" ].values [0 ]:
647- for key in sdata .labels :
648- sdata .labels [key ] = _robust_transform (sdata .labels [key ], cs )
649-
650- if members ["has_points" ].values [0 ]:
651- for key in sdata .points :
652- sdata .points [key ] = _robust_transform (sdata .points [key ], cs )
653-
654- if members ["has_shapes" ].values [0 ]:
655- for key in sdata .shapes :
656- sdata .shapes [key ] = _robust_transform (sdata .shapes [key ], cs )
657-
606+ _ , has_images , has_labels , has_points , has_shapes = (
607+ cs_contents .query (f"cs == '{ cs } '" ).iloc [0 , :].values .tolist ()
608+ )
658609 ax = fig_params .ax if fig_params .axs is None else fig_params .axs [i ]
659610
611+ wants_images = False
612+ wants_labels = False
613+ wants_points = False
614+ wants_shapes = False
615+ wanted_elements = []
616+
660617 for cmd , params in render_cmds .items ():
661- if cmd == "render_images" and cs_contents . query ( f"cs == ' { cs } '" )[ " has_images" ][ 0 ] :
618+ if cmd == "render_images" and has_images :
662619 _render_images (
663620 sdata = sdata ,
664621 render_params = params ,
@@ -667,9 +624,18 @@ def show(
667624 fig_params = fig_params ,
668625 scalebar_params = scalebar_params ,
669626 legend_params = legend_params ,
670- # extent=extent[cs],
671627 )
672- elif cmd == "render_shapes" and cs_contents .query (f"cs == '{ cs } '" )["has_shapes" ][0 ]:
628+ wants_images = True
629+ wanted_images = params .elements if params .elements is not None else list (sdata .images .keys ())
630+ wanted_elements .extend (
631+ [
632+ image
633+ for image in wanted_images
634+ if cs in set (get_transformation (sdata .images [image ], get_all = True ).keys ())
635+ ]
636+ )
637+
638+ elif cmd == "render_shapes" and has_shapes :
673639 _render_shapes (
674640 sdata = sdata ,
675641 render_params = params ,
@@ -679,8 +645,17 @@ def show(
679645 scalebar_params = scalebar_params ,
680646 legend_params = legend_params ,
681647 )
648+ wants_shapes = True
649+ wanted_shapes = params .elements if params .elements is not None else list (sdata .shapes .keys ())
650+ wanted_elements .extend (
651+ [
652+ shape
653+ for shape in wanted_shapes
654+ if cs in set (get_transformation (sdata .shapes [shape ], get_all = True ).keys ())
655+ ]
656+ )
682657
683- elif cmd == "render_points" and cs_contents . query ( f"cs == ' { cs } '" )[ " has_points" ][ 0 ] :
658+ elif cmd == "render_points" and has_points :
684659 _render_points (
685660 sdata = sdata ,
686661 render_params = params ,
@@ -690,8 +665,17 @@ def show(
690665 scalebar_params = scalebar_params ,
691666 legend_params = legend_params ,
692667 )
668+ wants_points = True
669+ wanted_points = params .elements if params .elements is not None else list (sdata .points .keys ())
670+ wanted_elements .extend (
671+ [
672+ point
673+ for point in wanted_points
674+ if cs in set (get_transformation (sdata .points [point ], get_all = True ).keys ())
675+ ]
676+ )
693677
694- elif cmd == "render_labels" and cs_contents . query ( f"cs == ' { cs } '" )[ " has_labels" ][ 0 ] :
678+ elif cmd == "render_labels" and has_labels :
695679 if sdata .table is not None and isinstance (params .color , str ):
696680 colors = sc .get .obs_df (sdata .table , params .color )
697681 if is_categorical_dtype (colors ):
@@ -710,33 +694,46 @@ def show(
710694 scalebar_params = scalebar_params ,
711695 legend_params = legend_params ,
712696 )
697+ wants_labels = True
698+ wanted_labels = params .elements if params .elements is not None else list (sdata .labels .keys ())
699+ wanted_elements .extend (
700+ [
701+ label
702+ for label in wanted_labels
703+ if cs in set (get_transformation (sdata .labels [label ], get_all = True ).keys ())
704+ ]
705+ )
713706
714- if title is not None :
715- if len (title ) == 1 :
716- t = title [0 ]
717- else :
718- try :
719- t = title [i ]
720- except IndexError as e :
721- raise IndexError ("The number of titles must match the number of coordinate systems." ) from e
722- else :
707+ if title is None :
723708 t = cs
709+ elif len (title ) == 1 :
710+ t = title [0 ]
711+ else :
712+ try :
713+ t = title [i ]
714+ except IndexError as e :
715+ raise IndexError ("The number of titles must match the number of coordinate systems." ) from e
724716 ax .set_title (t )
725717 ax .set_aspect ("equal" )
726718
727- if any (
728- [
729- cs_contents .query (f"cs == '{ cs } '" )["has_images" ][0 ],
730- cs_contents .query (f"cs == '{ cs } '" )["has_labels" ][0 ],
731- cs_contents .query (f"cs == '{ cs } '" )["has_points" ][0 ],
732- cs_contents .query (f"cs == '{ cs } '" )["has_shapes" ][0 ],
733- ]
734- ):
719+ extent = get_extent (
720+ sdata ,
721+ coordinate_system = cs ,
722+ has_images = has_images and wants_images ,
723+ has_labels = has_labels and wants_labels ,
724+ has_points = has_points and wants_points ,
725+ has_shapes = has_shapes and wants_shapes ,
726+ elements = wanted_elements ,
727+ )
728+ cs_x_min , cs_x_max = extent ["x" ]
729+ cs_y_min , cs_y_max = extent ["y" ]
730+
731+ if any ([has_images , has_labels , has_points , has_shapes ]):
735732 # If the axis already has limits, only expand them but not overwrite
736- x_min = min (x_min_orig , extent [ cs ][ 0 ] ) - pad_extent
737- x_max = max (x_max_orig , extent [ cs ][ 1 ] ) + pad_extent
738- y_min = min (y_min_orig , extent [ cs ][ 2 ] ) - pad_extent
739- y_max = max (y_max_orig , extent [ cs ][ 3 ] ) + pad_extent
733+ x_min = min (ax_x_min , cs_x_min ) - pad_extent
734+ x_max = max (ax_x_max , cs_x_max ) + pad_extent
735+ y_min = min (ax_y_min , cs_y_min ) - pad_extent
736+ y_max = max (ax_y_max , cs_y_max ) + pad_extent
740737 ax .set_xlim (x_min , x_max )
741738 ax .set_ylim (y_max , y_min ) # (0, 0) is top-left
742739
@@ -747,5 +744,4 @@ def show(
747744 # https://stackoverflow.com/a/64523765
748745 if not hasattr (sys , "ps1" ):
749746 plt .show ()
750-
751747 return (fig_params .ax if fig_params .axs is None else fig_params .axs ) if return_ax else None # shuts up ruff
0 commit comments