Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit 3cc730c

Browse files
committed
Use composition to implement AsyncClient
1 parent d4ec67d commit 3cc730c

3 files changed

Lines changed: 40 additions & 150 deletions

File tree

gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2

Lines changed: 21 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
{% block content %}
44
from collections import OrderedDict
5+
import functools
56
import re
67
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
78
import pkg_resources
@@ -21,111 +22,31 @@ from google.oauth2 import service_account # type: ignore
2122
{% endfor -%}
2223
{% endfilter %}
2324
from .transports.base import {{ service.name }}Transport
24-
from .transports.grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport
25+
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
26+
from .client import {{ service.client_name }}
2527

2628

27-
class {{ service.async_client_name }}Meta(type):
28-
"""Metaclass for the {{ service.name }} client.
29-
30-
This provides class-level methods for building and retrieving
31-
support objects (e.g. transport) without polluting the client instance
32-
objects.
33-
"""
34-
_transport_registry: Dict[str, Type[{{ service.name }}Transport]] = OrderedDict()
35-
_transport_registry['grpc_asyncio'] = {{ service.name }}GrpcAsyncIOTransport
36-
37-
def get_transport_class(cls,
38-
label: str = None,
39-
) -> Type[{{ service.name }}Transport]:
40-
"""Return an appropriate transport class.
41-
42-
Args:
43-
label: The name of the desired transport. If none is
44-
provided, then the first transport in the registry is used.
45-
46-
Returns:
47-
The transport class to use.
48-
"""
49-
# If a specific transport is requested, return that one.
50-
if label:
51-
return cls._transport_registry[label]
52-
53-
# No transport is requested; return the default (that is, the first one
54-
# in the dictionary).
55-
return next(iter(cls._transport_registry.values()))
56-
57-
58-
class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}Meta):
29+
class {{ service.async_client_name }}:
5930
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""
6031

61-
@staticmethod
62-
def _get_default_mtls_endpoint(api_endpoint):
63-
"""Convert api endpoint to mTLS endpoint.
64-
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
65-
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
66-
Args:
67-
api_endpoint (Optional[str]): the api endpoint to convert.
68-
Returns:
69-
str: converted mTLS api endpoint.
70-
"""
71-
if not api_endpoint:
72-
return api_endpoint
73-
74-
mtls_endpoint_re = re.compile(
75-
r"(?P<name>[^.]+)(?P<mtls>\.mtls)?(?P<sandbox>\.sandbox)?(?P<googledomain>\.googleapis\.com)?"
76-
)
77-
78-
m = mtls_endpoint_re.match(api_endpoint)
79-
name, mtls, sandbox, googledomain = m.groups()
80-
if mtls or not googledomain:
81-
return api_endpoint
82-
83-
if sandbox:
84-
return api_endpoint.replace(
85-
"sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
86-
)
32+
_client: {{ service.client_name }}
8733

88-
return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
34+
DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT
35+
DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT
8936

90-
DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
91-
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
92-
DEFAULT_ENDPOINT
93-
)
94-
DEFAULT_MTLS_TRANSPORT = {{ service.grpc_asyncio_transport_name }}
95-
96-
@classmethod
97-
def from_service_account_file(cls, filename: str, *args, **kwargs):
98-
"""Creates an instance of this client using the provided credentials
99-
file.
100-
101-
Args:
102-
filename (str): The path to the service account private key json
103-
file.
104-
args: Additional arguments to pass to the constructor.
105-
kwargs: Additional arguments to pass to the constructor.
37+
{% for message in service.resource_messages -%}
38+
{{ message.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.{{ message.resource_type|snake_case }}_path)
10639

107-
Returns:
108-
{@api.name}: The constructed client.
109-
"""
110-
credentials = service_account.Credentials.from_service_account_file(
111-
filename)
112-
kwargs['credentials'] = credentials
113-
return cls(*args, **kwargs)
40+
{% endfor %}
11441

42+
from_service_account_file = {{ service.client_name }}.from_service_account_file
11543
from_service_account_json = from_service_account_file
11644

117-
118-
{% for message in service.resource_messages -%}
119-
@staticmethod
120-
def {{ message.resource_type|snake_case }}_path({% for arg in message.resource_path_args %}{{ arg }}: str,{% endfor %}) -> str:
121-
"""Return a fully-qualified {{ message.resource_type|snake_case }} string."""
122-
return "{{ message.resource_path }}".format({% for arg in message.resource_path_args %}{{ arg }}={{ arg }}, {% endfor %})
123-
124-
{% endfor %}
45+
get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }}))
12546

12647
def __init__(self, *,
12748
credentials: credentials.Credentials = None,
128-
transport: Union[str, {{ service.name }}Transport] = None,
49+
transport: Union[str, {{ service.name }}Transport] = "grpc_asyncio",
12950
client_options: ClientOptions = None,
13051
) -> None:
13152
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.
@@ -152,52 +73,12 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
15273
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
15374
creation failed for any reason.
15475
"""
155-
if isinstance(client_options, dict):
156-
client_options = ClientOptions.from_dict(client_options)
157-
158-
# Save or instantiate the transport.
159-
# Ordinarily, we provide the transport, but allowing a custom transport
160-
# instance provides an extensibility point for unusual situations.
161-
if isinstance(transport, {{ service.name }}Transport):
162-
# transport is a {{ service.name }}Transport instance.
163-
if credentials:
164-
raise ValueError('When providing a transport instance, '
165-
'provide its credentials directly.')
166-
self._transport = transport
167-
elif client_options is None or (
168-
client_options.api_endpoint == None
169-
and client_options.client_cert_source is None
170-
):
171-
# Don't trigger mTLS if we get an empty ClientOptions.
172-
Transport = type(self).get_transport_class(transport)
173-
self._transport = Transport(
174-
credentials=credentials, host=self.DEFAULT_ENDPOINT
175-
)
176-
else:
177-
# We have a non-empty ClientOptions. If client_cert_source is
178-
# provided, trigger mTLS with user provided endpoint or the default
179-
# mTLS endpoint.
180-
if client_options.client_cert_source:
181-
api_mtls_endpoint = (
182-
client_options.api_endpoint
183-
if client_options.api_endpoint
184-
else self.DEFAULT_MTLS_ENDPOINT
185-
)
186-
else:
187-
api_mtls_endpoint = None
188-
189-
api_endpoint = (
190-
client_options.api_endpoint
191-
if client_options.api_endpoint
192-
else self.DEFAULT_ENDPOINT
193-
)
194-
195-
self._transport = self.DEFAULT_MTLS_TRANSPORT(
196-
credentials=credentials,
197-
host=api_endpoint,
198-
api_mtls_endpoint=api_mtls_endpoint,
199-
client_cert_source=client_options.client_cert_source,
200-
)
76+
{# NOTE(lidiz) Not using kwargs since we want the docstring and types. #}
77+
self._client = {{ service.client_name }}(
78+
credentials=credentials,
79+
transport=transport,
80+
client_options=client_options,
81+
)
20182

20283
{% for method in service.methods.values() -%}
20384
{% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self,
@@ -296,7 +177,7 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
296177
# Wrap the RPC method; this adds retry and timeout information,
297178
# and friendly error handling.
298179
rpc = gapic_v1.method_async.wrap_method(
299-
self._transport.{{ method.name|snake_case }},
180+
self._client._transport.{{ method.name|snake_case }},
300181
{%- if method.retry %}
301182
default_retry=retries.Retry(
302183
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
@@ -344,7 +225,7 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
344225
# Wrap the response in an operation future.
345226
response = operation_async.from_gapic(
346227
response,
347-
self._transport.operations_client,
228+
self._client._transport.operations_client,
348229
{{ method.lro.response_type.ident }},
349230
metadata_type={{ method.lro.metadata_type.ident }},
350231
)

gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ from google.oauth2 import service_account # type: ignore
2121
{% endfor -%}
2222
{% endfilter %}
2323
from .transports.base import {{ service.name }}Transport
24-
from .transports.grpc import {{ service.name }}GrpcTransport
24+
from .transports.grpc import {{ service.grpc_transport_name }}
25+
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
2526

2627

2728
class {{ service.client_name }}Meta(type):
@@ -32,16 +33,22 @@ class {{ service.client_name }}Meta(type):
3233
objects.
3334
"""
3435
_transport_registry = OrderedDict() # type: Dict[str, Type[{{ service.name }}Transport]]
35-
_transport_registry['grpc'] = {{ service.name }}GrpcTransport
36+
_transport_registry['grpc'] = {{ service.grpc_transport_name }}
37+
_transport_registry['grpc_asyncio'] = {{ service.grpc_asyncio_transport_name }}
38+
39+
_mtls_support = {'grpc', 'grpc_asyncio'}
3640

3741
def get_transport_class(cls,
3842
label: str = None,
43+
enable_mtls: bool = False,
3944
) -> Type[{{ service.name }}Transport]:
4045
"""Return an appropriate transport class.
4146

4247
Args:
4348
label: The name of the desired transport. If none is
4449
provided, then the first transport in the registry is used.
50+
enable_mtls: A bool indicates whether the transport needs to
51+
support MTLS or not.
4552

4653
Returns:
4754
The transport class to use.
@@ -91,7 +98,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
9198
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
9299
DEFAULT_ENDPOINT
93100
)
94-
DEFAULT_MTLS_TRANSPORT = {{ service.grpc_transport_name }}
95101

96102
@classmethod
97103
def from_service_account_file(cls, filename: str, *args, **kwargs):
@@ -192,7 +198,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
192198
else self.DEFAULT_ENDPOINT
193199
)
194200

195-
self._transport = self.DEFAULT_MTLS_TRANSPORT(
201+
self._transport = type(self).get_transport_class(
202+
label=transport,
203+
enable_mtls=True,
204+
)(
196205
credentials=credentials,
197206
host=api_endpoint,
198207
api_mtls_endpoint=api_mtls_endpoint,

gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file(client_c
7171
])
7272
def test_{{ service.client_name|snake_case }}_client_options(client_class, transport_class):
7373
# Check that if channel is provided we won't create a new one.
74-
with mock.patch.object(client_class, 'get_transport_class') as gtc:
74+
with mock.patch.object({{ service.client_name }}, 'get_transport_class') as gtc:
7575
transport = transport_class(
7676
credentials=credentials.AnonymousCredentials()
7777
)
@@ -80,10 +80,10 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
8080

8181
# Check mTLS is not triggered with empty client options.
8282
options = client_options.ClientOptions()
83-
with mock.patch.object(client_class, 'get_transport_class') as gtc:
84-
transport = gtc.return_value = mock.MagicMock()
83+
with mock.patch.object(transport_class, '__init__') as patched:
84+
patched.return_value = None
8585
client = client_class(client_options=options)
86-
transport.assert_called_once_with(
86+
patched.assert_called_once_with(
8787
credentials=None,
8888
host=client.DEFAULT_ENDPOINT,
8989
)
@@ -427,7 +427,7 @@ async def test_{{ method.name|snake_case }}_async_pager():
427427

428428
# Mock the actual call within the gRPC stub, and fake the request.
429429
with mock.patch.object(
430-
type(client._transport.{{ method.name|snake_case }}),
430+
type(client._client._transport.{{ method.name|snake_case }}),
431431
'__call__', new_callable=mock.AsyncMock) as call:
432432
# Set the response to a series of pages.
433433
call.side_effect = (
@@ -475,7 +475,7 @@ async def test_{{ method.name|snake_case }}_async_pages():
475475

476476
# Mock the actual call within the gRPC stub, and fake the request.
477477
with mock.patch.object(
478-
type(client._transport.{{ method.name|snake_case }}),
478+
type(client._client._transport.{{ method.name|snake_case }}),
479479
'__call__', new_callable=mock.AsyncMock) as call:
480480
# Set the response to a series of pages.
481481
call.side_effect = (

0 commit comments

Comments
 (0)