|
24 | 24 | import time |
25 | 25 | from unittest import mock |
26 | 26 |
|
27 | | -from google.cloud.spanner_v1 import RequestOptions |
| 27 | +from google.cloud.spanner_v1 import RequestOptions, Client |
28 | 28 |
|
29 | 29 | import sqlalchemy |
30 | 30 | from sqlalchemy import create_engine |
|
134 | 134 | UnicodeTextTest as _UnicodeTextTest, |
135 | 135 | _UnicodeFixture as __UnicodeFixture, |
136 | 136 | ) |
137 | | -from test._helpers import get_db_url |
| 137 | +from test._helpers import get_db_url, get_project |
138 | 138 |
|
139 | 139 | config.test_schema = "" |
140 | 140 |
|
@@ -2193,3 +2193,30 @@ def test_request_priority(self): |
2193 | 2193 | engine = create_engine("sqlite:///database") |
2194 | 2194 | with engine.connect() as connection: |
2195 | 2195 | pass |
| 2196 | + |
| 2197 | + |
| 2198 | +class CreateEngineWithClientObjectTest(fixtures.TestBase): |
| 2199 | + def test_create_engine_w_valid_client_object(self): |
| 2200 | + """ |
| 2201 | + SPANNER TEST: |
| 2202 | +
|
| 2203 | + Check that we can connect to SqlAlchemy |
| 2204 | + by passing custom Client object. |
| 2205 | + """ |
| 2206 | + client = Client(project=get_project()) |
| 2207 | + engine = create_engine(get_db_url(), connect_args={"client": client}) |
| 2208 | + with engine.connect() as connection: |
| 2209 | + assert connection.connection.instance._client == client |
| 2210 | + |
| 2211 | + def test_create_engine_w_invalid_client_object(self): |
| 2212 | + """ |
| 2213 | + SPANNER TEST: |
| 2214 | +
|
| 2215 | + Check that if project id in url and custom Client |
| 2216 | + Object passed to enginer mismatch, error is thrown. |
| 2217 | + """ |
| 2218 | + client = Client(project="project_id") |
| 2219 | + engine = create_engine(get_db_url(), connect_args={"client": client}) |
| 2220 | + |
| 2221 | + with pytest.raises(ValueError): |
| 2222 | + engine.connect() |
0 commit comments