@@ -84,6 +84,12 @@ def _load_testproto(filename):
8484 if test_proto .WhichOneof ("test" ) == "listen"
8585]
8686
87+ _QUERY_TESTPROTOS = [
88+ test_proto
89+ for test_proto in ALL_TESTPROTOS
90+ if test_proto .WhichOneof ("test" ) == "query"
91+ ]
92+
8793
8894def _mock_firestore_api ():
8995 firestore_api = mock .Mock (spec = ["commit" ])
@@ -201,10 +207,23 @@ def test_delete_testprotos(test_proto):
201207
202208@pytest .mark .skip (reason = "Watch aka listen not yet implemented in Python." )
203209@pytest .mark .parametrize ("test_proto" , _LISTEN_TESTPROTOS )
204- def test_listen_paths_testprotos (test_proto ): # pragma: NO COVER
210+ def test_listen_testprotos (test_proto ): # pragma: NO COVER
205211 pass
206212
207213
214+ @pytest .mark .parametrize ("test_proto" , _QUERY_TESTPROTOS )
215+ def test_query_testprotos (test_proto ): # pragma: NO COVER
216+ testcase = test_proto .query
217+ if testcase .is_error :
218+ with pytest .raises (Exception ):
219+ query = parse_query (testcase )
220+ query ._to_protobuf ()
221+ else :
222+ query = parse_query (testcase )
223+ found = query ._to_protobuf ()
224+ assert found == testcase .query
225+
226+
208227def convert_data (v ):
209228 # Replace the strings 'ServerTimestamp' and 'Delete' with the corresponding
210229 # sentinels.
@@ -225,6 +244,8 @@ def convert_data(v):
225244 return [convert_data (e ) for e in v ]
226245 elif isinstance (v , dict ):
227246 return {k : convert_data (v2 ) for k , v2 in v .items ()}
247+ elif v == "NaN" :
248+ return float (v )
228249 else :
229250 return v
230251
@@ -249,3 +270,106 @@ def convert_precondition(precond):
249270
250271 assert precond .HasField ("update_time" )
251272 return Client .write_option (last_update_time = precond .update_time )
273+
274+
275+ def parse_query (testcase ):
276+ # 'query' testcase contains:
277+ # - 'coll_path': collection ref path.
278+ # - 'clauses': array of one or more 'Clause' elements
279+ # - 'query': the actual google.firestore.v1beta1.StructuredQuery message
280+ # to be constructed.
281+ # - 'is_error' (as other testcases).
282+ #
283+ # 'Clause' elements are unions of:
284+ # - 'select': [field paths]
285+ # - 'where': (field_path, op, json_value)
286+ # - 'order_by': (field_path, direction)
287+ # - 'offset': int
288+ # - 'limit': int
289+ # - 'start_at': 'Cursor'
290+ # - 'start_after': 'Cursor'
291+ # - 'end_at': 'Cursor'
292+ # - 'end_before': 'Cursor'
293+ #
294+ # 'Cursor' contains either:
295+ # - 'doc_snapshot': 'DocSnapshot'
296+ # - 'json_values': [string]
297+ #
298+ # 'DocSnapshot' contains:
299+ # 'path': str
300+ # 'json_data': str
301+ from google .auth .credentials import Credentials
302+ from google .cloud .firestore_v1beta1 import Client
303+ from google .cloud .firestore_v1beta1 import Query
304+
305+ _directions = {"asc" : Query .ASCENDING , "desc" : Query .DESCENDING }
306+
307+ credentials = mock .create_autospec (Credentials )
308+ client = Client ("projectID" , credentials )
309+ path = parse_path (testcase .coll_path )
310+ collection = client .collection (* path )
311+ query = collection
312+
313+ for clause in testcase .clauses :
314+ kind = clause .WhichOneof ("clause" )
315+
316+ if kind == "select" :
317+ field_paths = [
318+ "." .join (field_path .field ) for field_path in clause .select .fields
319+ ]
320+ query = query .select (field_paths )
321+ elif kind == "where" :
322+ path = "." .join (clause .where .path .field )
323+ value = convert_data (json .loads (clause .where .json_value ))
324+ query = query .where (path , clause .where .op , value )
325+ elif kind == "order_by" :
326+ path = "." .join (clause .order_by .path .field )
327+ direction = clause .order_by .direction
328+ direction = _directions .get (direction , direction )
329+ query = query .order_by (path , direction = direction )
330+ elif kind == "offset" :
331+ query = query .offset (clause .offset )
332+ elif kind == "limit" :
333+ query = query .limit (clause .limit )
334+ elif kind == "start_at" :
335+ cursor = parse_cursor (clause .start_at , client )
336+ query = query .start_at (cursor )
337+ elif kind == "start_after" :
338+ cursor = parse_cursor (clause .start_after , client )
339+ query = query .start_after (cursor )
340+ elif kind == "end_at" :
341+ cursor = parse_cursor (clause .end_at , client )
342+ query = query .end_at (cursor )
343+ elif kind == "end_before" :
344+ cursor = parse_cursor (clause .end_before , client )
345+ query = query .end_before (cursor )
346+ else : # pragma: NO COVER
347+ raise ValueError ("Unknown query clause: {}" .format (kind ))
348+
349+ return query
350+
351+
352+ def parse_path (path ):
353+ _ , relative = path .split ("documents/" )
354+ return relative .split ("/" )
355+
356+
357+ def parse_cursor (cursor , client ):
358+ from google .cloud .firestore_v1beta1 import DocumentReference
359+ from google .cloud .firestore_v1beta1 import DocumentSnapshot
360+
361+ if cursor .HasField ("doc_snapshot" ):
362+ path = parse_path (cursor .doc_snapshot .path )
363+ doc_ref = DocumentReference (* path , client = client )
364+
365+ return DocumentSnapshot (
366+ reference = doc_ref ,
367+ data = json .loads (cursor .doc_snapshot .json_data ),
368+ exists = True ,
369+ read_time = None ,
370+ create_time = None ,
371+ update_time = None ,
372+ )
373+
374+ values = [json .loads (value ) for value in cursor .json_values ]
375+ return convert_data (values )
0 commit comments