Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit fa41722

Browse files
committed
Port MTLS functionality to AsyncClient
1 parent 31f4559 commit fa41722

2 files changed

Lines changed: 81 additions & 6 deletions

File tree

gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
{% block content %}
44
from collections import OrderedDict
5+
import re
56
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
67
import pkg_resources
78

@@ -57,7 +58,40 @@ class {{ service.async_client_name }}Meta(type):
5758
class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}Meta):
5859
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""
5960

60-
DEFAULT_OPTIONS = ClientOptions.ClientOptions({% if service.host %}api_endpoint='{{ service.host }}'{% endif %})
61+
@staticmethod
62+
def _get_default_mtls_endpoint(api_endpoint):
63+
"""Convert api endpoint to mTLS endpoint.
64+
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
65+
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
66+
Args:
67+
api_endpoint (Optional[str]): the api endpoint to convert.
68+
Returns:
69+
str: converted mTLS api endpoint.
70+
"""
71+
if not api_endpoint:
72+
return api_endpoint
73+
74+
mtls_endpoint_re = re.compile(
75+
r"(?P<name>[^.]+)(?P<mtls>\.mtls)?(?P<sandbox>\.sandbox)?(?P<googledomain>\.googleapis\.com)?"
76+
)
77+
78+
m = mtls_endpoint_re.match(api_endpoint)
79+
name, mtls, sandbox, googledomain = m.groups()
80+
if mtls or not googledomain:
81+
return api_endpoint
82+
83+
if sandbox:
84+
return api_endpoint.replace(
85+
"sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
86+
)
87+
88+
return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
89+
90+
DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
91+
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
92+
DEFAULT_ENDPOINT
93+
)
94+
DEFAULT_MTLS_TRANSPORT = {{ service.grpc_asyncio_transport_name }}
6195

6296
@classmethod
6397
def from_service_account_file(cls, filename: str, *args, **kwargs):
@@ -92,9 +126,9 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
92126
def __init__(self, *,
93127
credentials: credentials.Credentials = None,
94128
transport: Union[str, {{ service.name }}Transport] = None,
95-
client_options: ClientOptions = DEFAULT_OPTIONS,
129+
client_options: ClientOptions = None,
96130
) -> None:
97-
"""Instantiate the {{ (service.async_client_name|snake_case).replace('_', ' ') }}.
131+
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.
98132

99133
Args:
100134
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -106,6 +140,17 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
106140
transport to use. If set to None, a transport is chosen
107141
automatically.
108142
client_options (ClientOptions): Custom options for the client.
143+
(1) The ``api_endpoint`` property can be used to override the
144+
default endpoint provided by the client.
145+
(2) If ``transport`` argument is None, ``client_options`` can be
146+
used to create a mutual TLS transport. If ``client_cert_source``
147+
is provided, mutual TLS transport will be created with the given
148+
``api_endpoint`` or the default mTLS endpoint, and the client
149+
SSL credentials obtained from ``client_cert_source``.
150+
151+
Raises:
152+
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
153+
creation failed for any reason.
109154
"""
110155
if isinstance(client_options, dict):
111156
client_options = ClientOptions.from_dict(client_options)
@@ -114,15 +159,44 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
114159
# Ordinarily, we provide the transport, but allowing a custom transport
115160
# instance provides an extensibility point for unusual situations.
116161
if isinstance(transport, {{ service.name }}Transport):
162+
# transport is a {{ service.name }}Transport instance.
117163
if credentials:
118164
raise ValueError('When providing a transport instance, '
119165
'provide its credentials directly.')
120166
self._transport = transport
121-
else:
167+
elif client_options is None or (
168+
client_options.api_endpoint == None
169+
and client_options.client_cert_source is None
170+
):
171+
# Don't trigger mTLS if we get an empty ClientOptions.
122172
Transport = type(self).get_transport_class(transport)
123173
self._transport = Transport(
174+
credentials=credentials, host=self.DEFAULT_ENDPOINT
175+
)
176+
else:
177+
# We have a non-empty ClientOptions. If client_cert_source is
178+
# provided, trigger mTLS with user provided endpoint or the default
179+
# mTLS endpoint.
180+
if client_options.client_cert_source:
181+
api_mtls_endpoint = (
182+
client_options.api_endpoint
183+
if client_options.api_endpoint
184+
else self.DEFAULT_MTLS_ENDPOINT
185+
)
186+
else:
187+
api_mtls_endpoint = None
188+
189+
api_endpoint = (
190+
client_options.api_endpoint
191+
if client_options.api_endpoint
192+
else self.DEFAULT_ENDPOINT
193+
)
194+
195+
self._transport = self.DEFAULT_MTLS_TRANSPORT(
124196
credentials=credentials,
125-
host=client_options.api_endpoint{% if service.host %} or '{{ service.host }}'{% endif %},
197+
host=api_endpoint,
198+
api_mtls_endpoint=api_mtls_endpoint,
199+
client_cert_source=client_options.client_cert_source,
126200
)
127201

128202
{% for method in service.methods.values() -%}

gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
9191
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
9292
DEFAULT_ENDPOINT
9393
)
94+
DEFAULT_MTLS_TRANSPORT = {{ service.grpc_transport_name }}
9495

9596
@classmethod
9697
def from_service_account_file(cls, filename: str, *args, **kwargs):
@@ -191,7 +192,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
191192
else self.DEFAULT_ENDPOINT
192193
)
193194

194-
self._transport = {{ service.name }}GrpcTransport(
195+
self._transport = self.DEFAULT_MTLS_TRANSPORT(
195196
credentials=credentials,
196197
host=api_endpoint,
197198
api_mtls_endpoint=api_mtls_endpoint,

0 commit comments

Comments
 (0)