|
15 | 15 | from typing import Callable |
16 | 16 | from unittest import mock |
17 | 17 |
|
| 18 | +from google.api_core import exceptions as api_core_exceptions |
18 | 19 | import pandas as pd |
19 | 20 | import pyarrow as pa |
20 | 21 | import pytest |
@@ -216,7 +217,9 @@ def test_gemini_text_generator_predict_output_schema_success( |
216 | 217 | llm_text_df: bpd.DataFrame, model_name, session, bq_connection |
217 | 218 | ): |
218 | 219 | gemini_text_generator_model = llm.GeminiTextGenerator( |
219 | | - model_name=model_name, connection_name=bq_connection, session=session |
| 220 | + model_name="gemini-2.0-flash-001", |
| 221 | + connection_name=bq_connection, |
| 222 | + session=session, |
220 | 223 | ) |
221 | 224 | output_schema = { |
222 | 225 | "bool_output": "bool", |
@@ -807,3 +810,122 @@ def test_text_embedding_generator_no_default_model_warning(model_class): |
807 | 810 | message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message." |
808 | 811 | with pytest.warns(FutureWarning, match=message): |
809 | 812 | model_class(model_name=None) |
| 813 | + |
| 814 | + |
| 815 | +@pytest.mark.flaky(retries=2) |
| 816 | +def test_gemini_text_generator_predict_struct_schema_succeeds( |
| 817 | + llm_text_df: bpd.DataFrame, session, bq_connection |
| 818 | +): |
| 819 | + gemini_text_generator_model = llm.GeminiTextGenerator( |
| 820 | + model_name="gemini-2.0-flash-001", |
| 821 | + connection_name=bq_connection, |
| 822 | + session=session, |
| 823 | + ) |
| 824 | + output_schema = { |
| 825 | + "struct_output": "struct<name string, age int64>", |
| 826 | + } |
| 827 | + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) |
| 828 | + assert set(field.name for field in df["struct_output"].dtype.pyarrow_dtype) == { |
| 829 | + "name", |
| 830 | + "age", |
| 831 | + } |
| 832 | + |
| 833 | + pd_df = df.to_pandas() |
| 834 | + utils.check_pandas_df_schema_and_index( |
| 835 | + pd_df, |
| 836 | + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], |
| 837 | + index=3, |
| 838 | + col_exact=False, |
| 839 | + ) |
| 840 | + |
| 841 | + |
| 842 | +@pytest.mark.flaky(retries=2) |
| 843 | +def test_gemini_text_generator_predict_struct_schema_flat_succeeds( |
| 844 | + llm_text_df: bpd.DataFrame, session, bq_connection |
| 845 | +): |
| 846 | + gemini_text_generator_model = llm.GeminiTextGenerator( |
| 847 | + model_name="gemini-2.0-flash-001", |
| 848 | + connection_name=bq_connection, |
| 849 | + session=session, |
| 850 | + ) |
| 851 | + output_schema = { |
| 852 | + "name": "string", |
| 853 | + "age": "int64", |
| 854 | + } |
| 855 | + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) |
| 856 | + assert df["name"].dtype == pd.StringDtype(storage="pyarrow") |
| 857 | + assert df["age"].dtype == pd.Int64Dtype() |
| 858 | + |
| 859 | + pd_df = df.to_pandas() |
| 860 | + utils.check_pandas_df_schema_and_index( |
| 861 | + pd_df, |
| 862 | + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], |
| 863 | + index=3, |
| 864 | + col_exact=False, |
| 865 | + ) |
| 866 | + |
| 867 | + |
| 868 | +@pytest.mark.flaky(retries=2) |
| 869 | +def test_gemini_text_generator_predict_array_schema_succeeds( |
| 870 | + llm_text_df: bpd.DataFrame, session, bq_connection |
| 871 | +): |
| 872 | + gemini_text_generator_model = llm.GeminiTextGenerator( |
| 873 | + model_name="gemini-2.0-flash-001", |
| 874 | + connection_name=bq_connection, |
| 875 | + session=session, |
| 876 | + ) |
| 877 | + output_schema = { |
| 878 | + "array_output": "array<string>", |
| 879 | + } |
| 880 | + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) |
| 881 | + assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.string())) |
| 882 | + |
| 883 | + pd_df = df.to_pandas() |
| 884 | + utils.check_pandas_df_schema_and_index( |
| 885 | + pd_df, |
| 886 | + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], |
| 887 | + index=3, |
| 888 | + col_exact=False, |
| 889 | + ) |
| 890 | + |
| 891 | + |
| 892 | +@pytest.mark.flaky(retries=2) |
| 893 | +def test_gemini_text_generator_predict_array_struct_schema_succeeds( |
| 894 | + llm_text_df: bpd.DataFrame, session, bq_connection |
| 895 | +): |
| 896 | + gemini_text_generator_model = llm.GeminiTextGenerator( |
| 897 | + model_name="gemini-2.0-flash-001", |
| 898 | + connection_name=bq_connection, |
| 899 | + session=session, |
| 900 | + ) |
| 901 | + output_schema = { |
| 902 | + "array_output": "array<struct<name string, age int64>>", |
| 903 | + } |
| 904 | + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) |
| 905 | + assert set( |
| 906 | + field.name for field in df["array_output"].dtype.pyarrow_dtype.value_type |
| 907 | + ) == {"name", "age"} |
| 908 | + |
| 909 | + pd_df = df.to_pandas() |
| 910 | + utils.check_pandas_df_schema_and_index( |
| 911 | + pd_df, |
| 912 | + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], |
| 913 | + index=3, |
| 914 | + col_exact=False, |
| 915 | + ) |
| 916 | + |
| 917 | + |
| 918 | +@pytest.mark.flaky(retries=2) |
| 919 | +def test_gemini_text_generator_predict_invalid_schema_fails( |
| 920 | + llm_text_df: bpd.DataFrame, session, bq_connection |
| 921 | +): |
| 922 | + gemini_text_generator_model = llm.GeminiTextGenerator( |
| 923 | + model_name="gemini-2.0-flash-001", |
| 924 | + connection_name=bq_connection, |
| 925 | + session=session, |
| 926 | + ) |
| 927 | + output_schema = { |
| 928 | + "invalid_output": "invalid_type", |
| 929 | + } |
| 930 | + with pytest.raises(api_core_exceptions.BadRequest): |
| 931 | + gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) |
0 commit comments