33{% block content %}
44from typing import Awaitable, Callable, Dict, Sequence, Tuple
55
6- from google.api_core import grpc_helpers_async # type: ignore
6+ from google.api_core import grpc_helpers_async # type: ignore
77{% - if service .has_lro %}
8- from google.api_core import operations_v1 # type: ignore
8+ from google.api_core import operations_v1 # type: ignore
99{% - endif %}
10- from google.auth import credentials # type: ignore
10+ from google.auth import credentials # type: ignore
11+ from google.auth.transport.grpc import SslCredentials # type: ignore
1112
13+ import grpc # type: ignore
1214from grpc.experimental import aio # type: ignore
1315
1416{% filter sort_lines -%}
@@ -17,10 +19,11 @@ from grpc.experimental import aio # type: ignore
1719{{ method.output.ident.python_import }}
1820{% endfor -%}
1921{% endfilter %}
20- from .grpc_base import {{ service.name }}GrpcBaseTransport
22+ from .base import {{ service.name }}Transport
23+ from .grpc import {{ service.name }}GrpcTransport
2124
2225
23- class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransport[aio.Channel] ):
26+ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport ):
2427 """gRPC AsyncIO backend transport for {{ service.name }}.
2528
2629 {{ service.meta.doc|rst(width=72, indent=4) }}
@@ -33,6 +36,9 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransp
3336 top of HTTP/2); the ``grpcio`` package must be installed.
3437 """
3538
39+ _grpc_channel: aio.Channel
40+ _stubs: Dict[str, Callable] = {}
41+
3642 @classmethod
3743 def create_channel(cls,
3844 host: str{% if service .host %} = '{{ service.host }}'{% endif %} ,
@@ -64,6 +70,87 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransp
6470 **kwargs
6571 )
6672
73+ def __init__(self, *,
74+ host: str{% if service .host %} = '{{ service.host }}'{% endif %} ,
75+ credentials: credentials.Credentials = None,
76+ channel: aio.Channel = None,
77+ api_mtls_endpoint: str = None,
78+ client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
79+ """Instantiate the transport.
80+
81+ Args:
82+ host ({% if service .host %} Optional[str]{% else %} str{% endif %} ):
83+ {{- ' ' }}The hostname to connect to.
84+ credentials (Optional[google.auth.credentials.Credentials]): The
85+ authorization credentials to attach to requests. These
86+ credentials identify the application to the service; if none
87+ are specified, the client will attempt to ascertain the
88+ credentials from the environment.
89+ This argument is ignored if ``channel`` is provided.
90+ channel (Optional[aio.Channel]): A ``Channel`` instance through
91+ which to make calls.
92+ api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If
93+ provided, it overrides the ``host`` argument and tries to create
94+ a mutual TLS channel with client SSL credentials from
95+ ``client_cert_source`` or applicatin default SSL credentials.
96+ client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A
97+ callback to provide client SSL certificate bytes and private key
98+ bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
99+ is None.
100+
101+ Raises:
102+ google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
103+ creation failed for any reason.
104+ """
105+ if channel:
106+ # Sanity check: Ensure that channel and credentials are not both
107+ # provided.
108+ credentials = False
109+
110+ # If a channel was explicitly provided, set it.
111+ self._grpc_channel = channel
112+ elif api_mtls_endpoint:
113+ host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
114+
115+ # Create SSL credentials with client_cert_source or application
116+ # default SSL credentials.
117+ if client_cert_source:
118+ cert, key = client_cert_source()
119+ ssl_credentials = grpc.ssl_channel_credentials(
120+ certificate_chain=cert, private_key=key
121+ )
122+ else:
123+ ssl_credentials = SslCredentials().ssl_credentials
124+
125+ # create a new channel. The provided one is ignored.
126+ self._grpc_channel = type(self).create_channel(
127+ host,
128+ credentials=credentials,
129+ ssl_credentials=ssl_credentials,
130+ scopes=self.AUTH_SCOPES,
131+ )
132+
133+ # Run the base constructor.
134+ super().__init__(host=host, credentials=credentials)
135+ self._stubs = {}
136+
137+ @property
138+ def grpc_channel(self) -> aio.Channel:
139+ """Create the channel designed to connect to this service.
140+
141+ This property caches on the instance; repeated calls return
142+ the same channel.
143+ """
144+ # Sanity check: Only create a new channel if we do not already
145+ # have one.
146+ if not hasattr(self, '_grpc_channel'):
147+ self._grpc_channel = self.create_channel(
148+ self._host,
149+ credentials=self._credentials,
150+ )
151+
152+ # Return the channel from cache.
153+ return self._grpc_channel
67154 {% - if service .has_lro %}
68155
69156 @property
@@ -101,7 +188,17 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransp
101188 A function that, when called, will call the underlying RPC
102189 on the server.
103190 """
104- return super().{{ method.name|snake_case }}
191+ # Generate a "stub function" on-the-fly which will actually make
192+ # the request.
193+ # gRPC handles serialization and deserialization, so we just need
194+ # to pass in the functions for each.
195+ if '{{ method.name|snake_case }}' not in self._stubs:
196+ self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
197+ '/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
198+ request_serializer={{ method.input.ident }}.{% if method .input .ident .python_import .module .endswith ('_pb2' ) %} SerializeToString{% else %} serialize{% endif %} ,
199+ response_deserializer={{ method.output.ident }}.{% if method .output .ident .python_import .module .endswith ('_pb2' ) %} FromString{% else %} deserialize{% endif %} ,
200+ )
201+ return self._stubs['{{ method.name|snake_case }}']
105202 {% - endfor %}
106203
107204
0 commit comments