Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit bc92e7d

Browse files
feat: add GOOGLE_API_USE_MTLS support
1 parent 117d110 commit bc92e7d

6 files changed

Lines changed: 248 additions & 111 deletions

File tree

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
{% block content %}
44
from collections import OrderedDict
5+
import os
56
import re
67
from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
78
import pkg_resources
@@ -11,6 +12,8 @@ from google.api_core import exceptions # type: ignore
1112
from google.api_core import gapic_v1 # type: ignore
1213
from google.api_core import retry as retries # type: ignore
1314
from google.auth import credentials # type: ignore
15+
from google.auth.transport import mtls # type: ignore
16+
from google.auth.exceptions import MutualTLSChannelError # type: ignore
1417
from google.oauth2 import service_account # type: ignore
1518

1619
{% filter sort_lines -%}
@@ -154,11 +157,32 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
154157
SSL credentials obtained from ``client_cert_source``.
155158

156159
Raises:
157-
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
160+
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
158161
creation failed for any reason.
159162
"""
160163
if isinstance(client_options, dict):
161164
client_options = ClientOptions.from_dict(client_options)
165+
if client_options is None:
166+
client_options = ClientOptions.ClientOptions()
167+
168+
if transport is None and client_options.api_endpoint is None:
169+
use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "Never")
170+
if use_mtls_env == "Never":
171+
client_options.api_endpoint = self.DEFAULT_ENDPOINT
172+
elif use_mtls_env == "Always":
173+
client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT
174+
elif use_mtls_env == "Auto":
175+
has_client_cert_source = (
176+
client_options.client_cert_source is not None
177+
or mtls.has_default_client_cert_source()
178+
)
179+
client_options.api_endpoint = (
180+
self.DEFAULT_MTLS_ENDPOINT if has_client_cert_source else self.DEFAULT_ENDPOINT
181+
)
182+
else:
183+
raise MutualTLSChannelError(
184+
"Unsupported GOOGLE_API_USE_MTLS value. Accepted values: Never, Auto, Always"
185+
)
162186

163187
# Save or instantiate the transport.
164188
# Ordinarily, we provide the transport, but allowing a custom transport
@@ -169,38 +193,16 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
169193
raise ValueError('When providing a transport instance, '
170194
'provide its credentials directly.')
171195
self._transport = transport
172-
elif client_options is None or (
173-
client_options.api_endpoint is None
174-
and client_options.client_cert_source is None
175-
):
176-
# Don't trigger mTLS if we get an empty ClientOptions.
196+
elif isinstance(transport, str):
177197
Transport = type(self).get_transport_class(transport)
178198
self._transport = Transport(
179199
credentials=credentials, host=self.DEFAULT_ENDPOINT
180200
)
181201
else:
182-
# We have a non-empty ClientOptions. If client_cert_source is
183-
# provided, trigger mTLS with user provided endpoint or the default
184-
# mTLS endpoint.
185-
if client_options.client_cert_source:
186-
api_mtls_endpoint = (
187-
client_options.api_endpoint
188-
if client_options.api_endpoint
189-
else self.DEFAULT_MTLS_ENDPOINT
190-
)
191-
else:
192-
api_mtls_endpoint = None
193-
194-
api_endpoint = (
195-
client_options.api_endpoint
196-
if client_options.api_endpoint
197-
else self.DEFAULT_ENDPOINT
198-
)
199-
200202
self._transport = {{ service.name }}GrpcTransport(
201203
credentials=credentials,
202-
host=api_endpoint,
203-
api_mtls_endpoint=api_mtls_endpoint,
204+
host=client_options.api_endpoint,
205+
api_mtls_endpoint=client_options.api_endpoint,
204206
client_cert_source=client_options.client_cert_source,
205207
)
206208

gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from google.api_core import grpc_helpers # type: ignore
77
{%- if service.has_lro %}
88
from google.api_core import operations_v1 # type: ignore
99
{%- endif %}
10+
from google import auth # type: ignore
1011
from google.auth import credentials # type: ignore
1112
from google.auth.transport.grpc import SslCredentials # type: ignore
1213

@@ -63,7 +64,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
6364
is None.
6465

6566
Raises:
66-
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
67+
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
6768
creation failed for any reason.
6869
"""
6970
if channel:
@@ -76,6 +77,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
7677
elif api_mtls_endpoint:
7778
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
7879

80+
if credentials is None:
81+
credentials, _ = auth.default(scopes=self.AUTH_SCOPES)
82+
7983
# Create SSL credentials with client_cert_source or application
8084
# default SSL credentials.
8185
if client_cert_source:
@@ -96,7 +100,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
96100

97101
# Run the base constructor.
98102
super().__init__(host=host, credentials=credentials)
99-
self._stubs = {} # type: Dict[str, Callable]
103+
self._stubs = {} # type: Dict[str, Callable]
100104

101105

102106
@classmethod

gapic/ads-templates/tests/unit/%name_%version/%sub/test_%service.py.j2

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{% extends "_base.py.j2" %}
22

33
{% block content %}
4+
import os
45
from unittest import mock
56

67
import grpc
@@ -11,6 +12,7 @@ import pytest
1112
{% filter sort_lines -%}
1213
from google import auth
1314
from google.auth import credentials
15+
from google.auth.exceptions import MutualTLSChannelError
1416
from google.oauth2 import service_account
1517
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }}
1618
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports
@@ -63,6 +65,14 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():
6365
{% if service.host %}assert client._transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %}
6466

6567

68+
def test_{{ service.client_name|snake_case }}_get_transport_class():
69+
transport = {{ service.client_name }}.get_transport_class()
70+
assert transport == transports.{{ service.name }}GrpcTransport
71+
72+
transport = {{ service.client_name }}.get_transport_class("grpc")
73+
assert transport == transports.{{ service.name }}GrpcTransport
74+
75+
6676
def test_{{ service.client_name|snake_case }}_client_options():
6777
# Check that if channel is provided we won't create a new one.
6878
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
@@ -72,58 +82,99 @@ def test_{{ service.client_name|snake_case }}_client_options():
7282
client = {{ service.client_name }}(transport=transport)
7383
gtc.assert_not_called()
7484

75-
# Check mTLS is not triggered with empty client options.
76-
options = client_options.ClientOptions()
85+
# Check that if channel is provided via str we will create a new one.
7786
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
78-
transport = gtc.return_value = mock.MagicMock()
79-
client = {{ service.client_name }}(client_options=options)
80-
transport.assert_called_once_with(
81-
credentials=None,
82-
host=client.DEFAULT_ENDPOINT,
83-
)
87+
client = {{ service.client_name }}(transport="grpc")
88+
gtc.assert_called()
8489

85-
# Check mTLS is not triggered if api_endpoint is provided but
86-
# client_cert_source is None.
90+
# Check the case api_endpoint is provided.
8791
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
8892
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
8993
grpc_transport.return_value = None
9094
client = {{ service.client_name }}(client_options=options)
9195
grpc_transport.assert_called_once_with(
92-
api_mtls_endpoint=None,
96+
api_mtls_endpoint="squid.clam.whelk",
9397
client_cert_source=None,
9498
credentials=None,
9599
host="squid.clam.whelk",
96100
)
97101

98-
# Check mTLS is triggered if client_cert_source is provided.
99-
options = client_options.ClientOptions(
100-
client_cert_source=client_cert_source_callback
101-
)
102+
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
103+
# "Never".
104+
os.environ["GOOGLE_API_USE_MTLS"] = "Never"
102105
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
103106
grpc_transport.return_value = None
104-
client = {{ service.client_name }}(client_options=options)
107+
client = {{ service.client_name }}()
105108
grpc_transport.assert_called_once_with(
106-
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
107-
client_cert_source=client_cert_source_callback,
109+
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
110+
client_cert_source=None,
108111
credentials=None,
109112
host=client.DEFAULT_ENDPOINT,
110113
)
111114

112-
# Check mTLS is triggered if api_endpoint and client_cert_source are provided.
113-
options = client_options.ClientOptions(
114-
api_endpoint="squid.clam.whelk",
115-
client_cert_source=client_cert_source_callback
116-
)
115+
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
116+
# "Always".
117+
os.environ["GOOGLE_API_USE_MTLS"] = "Always"
118+
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
119+
grpc_transport.return_value = None
120+
client = {{ service.client_name }}()
121+
grpc_transport.assert_called_once_with(
122+
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
123+
client_cert_source=None,
124+
credentials=None,
125+
host=client.DEFAULT_MTLS_ENDPOINT,
126+
)
127+
128+
# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
129+
# "Auto", and client_cert_source is provided.
130+
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
131+
options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
117132
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
118133
grpc_transport.return_value = None
119134
client = {{ service.client_name }}(client_options=options)
120135
grpc_transport.assert_called_once_with(
121-
api_mtls_endpoint="squid.clam.whelk",
136+
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
122137
client_cert_source=client_cert_source_callback,
123138
credentials=None,
124-
host="squid.clam.whelk",
139+
host=client.DEFAULT_MTLS_ENDPOINT,
125140
)
126141

142+
# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
143+
# "Auto", and default_client_cert_source is provided.
144+
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
145+
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
146+
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True):
147+
grpc_transport.return_value = None
148+
client = {{ service.client_name }}()
149+
grpc_transport.assert_called_once_with(
150+
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
151+
client_cert_source=None,
152+
credentials=None,
153+
host=client.DEFAULT_MTLS_ENDPOINT,
154+
)
155+
156+
# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
157+
# "Auto", but client_cert_source and default_client_cert_source are None.
158+
os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
159+
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
160+
with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False):
161+
grpc_transport.return_value = None
162+
client = {{ service.client_name }}()
163+
grpc_transport.assert_called_once_with(
164+
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
165+
client_cert_source=None,
166+
credentials=None,
167+
host=client.DEFAULT_ENDPOINT,
168+
)
169+
170+
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
171+
# unsupported value.
172+
os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported"
173+
with pytest.raises(MutualTLSChannelError):
174+
client = {{ service.client_name }}()
175+
176+
del os.environ["GOOGLE_API_USE_MTLS"]
177+
127178

128179
def test_{{ service.client_name|snake_case }}_client_options_from_dict():
129180
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
@@ -132,7 +183,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
132183
client_options={'api_endpoint': 'squid.clam.whelk'}
133184
)
134185
grpc_transport.assert_called_once_with(
135-
api_mtls_endpoint=None,
186+
api_mtls_endpoint="squid.clam.whelk",
136187
client_cert_source=None,
137188
credentials=None,
138189
host="squid.clam.whelk",
@@ -490,12 +541,24 @@ def test_{{ service.name|snake_case }}_auth_adc():
490541
))
491542

492543

544+
def test_{{ service.name|snake_case }}_transport_auth_adc():
545+
# If credentials and host are not provided, the transport class should use
546+
# ADC credentials.
547+
with mock.patch.object(auth, 'default') as adc:
548+
adc.return_value = (credentials.AnonymousCredentials(), None)
549+
transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk")
550+
adc.assert_called_once_with(scopes=(
551+
{%- for scope in service.oauth_scopes %}
552+
'{{ scope }}',
553+
{%- endfor %}
554+
))
555+
556+
493557
def test_{{ service.name|snake_case }}_host_no_port():
494558
{% with host = (service.host|default('localhost', true)).split(':')[0] -%}
495559
client = {{ service.client_name }}(
496560
credentials=credentials.AnonymousCredentials(),
497561
client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
498-
transport='grpc',
499562
)
500563
assert client._transport._host == '{{ host }}:443'
501564
{% endwith %}
@@ -506,7 +569,6 @@ def test_{{ service.name|snake_case }}_host_with_port():
506569
client = {{ service.client_name }}(
507570
credentials=credentials.AnonymousCredentials(),
508571
client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
509-
transport='grpc',
510572
)
511573
assert client._transport._host == '{{ host }}:8000'
512574
{% endwith %}

0 commit comments

Comments
 (0)