Skip to content

Commit 58def1c

Browse files
authored
Merge branch 'main' into bugfix/issue172-add-tests-for-extent-of-plots
2 parents 84409ef + b82e7f5 commit 58def1c

17 files changed

Lines changed: 97 additions & 19 deletions

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ and this project adheres to [Semantic Versioning][].
1616
- Large images are automatically rasterized to speed up performance (#164)
1717
- Added better error message for mismatch in cs and ax number (#185)
1818
- Beter test coverage for correct plotting of elements after transformation (#198)
19+
- Can now stack render commands (#190, #192)
1920

2021
### Fixed
2122

2223
- Now dropping index when plotting shapes after spatial query (#177)
2324
- Points are now being correctly rotated (#198)
25+
- User can now pass Colormap objects to the cmap argument in render_images. When only one cmap is given for 3 channels, it is now applied to each channel (#188, #194)
2426

2527
## [0.0.6] - 2023-11-06
2628

src/spatialdata_plot/pl/basic.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,6 @@ def render_images(
362362
sdata = _verify_plotting_tree(sdata)
363363
n_steps = len(sdata.plotting_tree.keys())
364364

365-
if channel is None and cmap is None:
366-
cmap = "brg"
367-
368365
cmap_params: list[CmapParams] | CmapParams
369366
if isinstance(cmap, list):
370367
cmap_params = [
@@ -567,7 +564,7 @@ def show(
567564
]
568565

569566
# prepare rendering params
570-
render_cmds = OrderedDict()
567+
render_cmds = []
571568
for cmd, params in plotting_tree.items():
572569
# strip prefix from cmd and verify it's valid
573570
cmd = "_".join(cmd.split("_")[1:])
@@ -577,9 +574,9 @@ def show(
577574

578575
if "render" in cmd:
579576
# verify that rendering commands have been called before
580-
render_cmds[cmd] = params
577+
render_cmds.append((cmd, params))
581578

582-
if len(render_cmds.keys()) == 0:
579+
if len(render_cmds) == 0:
583580
raise TypeError("Please specify what to plot using the 'render_*' functions before calling 'imshow()'.")
584581

585582
if title is not None:
@@ -609,7 +606,7 @@ def show(
609606
# Check if user specified only certain elements to be plotted
610607
cs_contents = _get_cs_contents(sdata)
611608
elements_to_be_rendered = []
612-
for cmd, params in render_cmds.items():
609+
for cmd, params in render_cmds:
613610
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114
614611
if params.elements is not None:
615612
elements_to_be_rendered += (
@@ -632,13 +629,14 @@ def show(
632629
)
633630

634631
# filter out cs without relevant elements
632+
cmds = [cmd for cmd, _ in render_cmds]
635633
coordinate_systems = _get_valid_cs(
636634
sdata=sdata,
637635
coordinate_systems=coordinate_systems,
638-
render_images="render_images" in render_cmds,
639-
render_labels="render_labels" in render_cmds,
640-
render_points="render_points" in render_cmds,
641-
render_shapes="render_shapes" in render_cmds,
636+
render_images="render_images" in cmds,
637+
render_labels="render_labels" in cmds,
638+
render_points="render_points" in cmds,
639+
render_shapes="render_shapes" in cmds,
642640
elements=elements_to_be_rendered,
643641
)
644642

@@ -689,7 +687,7 @@ def show(
689687
wants_shapes = False
690688
wanted_elements = []
691689

692-
for cmd, params in render_cmds.items():
690+
for cmd, params in render_cmds:
693691
if cmd == "render_images" and has_images:
694692
wants_images = True
695693
wanted_images = params.elements if params.elements is not None else list(sdata.images.keys())

src/spatialdata_plot/pl/render.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,13 @@ def _render_images(
403403
else:
404404
cmap = _get_linear_colormap([render_params.palette], "k")[0]
405405

406+
# Overwrite alpha in cmap: https://stackoverflow.com/a/10127675
407+
cmap._init()
408+
cmap._lut[:, -1] = render_params.alpha
409+
406410
im = ax.imshow(
407-
layer, # get rid of the channel dimension
411+
layer,
408412
cmap=cmap,
409-
alpha=render_params.alpha,
410413
)
411414
im.set_transform(trans_data)
412415

@@ -431,10 +434,29 @@ def _render_images(
431434
if render_params.cmap_params[i].norm is not None:
432435
layers[c] = render_params.cmap_params[i].norm(layers[c])
433436

434-
# 2A) Image has 3 channels, no palette/cmap info -> use RGB
435-
if n_channels == 3 and render_params.palette is None and not got_multiple_cmaps:
437+
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
438+
if n_channels == 3 and render_params.palette is None and not isinstance(render_params.cmap_params, list):
439+
if render_params.cmap_params.is_default: # -> use RGB
440+
stacked = np.stack([layers[c] for c in channels], axis=-1)
441+
else: # -> use given cmap for each channel
442+
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
443+
# Apply cmaps to each channel, add up and normalize to [0, 1]
444+
stacked = (
445+
np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) / n_channels
446+
)
447+
# Remove alpha channel so we can overwrite it from render_params.alpha
448+
stacked = stacked[:, :, :3]
449+
logger.warning(
450+
"One cmap was given for multiple channels and is now used for each channel. "
451+
"You're blending multiple cmaps. "
452+
"If the plot doesn't look like you expect, it might be because your "
453+
"cmaps go from a given color to 'white', and not to 'transparent'. "
454+
"Therefore, the 'white' of higher layers will overlay the lower layers. "
455+
"Consider using 'palette' instead."
456+
)
457+
436458
im = ax.imshow(
437-
np.stack([layers[c] for c in channels], axis=-1),
459+
stacked,
438460
alpha=render_params.alpha,
439461
)
440462
im.set_transform(trans_data)
@@ -511,6 +533,12 @@ def _render_labels(
511533
) -> None:
512534
elements = render_params.elements
513535

536+
if not isinstance(render_params.outline, bool):
537+
raise TypeError("Parameter 'outline' must be a boolean.")
538+
539+
if not isinstance(render_params.contour_px, int):
540+
raise TypeError("Parameter 'contour_px' must be an integer.")
541+
514542
if render_params.groups is not None:
515543
if isinstance(render_params.groups, str):
516544
render_params.groups = [render_params.groups]

src/spatialdata_plot/pl/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,13 @@ def _prepare_cmap_norm(
344344
**kwargs: Any,
345345
) -> CmapParams:
346346
is_default = cmap is None
347-
cmap = copy(matplotlib.colormaps[rcParams["image.cmap"] if cmap is None else cmap])
347+
if cmap is None:
348+
cmap = rcParams["image.cmap"]
349+
if isinstance(cmap, str):
350+
cmap = matplotlib.colormaps[cmap]
351+
352+
cmap = copy(cmap)
353+
348354
cmap.set_bad("lightgray" if na_color is None else na_color)
349355

350356
if isinstance(norm, Normalize) or not norm:
32 KB
Loading
31.8 KB
Loading
-37.6 KB
Binary file not shown.
32 KB
Loading
31.8 KB
Loading
18.5 KB
Loading

0 commit comments

Comments
 (0)