Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 179f60d

Browse files
committed
support ai.generate
1 parent 9af7130 commit 179f60d

7 files changed

Lines changed: 182 additions & 3 deletions

File tree

bigframes/ml/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,21 @@ def generate_table(
217217

218218
generate_table_tvf = TvfDef(generate_table, "status")
219219

220+
def ai_generate(
221+
self,
222+
input_data: bpd.DataFrame,
223+
options: dict[str, Union[int, float, bool, Mapping]],
224+
) -> bpd.DataFrame:
225+
return self._apply_ml_tvf(
226+
input_data,
227+
lambda source_sql: self._sql_generator.ai_generate(
228+
source_sql=source_sql,
229+
struct_options=options,
230+
),
231+
)
232+
233+
ai_generate_tvf = TvfDef(ai_generate, "status")
234+
220235
def detect_anomalies(
221236
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
222237
) -> bpd.DataFrame:

bigframes/ml/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,9 @@ def predict(
734734
output_schema = {
735735
k: utils.standardize_type(v) for k, v in output_schema.items()
736736
}
737-
options["output_schema"] = output_schema
737+
options["output_schema"] = {
738+
k: utils.standardize_type(v) for k, v in output_schema.items()
739+
}
738740
return self._predict_and_retry(
739741
core.BqmlModel.generate_table_tvf,
740742
X,

bigframes/ml/sql.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,13 @@ def ai_generate_table(
435435
struct_options_sql = self.struct_options(**struct_options)
436436
return f"""SELECT * FROM AI.GENERATE_TABLE(MODEL {self._model_ref_sql()},
437437
({source_sql}), {struct_options_sql})"""
438+
439+
def ai_generate(
440+
self,
441+
source_sql: str,
442+
struct_options: Mapping[str, Union[int, float, bool, Mapping]],
443+
) -> str:
444+
"""Encode AI.GENERATE for BQML"""
445+
struct_options_sql = self.struct_options(**struct_options)
446+
return f"""SELECT * FROM AI.GENERATE(MODEL {self._model_ref_sql()},
447+
({source_sql}), {struct_options_sql})"""

bigframes/ml/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,16 @@ def combine_training_and_evaluation_data(
191191

192192

193193
def standardize_type(v: str, supported_dtypes: Optional[Iterable[str]] = None):
194+
"""Standardize type string to BQML supported type string."""
194195
t = v.lower()
195-
t = t.replace("boolean", "bool")
196+
if t == "boolean":
197+
t = "bool"
198+
elif t == "integer":
199+
t = "int64"
200+
elif t == "str":
201+
t = "string"
202+
elif t == "float":
203+
t = "float64"
196204

197205
if supported_dtypes:
198206
if t not in supported_dtypes:

bigframes/testing/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
"ml_generate_text_status",
5050
"prompt",
5151
]
52+
AI_GENERATE_OUTPUT = [
53+
"result",
54+
"full_response",
55+
"status",
56+
]
5257
ML_GENERATE_EMBEDDING_OUTPUT = [
5358
"ml_generate_embedding_result",
5459
"ml_generate_embedding_statistics",

tests/system/small/ml/test_llm.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Callable
1616
from unittest import mock
1717

18+
from google.api_core import exceptions as api_core_exceptions
1819
import pandas as pd
1920
import pyarrow as pa
2021
import pytest
@@ -216,7 +217,9 @@ def test_gemini_text_generator_predict_output_schema_success(
216217
llm_text_df: bpd.DataFrame, model_name, session, bq_connection
217218
):
218219
gemini_text_generator_model = llm.GeminiTextGenerator(
219-
model_name=model_name, connection_name=bq_connection, session=session
220+
model_name="gemini-2.0-flash-001",
221+
connection_name=bq_connection,
222+
session=session,
220223
)
221224
output_schema = {
222225
"bool_output": "bool",
@@ -807,3 +810,122 @@ def test_text_embedding_generator_no_default_model_warning(model_class):
807810
message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message."
808811
with pytest.warns(FutureWarning, match=message):
809812
model_class(model_name=None)
813+
814+
815+
@pytest.mark.flaky(retries=2)
816+
def test_gemini_text_generator_predict_struct_schema_succeeds(
817+
llm_text_df: bpd.DataFrame, session, bq_connection
818+
):
819+
gemini_text_generator_model = llm.GeminiTextGenerator(
820+
model_name="gemini-2.0-flash-001",
821+
connection_name=bq_connection,
822+
session=session,
823+
)
824+
output_schema = {
825+
"struct_output": "struct<name string, age int64>",
826+
}
827+
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
828+
assert set(field.name for field in df["struct_output"].dtype.pyarrow_dtype) == {
829+
"name",
830+
"age",
831+
}
832+
833+
pd_df = df.to_pandas()
834+
utils.check_pandas_df_schema_and_index(
835+
pd_df,
836+
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
837+
index=3,
838+
col_exact=False,
839+
)
840+
841+
842+
@pytest.mark.flaky(retries=2)
843+
def test_gemini_text_generator_predict_struct_schema_flat_succeeds(
844+
llm_text_df: bpd.DataFrame, session, bq_connection
845+
):
846+
gemini_text_generator_model = llm.GeminiTextGenerator(
847+
model_name="gemini-2.0-flash-001",
848+
connection_name=bq_connection,
849+
session=session,
850+
)
851+
output_schema = {
852+
"name": "string",
853+
"age": "int64",
854+
}
855+
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
856+
assert df["name"].dtype == pd.StringDtype(storage="pyarrow")
857+
assert df["age"].dtype == pd.Int64Dtype()
858+
859+
pd_df = df.to_pandas()
860+
utils.check_pandas_df_schema_and_index(
861+
pd_df,
862+
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
863+
index=3,
864+
col_exact=False,
865+
)
866+
867+
868+
@pytest.mark.flaky(retries=2)
869+
def test_gemini_text_generator_predict_array_schema_succeeds(
870+
llm_text_df: bpd.DataFrame, session, bq_connection
871+
):
872+
gemini_text_generator_model = llm.GeminiTextGenerator(
873+
model_name="gemini-2.0-flash-001",
874+
connection_name=bq_connection,
875+
session=session,
876+
)
877+
output_schema = {
878+
"array_output": "array<string>",
879+
}
880+
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
881+
assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.string()))
882+
883+
pd_df = df.to_pandas()
884+
utils.check_pandas_df_schema_and_index(
885+
pd_df,
886+
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
887+
index=3,
888+
col_exact=False,
889+
)
890+
891+
892+
@pytest.mark.flaky(retries=2)
893+
def test_gemini_text_generator_predict_array_struct_schema_succeeds(
894+
llm_text_df: bpd.DataFrame, session, bq_connection
895+
):
896+
gemini_text_generator_model = llm.GeminiTextGenerator(
897+
model_name="gemini-2.0-flash-001",
898+
connection_name=bq_connection,
899+
session=session,
900+
)
901+
output_schema = {
902+
"array_output": "array<struct<name string, age int64>>",
903+
}
904+
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
905+
assert set(
906+
field.name for field in df["array_output"].dtype.pyarrow_dtype.value_type
907+
) == {"name", "age"}
908+
909+
pd_df = df.to_pandas()
910+
utils.check_pandas_df_schema_and_index(
911+
pd_df,
912+
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
913+
index=3,
914+
col_exact=False,
915+
)
916+
917+
918+
@pytest.mark.flaky(retries=2)
919+
def test_gemini_text_generator_predict_invalid_schema_fails(
920+
llm_text_df: bpd.DataFrame, session, bq_connection
921+
):
922+
gemini_text_generator_model = llm.GeminiTextGenerator(
923+
model_name="gemini-2.0-flash-001",
924+
connection_name=bq_connection,
925+
session=session,
926+
)
927+
output_schema = {
928+
"invalid_output": "invalid_type",
929+
}
930+
with pytest.raises(api_core_exceptions.BadRequest):
931+
gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)

tests/unit/ml/test_sql.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,20 @@ def test_ml_principal_component_info_correct(
529529
sql
530530
== """SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)"""
531531
)
532+
533+
534+
def test_ai_generate_correct(
535+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
536+
mock_df: bpd.DataFrame,
537+
):
538+
sql = model_manipulation_sql_generator.ai_generate(
539+
source_sql=mock_df.sql,
540+
struct_options={"option_key1": 1, "option_key2": 2.2},
541+
)
542+
assert (
543+
sql
544+
== """SELECT * FROM AI.GENERATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
545+
(input_X_y_sql), STRUCT(
546+
1 AS `option_key1`,
547+
2.2 AS `option_key2`))"""
548+
)

0 commit comments

Comments
 (0)