66from functools import partial
77from typing import Any , Callable , Union
88
9+ import geopandas as gpd
910import matplotlib
1011import numpy as np
1112import pandas as pd
3536 _decorate_axs ,
3637 _get_colors_for_categorical_obs ,
3738 _get_linear_colormap ,
39+ _make_patch_from_multipolygon ,
3840 _map_color_seg ,
3941 _maybe_set_colors ,
4042 _normalize ,
@@ -119,18 +121,22 @@ def _get_collection_shape(
119121 outline_alpha : None | float = None ,
120122 ** kwargs : Any ,
121123 ) -> PatchCollection :
122- patches = []
123- for shape in shapes :
124- # remove empty points/polygons
125- shape = shape [shape ["geometry" ].apply (lambda geom : not geom .is_empty )]
126- # We assume that all elements in one collection are of the same type
127- if shape ["geometry" ].iloc [0 ].geom_type == "Polygon" :
128- patches += [Polygon (p .exterior .coords , closed = True ) for p in shape ["geometry" ]]
129- elif shape ["geometry" ].iloc [0 ].geom_type == "Point" :
130- patches += [
131- Circle ((circ .x , circ .y ), radius = r * s ) for circ , r in zip (shape ["geometry" ], shape ["radius" ])
132- ]
133-
124+ """
125+ Get a PatchCollection for rendering given geometries with specified colors and outlines.
126+
127+ Args:
128+ - shapes (list[GeoDataFrame]): List of geometrical shapes.
129+ - c: Color parameter.
130+ - s (float): Size of the shape.
131+ - norm: Normalization for the color map.
132+ - fill_alpha (float, optional): Opacity for the fill color.
133+ - outline_alpha (float, optional): Opacity for the outline.
134+ - **kwargs: Additional keyword arguments.
135+
136+ Returns
137+ -------
138+ - PatchCollection: Collection of patches for rendering.
139+ """
134140 cmap = kwargs ["cmap" ]
135141
136142 try :
@@ -149,16 +155,60 @@ def _get_collection_shape(
149155 if render_params .outline_params .outline :
150156 outline_c = ColorConverter ().to_rgba_array (render_params .outline_params .outline_color )
151157 outline_c [..., - 1 ] = render_params .outline_alpha
158+ outline_c = outline_c .tolist ()
152159 else :
153- outline_c = None
160+ outline_c = [None ]
161+ outline_c = outline_c * fill_c .shape [0 ]
162+
163+ shapes_df = pd .DataFrame (shapes , copy = True )
164+
165+ # remove empty points/polygons
166+ shapes_df = shapes_df [shapes_df ["geometry" ].apply (lambda geom : not geom .is_empty )]
167+
168+ rows = []
169+
170+ def assign_fill_and_outline_to_row (
171+ shapes : list [GeoDataFrame ], fill_c : list [Any ], outline_c : list [Any ], row : pd .Series , idx : int
172+ ) -> None :
173+ if len (shapes ) > 1 and len (fill_c ) == 1 :
174+ row ["fill_c" ] = fill_c
175+ row ["outline_c" ] = outline_c
176+ else :
177+ row ["fill_c" ] = fill_c [idx ]
178+ row ["outline_c" ] = outline_c [idx ]
179+
180+ # Match colors to the geometry, potentially expanding the row in case of
181+ # multipolygons
182+ for idx , row in shapes_df .iterrows ():
183+ geom = row ["geometry" ]
184+ if geom .geom_type == "Polygon" :
185+ row = row .to_dict ()
186+ row ["geometry" ] = Polygon (geom .exterior .coords , closed = True )
187+ assign_fill_and_outline_to_row (shapes , fill_c , outline_c , row , idx )
188+ rows .append (row )
189+
190+ elif geom .geom_type == "MultiPolygon" :
191+ mp = _make_patch_from_multipolygon (geom )
192+ for _ , m in enumerate (mp ):
193+ mp_copy = row .to_dict ()
194+ mp_copy ["geometry" ] = m
195+ assign_fill_and_outline_to_row (shapes , fill_c , outline_c , mp_copy , idx )
196+ rows .append (mp_copy )
197+
198+ elif geom .geom_type == "Point" :
199+ row = row .to_dict ()
200+ row ["geometry" ] = Circle ((geom .x , geom .y ), radius = row ["radius" ])
201+ assign_fill_and_outline_to_row (shapes , fill_c , outline_c , row , idx )
202+ rows .append (row )
203+
204+ patches = pd .DataFrame (rows )
154205
155206 return PatchCollection (
156- patches ,
207+ patches [ "geometry" ]. values . tolist () ,
157208 snap = False ,
158- # zorder=4,
159209 lw = render_params .outline_params .linewidth ,
160- facecolor = fill_c ,
161- edgecolor = outline_c ,
210+ facecolor = patches [ " fill_c" ] ,
211+ edgecolor = None if all ( outline is None for outline in outline_c ) else outline_c ,
162212 ** kwargs ,
163213 )
164214
@@ -167,6 +217,8 @@ def _get_collection_shape(
167217 if len (color_vector ) == 0 :
168218 color_vector = [render_params .cmap_params .na_color ]
169219
220+ shapes = pd .concat (shapes , ignore_index = True )
221+ shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
170222 _cax = _get_collection_shape (
171223 shapes = shapes ,
172224 s = render_params .size ,
@@ -178,6 +230,7 @@ def _get_collection_shape(
178230 outline_alpha = render_params .outline_alpha
179231 # **kwargs,
180232 )
233+
181234 cax = ax .add_collection (_cax )
182235
183236 # Using dict.fromkeys here since set returns in arbitrary order
0 commit comments