Skip to content

Commit f52ac52

Browse files
authored
Impose required semantics for snapshots as cursors: (#6837)
- Snapshot cursors imply on ordering on '__name__', if not already present. Implied ordering is added at the end of the list, matching the direction of the prior entry ('ASCENDING' if none exist). - Snapshots copy their document reference into the '__name__' field of their document values. - Disallow use of snapshots from foreign collections as query cursors. - In a query with one or more 'where' clauses using ordering operators, and including a snapshot cursor, we must add ordering on the field(s) used (IFF the field is not already in the query's 'order_by'). Closes #6665.
1 parent ca64dbe commit f52ac52

2 files changed

Lines changed: 190 additions & 30 deletions

File tree

packages/google-cloud-firestore/google/cloud/firestore_v1beta1/query.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,14 @@ def where(self, field_path, op_string, value):
262262
end_at=self._end_at,
263263
)
264264

265+
@staticmethod
266+
def _make_order(field_path, direction):
267+
"""Helper for :meth:`order_by`."""
268+
return query_pb2.StructuredQuery.Order(
269+
field=query_pb2.StructuredQuery.FieldReference(field_path=field_path),
270+
direction=_enum_from_direction(direction),
271+
)
272+
265273
def order_by(self, field_path, direction=ASCENDING):
266274
"""Modify the query to add an order clause on a specific field.
267275
@@ -291,10 +299,7 @@ def order_by(self, field_path, direction=ASCENDING):
291299
"""
292300
field_path_module.split_field_path(field_path) # raises
293301

294-
order_pb = query_pb2.StructuredQuery.Order(
295-
field=query_pb2.StructuredQuery.FieldReference(field_path=field_path),
296-
direction=_enum_from_direction(direction),
297-
)
302+
order_pb = self._make_order(field_path, direction)
298303

299304
new_orders = self._orders + (order_pb,)
300305
return self.__class__(
@@ -388,7 +393,10 @@ def _cursor_helper(self, document_fields, before, start):
388393
if isinstance(document_fields, tuple):
389394
document_fields = list(document_fields)
390395
elif isinstance(document_fields, document.DocumentSnapshot):
391-
document_fields = document_fields.to_dict()
396+
if document_fields.reference._path[:-1] != self._parent._path:
397+
raise ValueError(
398+
"Cannot use snapshot from another collection as a cursor."
399+
)
392400
else:
393401
# NOTE: We copy so that the caller can't modify after calling.
394402
document_fields = copy.deepcopy(document_fields)
@@ -564,6 +572,40 @@ def _normalize_projection(projection):
564572

565573
return projection
566574

575+
def _normalize_orders(self):
576+
"""Helper: adjust orders based on cursors, where clauses."""
577+
orders = list(self._orders)
578+
_has_snapshot_cursor = False
579+
580+
if self._start_at:
581+
if isinstance(self._start_at[0], document.DocumentSnapshot):
582+
_has_snapshot_cursor = True
583+
584+
if self._end_at:
585+
if isinstance(self._end_at[0], document.DocumentSnapshot):
586+
_has_snapshot_cursor = True
587+
588+
if _has_snapshot_cursor:
589+
should_order = [
590+
_enum_from_op_string(key)
591+
for key in _COMPARISON_OPERATORS
592+
if key not in (_EQ_OP, "array_contains")
593+
]
594+
order_keys = [order.field.field_path for order in orders]
595+
for filter_ in self._field_filters:
596+
field = filter_.field.field_path
597+
if filter_.op in should_order and field not in order_keys:
598+
orders.append(self._make_order(field, "ASCENDING"))
599+
if not orders:
600+
orders.append(self._make_order("__name__", "ASCENDING"))
601+
else:
602+
order_keys = [order.field.field_path for order in orders]
603+
if "__name__" not in order_keys:
604+
direction = orders[-1].direction # enum?
605+
orders.append(self._make_order("__name__", direction))
606+
607+
return orders
608+
567609
def _normalize_cursor(self, cursor, orders):
568610
"""Helper: convert cursor to a list of values based on orders."""
569611
if cursor is None:
@@ -576,6 +618,11 @@ def _normalize_cursor(self, cursor, orders):
576618

577619
order_keys = [order.field.field_path for order in orders]
578620

621+
if isinstance(document_fields, document.DocumentSnapshot):
622+
snapshot = document_fields
623+
document_fields = snapshot.to_dict()
624+
document_fields["__name__"] = snapshot.reference
625+
579626
if isinstance(document_fields, dict):
580627
# Transform to list using orders
581628
values = []
@@ -616,8 +663,9 @@ def _to_protobuf(self):
616663
query protobuf.
617664
"""
618665
projection = self._normalize_projection(self._projection)
619-
start_at = self._normalize_cursor(self._start_at, self._orders)
620-
end_at = self._normalize_cursor(self._end_at, self._orders)
666+
orders = self._normalize_orders()
667+
start_at = self._normalize_cursor(self._start_at, orders)
668+
end_at = self._normalize_cursor(self._end_at, orders)
621669

622670
query_kwargs = {
623671
"select": projection,
@@ -627,7 +675,7 @@ def _to_protobuf(self):
627675
)
628676
],
629677
"where": self._filters_pb(),
630-
"order_by": self._orders,
678+
"order_by": orders,
631679
"start_at": _cursor_pb(start_at),
632680
"end_at": _cursor_pb(end_at),
633681
}
@@ -825,6 +873,9 @@ def _enum_from_direction(direction):
825873
Raises:
826874
ValueError: If ``direction`` is not a valid direction.
827875
"""
876+
if isinstance(direction, int):
877+
return direction
878+
828879
if direction == Query.ASCENDING:
829880
return enums.StructuredQuery.Direction.ASCENDING
830881
elif direction == Query.DESCENDING:

packages/google-cloud-firestore/tests/unit/test_query.py

Lines changed: 131 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_constructor_defaults(self):
4646
self.assertIsNone(query._start_at)
4747
self.assertIsNone(query._end_at)
4848

49-
def _make_one_all_fields(self, limit=9876, offset=12, skip_fields=()):
49+
def _make_one_all_fields(self, limit=9876, offset=12, skip_fields=(), parent=None):
5050
kwargs = {
5151
"projection": mock.sentinel.projection,
5252
"field_filters": mock.sentinel.filters,
@@ -58,7 +58,9 @@ def _make_one_all_fields(self, limit=9876, offset=12, skip_fields=()):
5858
}
5959
for field in skip_fields:
6060
kwargs.pop(field)
61-
return self._make_one(mock.sentinel.parent, **kwargs)
61+
if parent is None:
62+
parent = mock.sentinel.parent
63+
return self._make_one(parent, **kwargs)
6264

6365
def test_constructor_explicit(self):
6466
limit = 234
@@ -289,10 +291,22 @@ def test_offset(self):
289291
self._compare_queries(query2, query3, "_offset")
290292

291293
@staticmethod
292-
def _make_snapshot(values):
293-
from google.cloud.firestore_v1beta1.document import DocumentSnapshot
294+
def _make_collection(*path, **kw):
295+
from google.cloud.firestore_v1beta1 import collection
296+
297+
return collection.CollectionReference(*path, **kw)
294298

295-
return DocumentSnapshot(None, values, True, None, None, None)
299+
@staticmethod
300+
def _make_docref(*path, **kw):
301+
from google.cloud.firestore_v1beta1 import document
302+
303+
return document.DocumentReference(*path, **kw)
304+
305+
@staticmethod
306+
def _make_snapshot(docref, values):
307+
from google.cloud.firestore_v1beta1 import document
308+
309+
return document.DocumentSnapshot(docref, values, True, None, None, None)
296310

297311
def test__cursor_helper_w_dict(self):
298312
values = {"a": 7, "b": "foo"}
@@ -349,15 +363,26 @@ def test__cursor_helper_w_list(self):
349363
self.assertIsNot(cursor, values)
350364
self.assertTrue(before)
351365

352-
def test__cursor_helper_w_snapshot(self):
366+
def test__cursor_helper_w_snapshot_wrong_collection(self):
367+
values = {"a": 7, "b": "foo"}
368+
docref = self._make_docref("there", "doc_id")
369+
snapshot = self._make_snapshot(docref, values)
370+
collection = self._make_collection("here")
371+
query = self._make_one(collection)
353372

373+
with self.assertRaises(ValueError):
374+
query._cursor_helper(snapshot, False, False)
375+
376+
def test__cursor_helper_w_snapshot(self):
354377
values = {"a": 7, "b": "foo"}
355-
snapshot = self._make_snapshot(values)
356-
query1 = self._make_one(mock.sentinel.parent)
378+
docref = self._make_docref("here", "doc_id")
379+
snapshot = self._make_snapshot(docref, values)
380+
collection = self._make_collection("here")
381+
query1 = self._make_one(collection)
357382

358383
query2 = query1._cursor_helper(snapshot, False, False)
359384

360-
self.assertIs(query2._parent, mock.sentinel.parent)
385+
self.assertIs(query2._parent, collection)
361386
self.assertIsNone(query2._projection)
362387
self.assertEqual(query2._field_filters, ())
363388
self.assertEqual(query2._orders, ())
@@ -367,11 +392,12 @@ def test__cursor_helper_w_snapshot(self):
367392

368393
cursor, before = query2._end_at
369394

370-
self.assertEqual(cursor, values)
395+
self.assertIs(cursor, snapshot)
371396
self.assertFalse(before)
372397

373398
def test_start_at(self):
374-
query1 = self._make_one_all_fields(skip_fields=("orders",))
399+
collection = self._make_collection("here")
400+
query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",))
375401
query2 = query1.order_by("hi")
376402

377403
document_fields3 = {"hi": "mom"}
@@ -384,15 +410,17 @@ def test_start_at(self):
384410
# Make sure it overrides.
385411
query4 = query3.order_by("bye")
386412
values5 = {"hi": "zap", "bye": 88}
387-
document_fields5 = self._make_snapshot(values5)
413+
docref = self._make_docref("here", "doc_id")
414+
document_fields5 = self._make_snapshot(docref, values5)
388415
query5 = query4.start_at(document_fields5)
389416
self.assertIsNot(query5, query4)
390417
self.assertIsInstance(query5, self._get_target_class())
391-
self.assertEqual(query5._start_at, (values5, True))
418+
self.assertEqual(query5._start_at, (document_fields5, True))
392419
self._compare_queries(query4, query5, "_start_at")
393420

394421
def test_start_after(self):
395-
query1 = self._make_one_all_fields(skip_fields=("orders",))
422+
collection = self._make_collection("here")
423+
query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",))
396424
query2 = query1.order_by("down")
397425

398426
document_fields3 = {"down": 99.75}
@@ -405,15 +433,17 @@ def test_start_after(self):
405433
# Make sure it overrides.
406434
query4 = query3.order_by("out")
407435
values5 = {"down": 100.25, "out": b"\x00\x01"}
408-
document_fields5 = self._make_snapshot(values5)
436+
docref = self._make_docref("here", "doc_id")
437+
document_fields5 = self._make_snapshot(docref, values5)
409438
query5 = query4.start_after(document_fields5)
410439
self.assertIsNot(query5, query4)
411440
self.assertIsInstance(query5, self._get_target_class())
412-
self.assertEqual(query5._start_at, (values5, False))
441+
self.assertEqual(query5._start_at, (document_fields5, False))
413442
self._compare_queries(query4, query5, "_start_at")
414443

415444
def test_end_before(self):
416-
query1 = self._make_one_all_fields(skip_fields=("orders",))
445+
collection = self._make_collection("here")
446+
query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",))
417447
query2 = query1.order_by("down")
418448

419449
document_fields3 = {"down": 99.75}
@@ -426,15 +456,18 @@ def test_end_before(self):
426456
# Make sure it overrides.
427457
query4 = query3.order_by("out")
428458
values5 = {"down": 100.25, "out": b"\x00\x01"}
429-
document_fields5 = self._make_snapshot(values5)
459+
docref = self._make_docref("here", "doc_id")
460+
document_fields5 = self._make_snapshot(docref, values5)
430461
query5 = query4.end_before(document_fields5)
431462
self.assertIsNot(query5, query4)
432463
self.assertIsInstance(query5, self._get_target_class())
433-
self.assertEqual(query5._end_at, (values5, True))
464+
self.assertEqual(query5._end_at, (document_fields5, True))
465+
self._compare_queries(query4, query5, "_end_at")
434466
self._compare_queries(query4, query5, "_end_at")
435467

436468
def test_end_at(self):
437-
query1 = self._make_one_all_fields(skip_fields=("orders",))
469+
collection = self._make_collection("here")
470+
query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",))
438471
query2 = query1.order_by("hi")
439472

440473
document_fields3 = {"hi": "mom"}
@@ -447,11 +480,12 @@ def test_end_at(self):
447480
# Make sure it overrides.
448481
query4 = query3.order_by("bye")
449482
values5 = {"hi": "zap", "bye": 88}
450-
document_fields5 = self._make_snapshot(values5)
483+
docref = self._make_docref("here", "doc_id")
484+
document_fields5 = self._make_snapshot(docref, values5)
451485
query5 = query4.end_at(document_fields5)
452486
self.assertIsNot(query5, query4)
453487
self.assertIsInstance(query5, self._get_target_class())
454-
self.assertEqual(query5._end_at, (values5, False))
488+
self.assertEqual(query5._end_at, (document_fields5, False))
455489
self._compare_queries(query4, query5, "_end_at")
456490

457491
def test__filters_pb_empty(self):
@@ -530,6 +564,67 @@ def test__normalize_projection_non_empty(self):
530564
query = self._make_one(mock.sentinel.parent)
531565
self.assertIs(query._normalize_projection(projection), projection)
532566

567+
def test__normalize_orders_wo_orders_wo_cursors(self):
568+
query = self._make_one(mock.sentinel.parent)
569+
expected = []
570+
self.assertEqual(query._normalize_orders(), expected)
571+
572+
def test__normalize_orders_w_orders_wo_cursors(self):
573+
query = self._make_one(mock.sentinel.parent).order_by("a")
574+
expected = [query._make_order("a", "ASCENDING")]
575+
self.assertEqual(query._normalize_orders(), expected)
576+
577+
def test__normalize_orders_wo_orders_w_snapshot_cursor(self):
578+
values = {"a": 7, "b": "foo"}
579+
docref = self._make_docref("here", "doc_id")
580+
snapshot = self._make_snapshot(docref, values)
581+
collection = self._make_collection("here")
582+
query = self._make_one(collection).start_at(snapshot)
583+
expected = [query._make_order("__name__", "ASCENDING")]
584+
self.assertEqual(query._normalize_orders(), expected)
585+
586+
def test__normalize_orders_w_name_orders_w_snapshot_cursor(self):
587+
values = {"a": 7, "b": "foo"}
588+
docref = self._make_docref("here", "doc_id")
589+
snapshot = self._make_snapshot(docref, values)
590+
collection = self._make_collection("here")
591+
query = (
592+
self._make_one(collection)
593+
.order_by("__name__", "DESCENDING")
594+
.start_at(snapshot)
595+
)
596+
expected = [query._make_order("__name__", "DESCENDING")]
597+
self.assertEqual(query._normalize_orders(), expected)
598+
599+
def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists(self):
600+
values = {"a": 7, "b": "foo"}
601+
docref = self._make_docref("here", "doc_id")
602+
snapshot = self._make_snapshot(docref, values)
603+
collection = self._make_collection("here")
604+
query = (
605+
self._make_one(collection)
606+
.where("c", "<=", 20)
607+
.order_by("c", "DESCENDING")
608+
.start_at(snapshot)
609+
)
610+
expected = [
611+
query._make_order("c", "DESCENDING"),
612+
query._make_order("__name__", "DESCENDING"),
613+
]
614+
self.assertEqual(query._normalize_orders(), expected)
615+
616+
def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where(self):
617+
values = {"a": 7, "b": "foo"}
618+
docref = self._make_docref("here", "doc_id")
619+
snapshot = self._make_snapshot(docref, values)
620+
collection = self._make_collection("here")
621+
query = self._make_one(collection).where("c", "<=", 20).end_at(snapshot)
622+
expected = [
623+
query._make_order("c", "ASCENDING"),
624+
query._make_order("__name__", "ASCENDING"),
625+
]
626+
self.assertEqual(query._normalize_orders(), expected)
627+
533628
def test__normalize_cursor_none(self):
534629
query = self._make_one(mock.sentinel.parent)
535630
self.assertIsNone(query._normalize_cursor(None, query._orders))
@@ -603,6 +698,16 @@ def test__normalize_cursor_as_dict_hit(self):
603698

604699
self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True))
605700

701+
def test__normalize_cursor_as_snapshot_hit(self):
702+
values = {"b": 1}
703+
docref = self._make_docref("here", "doc_id")
704+
snapshot = self._make_snapshot(docref, values)
705+
cursor = (snapshot, True)
706+
collection = self._make_collection("here")
707+
query = self._make_one(collection).order_by("b", "ASCENDING")
708+
709+
self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True))
710+
606711
def test__normalize_cursor_w___name___w_slash(self):
607712
db_string = "projects/my-project/database/(default)"
608713
client = mock.Mock(spec=["_database_string"])
@@ -1206,6 +1311,10 @@ def test_success(self):
12061311
self.assertEqual(self._call_fut(Query.ASCENDING), dir_class.ASCENDING)
12071312
self.assertEqual(self._call_fut(Query.DESCENDING), dir_class.DESCENDING)
12081313

1314+
# Ints pass through
1315+
self.assertEqual(self._call_fut(dir_class.ASCENDING), dir_class.ASCENDING)
1316+
self.assertEqual(self._call_fut(dir_class.DESCENDING), dir_class.DESCENDING)
1317+
12091318
def test_failure(self):
12101319
with self.assertRaises(ValueError):
12111320
self._call_fut("neither-ASCENDING-nor-DESCENDING")

0 commit comments

Comments
 (0)