@@ -46,7 +46,7 @@ def test_constructor_defaults(self):
4646 self .assertIsNone (query ._start_at )
4747 self .assertIsNone (query ._end_at )
4848
49- def _make_one_all_fields (self , limit = 9876 , offset = 12 , skip_fields = ()):
49+ def _make_one_all_fields (self , limit = 9876 , offset = 12 , skip_fields = (), parent = None ):
5050 kwargs = {
5151 "projection" : mock .sentinel .projection ,
5252 "field_filters" : mock .sentinel .filters ,
@@ -58,7 +58,9 @@ def _make_one_all_fields(self, limit=9876, offset=12, skip_fields=()):
5858 }
5959 for field in skip_fields :
6060 kwargs .pop (field )
61- return self ._make_one (mock .sentinel .parent , ** kwargs )
61+ if parent is None :
62+ parent = mock .sentinel .parent
63+ return self ._make_one (parent , ** kwargs )
6264
6365 def test_constructor_explicit (self ):
6466 limit = 234
@@ -289,10 +291,22 @@ def test_offset(self):
289291 self ._compare_queries (query2 , query3 , "_offset" )
290292
291293 @staticmethod
292- def _make_snapshot (values ):
293- from google .cloud .firestore_v1beta1 .document import DocumentSnapshot
294+ def _make_collection (* path , ** kw ):
295+ from google .cloud .firestore_v1beta1 import collection
296+
297+ return collection .CollectionReference (* path , ** kw )
294298
295- return DocumentSnapshot (None , values , True , None , None , None )
299+ @staticmethod
300+ def _make_docref (* path , ** kw ):
301+ from google .cloud .firestore_v1beta1 import document
302+
303+ return document .DocumentReference (* path , ** kw )
304+
305+ @staticmethod
306+ def _make_snapshot (docref , values ):
307+ from google .cloud .firestore_v1beta1 import document
308+
309+ return document .DocumentSnapshot (docref , values , True , None , None , None )
296310
297311 def test__cursor_helper_w_dict (self ):
298312 values = {"a" : 7 , "b" : "foo" }
@@ -349,15 +363,26 @@ def test__cursor_helper_w_list(self):
349363 self .assertIsNot (cursor , values )
350364 self .assertTrue (before )
351365
352- def test__cursor_helper_w_snapshot (self ):
366+ def test__cursor_helper_w_snapshot_wrong_collection (self ):
367+ values = {"a" : 7 , "b" : "foo" }
368+ docref = self ._make_docref ("there" , "doc_id" )
369+ snapshot = self ._make_snapshot (docref , values )
370+ collection = self ._make_collection ("here" )
371+ query = self ._make_one (collection )
353372
373+ with self .assertRaises (ValueError ):
374+ query ._cursor_helper (snapshot , False , False )
375+
376+ def test__cursor_helper_w_snapshot (self ):
354377 values = {"a" : 7 , "b" : "foo" }
355- snapshot = self ._make_snapshot (values )
356- query1 = self ._make_one (mock .sentinel .parent )
378+ docref = self ._make_docref ("here" , "doc_id" )
379+ snapshot = self ._make_snapshot (docref , values )
380+ collection = self ._make_collection ("here" )
381+ query1 = self ._make_one (collection )
357382
358383 query2 = query1 ._cursor_helper (snapshot , False , False )
359384
360- self .assertIs (query2 ._parent , mock . sentinel . parent )
385+ self .assertIs (query2 ._parent , collection )
361386 self .assertIsNone (query2 ._projection )
362387 self .assertEqual (query2 ._field_filters , ())
363388 self .assertEqual (query2 ._orders , ())
@@ -367,11 +392,12 @@ def test__cursor_helper_w_snapshot(self):
367392
368393 cursor , before = query2 ._end_at
369394
370- self .assertEqual (cursor , values )
395+ self .assertIs (cursor , snapshot )
371396 self .assertFalse (before )
372397
373398 def test_start_at (self ):
374- query1 = self ._make_one_all_fields (skip_fields = ("orders" ,))
399+ collection = self ._make_collection ("here" )
400+ query1 = self ._make_one_all_fields (parent = collection , skip_fields = ("orders" ,))
375401 query2 = query1 .order_by ("hi" )
376402
377403 document_fields3 = {"hi" : "mom" }
@@ -384,15 +410,17 @@ def test_start_at(self):
384410 # Make sure it overrides.
385411 query4 = query3 .order_by ("bye" )
386412 values5 = {"hi" : "zap" , "bye" : 88 }
387- document_fields5 = self ._make_snapshot (values5 )
413+ docref = self ._make_docref ("here" , "doc_id" )
414+ document_fields5 = self ._make_snapshot (docref , values5 )
388415 query5 = query4 .start_at (document_fields5 )
389416 self .assertIsNot (query5 , query4 )
390417 self .assertIsInstance (query5 , self ._get_target_class ())
391- self .assertEqual (query5 ._start_at , (values5 , True ))
418+ self .assertEqual (query5 ._start_at , (document_fields5 , True ))
392419 self ._compare_queries (query4 , query5 , "_start_at" )
393420
394421 def test_start_after (self ):
395- query1 = self ._make_one_all_fields (skip_fields = ("orders" ,))
422+ collection = self ._make_collection ("here" )
423+ query1 = self ._make_one_all_fields (parent = collection , skip_fields = ("orders" ,))
396424 query2 = query1 .order_by ("down" )
397425
398426 document_fields3 = {"down" : 99.75 }
@@ -405,15 +433,17 @@ def test_start_after(self):
405433 # Make sure it overrides.
406434 query4 = query3 .order_by ("out" )
407435 values5 = {"down" : 100.25 , "out" : b"\x00 \x01 " }
408- document_fields5 = self ._make_snapshot (values5 )
436+ docref = self ._make_docref ("here" , "doc_id" )
437+ document_fields5 = self ._make_snapshot (docref , values5 )
409438 query5 = query4 .start_after (document_fields5 )
410439 self .assertIsNot (query5 , query4 )
411440 self .assertIsInstance (query5 , self ._get_target_class ())
412- self .assertEqual (query5 ._start_at , (values5 , False ))
441+ self .assertEqual (query5 ._start_at , (document_fields5 , False ))
413442 self ._compare_queries (query4 , query5 , "_start_at" )
414443
415444 def test_end_before (self ):
416- query1 = self ._make_one_all_fields (skip_fields = ("orders" ,))
445+ collection = self ._make_collection ("here" )
446+ query1 = self ._make_one_all_fields (parent = collection , skip_fields = ("orders" ,))
417447 query2 = query1 .order_by ("down" )
418448
419449 document_fields3 = {"down" : 99.75 }
@@ -426,15 +456,18 @@ def test_end_before(self):
426456 # Make sure it overrides.
427457 query4 = query3 .order_by ("out" )
428458 values5 = {"down" : 100.25 , "out" : b"\x00 \x01 " }
429- document_fields5 = self ._make_snapshot (values5 )
459+ docref = self ._make_docref ("here" , "doc_id" )
460+ document_fields5 = self ._make_snapshot (docref , values5 )
430461 query5 = query4 .end_before (document_fields5 )
431462 self .assertIsNot (query5 , query4 )
432463 self .assertIsInstance (query5 , self ._get_target_class ())
433- self .assertEqual (query5 ._end_at , (values5 , True ))
464+ self .assertEqual (query5 ._end_at , (document_fields5 , True ))
465+ self ._compare_queries (query4 , query5 , "_end_at" )
434466 self ._compare_queries (query4 , query5 , "_end_at" )
435467
436468 def test_end_at (self ):
437- query1 = self ._make_one_all_fields (skip_fields = ("orders" ,))
469+ collection = self ._make_collection ("here" )
470+ query1 = self ._make_one_all_fields (parent = collection , skip_fields = ("orders" ,))
438471 query2 = query1 .order_by ("hi" )
439472
440473 document_fields3 = {"hi" : "mom" }
@@ -447,11 +480,12 @@ def test_end_at(self):
447480 # Make sure it overrides.
448481 query4 = query3 .order_by ("bye" )
449482 values5 = {"hi" : "zap" , "bye" : 88 }
450- document_fields5 = self ._make_snapshot (values5 )
483+ docref = self ._make_docref ("here" , "doc_id" )
484+ document_fields5 = self ._make_snapshot (docref , values5 )
451485 query5 = query4 .end_at (document_fields5 )
452486 self .assertIsNot (query5 , query4 )
453487 self .assertIsInstance (query5 , self ._get_target_class ())
454- self .assertEqual (query5 ._end_at , (values5 , False ))
488+ self .assertEqual (query5 ._end_at , (document_fields5 , False ))
455489 self ._compare_queries (query4 , query5 , "_end_at" )
456490
457491 def test__filters_pb_empty (self ):
@@ -530,6 +564,67 @@ def test__normalize_projection_non_empty(self):
530564 query = self ._make_one (mock .sentinel .parent )
531565 self .assertIs (query ._normalize_projection (projection ), projection )
532566
567+ def test__normalize_orders_wo_orders_wo_cursors (self ):
568+ query = self ._make_one (mock .sentinel .parent )
569+ expected = []
570+ self .assertEqual (query ._normalize_orders (), expected )
571+
572+ def test__normalize_orders_w_orders_wo_cursors (self ):
573+ query = self ._make_one (mock .sentinel .parent ).order_by ("a" )
574+ expected = [query ._make_order ("a" , "ASCENDING" )]
575+ self .assertEqual (query ._normalize_orders (), expected )
576+
577+ def test__normalize_orders_wo_orders_w_snapshot_cursor (self ):
578+ values = {"a" : 7 , "b" : "foo" }
579+ docref = self ._make_docref ("here" , "doc_id" )
580+ snapshot = self ._make_snapshot (docref , values )
581+ collection = self ._make_collection ("here" )
582+ query = self ._make_one (collection ).start_at (snapshot )
583+ expected = [query ._make_order ("__name__" , "ASCENDING" )]
584+ self .assertEqual (query ._normalize_orders (), expected )
585+
586+ def test__normalize_orders_w_name_orders_w_snapshot_cursor (self ):
587+ values = {"a" : 7 , "b" : "foo" }
588+ docref = self ._make_docref ("here" , "doc_id" )
589+ snapshot = self ._make_snapshot (docref , values )
590+ collection = self ._make_collection ("here" )
591+ query = (
592+ self ._make_one (collection )
593+ .order_by ("__name__" , "DESCENDING" )
594+ .start_at (snapshot )
595+ )
596+ expected = [query ._make_order ("__name__" , "DESCENDING" )]
597+ self .assertEqual (query ._normalize_orders (), expected )
598+
599+ def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists (self ):
600+ values = {"a" : 7 , "b" : "foo" }
601+ docref = self ._make_docref ("here" , "doc_id" )
602+ snapshot = self ._make_snapshot (docref , values )
603+ collection = self ._make_collection ("here" )
604+ query = (
605+ self ._make_one (collection )
606+ .where ("c" , "<=" , 20 )
607+ .order_by ("c" , "DESCENDING" )
608+ .start_at (snapshot )
609+ )
610+ expected = [
611+ query ._make_order ("c" , "DESCENDING" ),
612+ query ._make_order ("__name__" , "DESCENDING" ),
613+ ]
614+ self .assertEqual (query ._normalize_orders (), expected )
615+
616+ def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where (self ):
617+ values = {"a" : 7 , "b" : "foo" }
618+ docref = self ._make_docref ("here" , "doc_id" )
619+ snapshot = self ._make_snapshot (docref , values )
620+ collection = self ._make_collection ("here" )
621+ query = self ._make_one (collection ).where ("c" , "<=" , 20 ).end_at (snapshot )
622+ expected = [
623+ query ._make_order ("c" , "ASCENDING" ),
624+ query ._make_order ("__name__" , "ASCENDING" ),
625+ ]
626+ self .assertEqual (query ._normalize_orders (), expected )
627+
533628 def test__normalize_cursor_none (self ):
534629 query = self ._make_one (mock .sentinel .parent )
535630 self .assertIsNone (query ._normalize_cursor (None , query ._orders ))
@@ -603,6 +698,16 @@ def test__normalize_cursor_as_dict_hit(self):
603698
604699 self .assertEqual (query ._normalize_cursor (cursor , query ._orders ), ([1 ], True ))
605700
701+ def test__normalize_cursor_as_snapshot_hit (self ):
702+ values = {"b" : 1 }
703+ docref = self ._make_docref ("here" , "doc_id" )
704+ snapshot = self ._make_snapshot (docref , values )
705+ cursor = (snapshot , True )
706+ collection = self ._make_collection ("here" )
707+ query = self ._make_one (collection ).order_by ("b" , "ASCENDING" )
708+
709+ self .assertEqual (query ._normalize_cursor (cursor , query ._orders ), ([1 ], True ))
710+
606711 def test__normalize_cursor_w___name___w_slash (self ):
607712 db_string = "projects/my-project/database/(default)"
608713 client = mock .Mock (spec = ["_database_string" ])
@@ -1206,6 +1311,10 @@ def test_success(self):
12061311 self .assertEqual (self ._call_fut (Query .ASCENDING ), dir_class .ASCENDING )
12071312 self .assertEqual (self ._call_fut (Query .DESCENDING ), dir_class .DESCENDING )
12081313
1314+ # Ints pass through
1315+ self .assertEqual (self ._call_fut (dir_class .ASCENDING ), dir_class .ASCENDING )
1316+ self .assertEqual (self ._call_fut (dir_class .DESCENDING ), dir_class .DESCENDING )
1317+
12091318 def test_failure (self ):
12101319 with self .assertRaises (ValueError ):
12111320 self ._call_fut ("neither-ASCENDING-nor-DESCENDING" )
0 commit comments