|
24 | 24 | import time |
25 | 25 | from unittest import mock |
26 | 26 |
|
27 | | -from google.cloud.spanner_v1 import RequestOptions |
| 27 | +from google.cloud.spanner_v1 import RequestOptions, Client |
28 | 28 | import sqlalchemy |
29 | 29 | from sqlalchemy import create_engine |
30 | 30 | from sqlalchemy.engine import Inspector |
|
144 | 144 | UnicodeTextTest as _UnicodeTextTest, |
145 | 145 | _UnicodeFixture as __UnicodeFixture, |
146 | 146 | ) # noqa: F401, F403 |
147 | | -from test._helpers import get_db_url |
| 147 | +from test._helpers import get_db_url, get_project |
148 | 148 |
|
149 | 149 | config.test_schema = "" |
150 | 150 |
|
@@ -3000,3 +3000,44 @@ def test_request_priority(self): |
3000 | 3000 | engine = create_engine("sqlite:///database") |
3001 | 3001 | with engine.connect() as connection: |
3002 | 3002 | pass |
| 3003 | + |
| 3004 | + |
| 3005 | +class CreateEngineWithClientObjectTest(fixtures.TestBase): |
| 3006 | + def test_create_engine_w_valid_client_object(self): |
| 3007 | + """ |
| 3008 | + SPANNER TEST: |
| 3009 | +
|
| 3010 | + Check that we can connect to SqlAlchemy |
| 3011 | + by passing custom Client object. |
| 3012 | + """ |
| 3013 | + client = Client(project=get_project()) |
| 3014 | + engine = create_engine(get_db_url(), connect_args={"client": client}) |
| 3015 | + with engine.connect() as connection: |
| 3016 | + assert connection.connection.instance._client == client |
| 3017 | + |
| 3018 | + def test_create_engine_w_invalid_client_object(self): |
| 3019 | + """ |
| 3020 | + SPANNER TEST: |
| 3021 | +
|
| 3022 | + Check that if project id in url and custom Client |
| 3023 | + Object passed to enginer mismatch, error is thrown. |
| 3024 | + """ |
| 3025 | + client = Client(project="project_id") |
| 3026 | + engine = create_engine(get_db_url(), connect_args={"client": client}) |
| 3027 | + |
| 3028 | + with pytest.raises(ValueError): |
| 3029 | + engine.connect() |
| 3030 | + |
| 3031 | + |
| 3032 | +class CreateEngineWithoutDatabaseTest(fixtures.TestBase): |
| 3033 | + def test_create_engine_wo_database(self): |
| 3034 | + """ |
| 3035 | + SPANNER TEST: |
| 3036 | +
|
| 3037 | + Check that we can connect to SqlAlchemy |
| 3038 | + without passing database id in the |
| 3039 | + connection URL. |
| 3040 | + """ |
| 3041 | + engine = create_engine(get_db_url().split("/database")[0]) |
| 3042 | + with engine.connect() as connection: |
| 3043 | + assert connection.connection.database is None |
0 commit comments