Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 63f0527

Browse files
committed
feat: add bigquery.ml.generate_embedding function
1 parent 248c8ea commit 63f0527

4 files changed

Lines changed: 136 additions & 0 deletions

File tree

bigframes/bigquery/_operations/ml.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,63 @@ def generate_text(
520520
return bpd.read_gbq_query(sql)
521521
else:
522522
return session.read_gbq_query(sql)
523+
524+
525+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
526+
def generate_embedding(
527+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
528+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
529+
*,
530+
flatten_json_output: Optional[bool] = None,
531+
task_type: Optional[str] = None,
532+
output_dimensionality: Optional[int] = None,
533+
) -> dataframe.DataFrame:
534+
"""
535+
Generates text embedding using a BigQuery ML model.
536+
537+
See the `BigQuery ML GENERATE_EMBEDDING function syntax
538+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding>`_
539+
for additional reference.
540+
541+
Args:
542+
model (bigframes.ml.base.BaseEstimator or str):
543+
The model to use for text embedding.
544+
input_ (Union[bigframes.pandas.DataFrame, str]):
545+
The DataFrame or query to use for text embedding.
546+
flatten_json_output (bool, optional):
547+
A BOOL value that determines the content of the generated JSON column.
548+
task_type (str, optional):
549+
A STRING value that specifies the intended downstream application task.
550+
Supported values are:
551+
- `RETRIEVAL_QUERY`
552+
- `RETRIEVAL_DOCUMENT`
553+
- `SEMANTIC_SIMILARITY`
554+
- `CLASSIFICATION`
555+
- `CLUSTERING`
556+
- `QUESTION_ANSWERING`
557+
- `FACT_VERIFICATION`
558+
- `CODE_RETRIEVAL_QUERY`
559+
output_dimensionality (int, optional):
560+
An INT64 value that specifies the size of the output embedding.
561+
562+
Returns:
563+
bigframes.pandas.DataFrame:
564+
The generated text embedding.
565+
"""
566+
import bigframes.pandas as bpd
567+
568+
model_name, session = _get_model_name_and_session(model, input_)
569+
table_sql = _to_sql(input_)
570+
571+
sql = bigframes.core.sql.ml.generate_embedding(
572+
model_name=model_name,
573+
table=table_sql,
574+
flatten_json_output=flatten_json_output,
575+
task_type=task_type,
576+
output_dimensionality=output_dimensionality,
577+
)
578+
579+
if session is None:
580+
return bpd.read_gbq_query(sql)
581+
else:
582+
return session.read_gbq_query(sql)

bigframes/core/sql/ml.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,31 @@ def generate_text(
296296
sql += _build_struct_sql(struct_options)
297297
sql += ")\n"
298298
return sql
299+
300+
301+
def generate_embedding(
302+
model_name: str,
303+
table: str,
304+
*,
305+
flatten_json_output: Optional[bool] = None,
306+
task_type: Optional[str] = None,
307+
output_dimensionality: Optional[int] = None,
308+
) -> str:
309+
"""Encode the ML.GENERATE_EMBEDDING statement.
310+
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding for reference.
311+
"""
312+
struct_options: Dict[
313+
str,
314+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
315+
] = {}
316+
if flatten_json_output is not None:
317+
struct_options["flatten_json_output"] = flatten_json_output
318+
if task_type is not None:
319+
struct_options["task_type"] = task_type
320+
if output_dimensionality is not None:
321+
struct_options["output_dimensionality"] = output_dimensionality
322+
323+
sql = f"SELECT * FROM ML.GENERATE_EMBEDDING(MODEL {googlesql.identifier(model_name)}, ({table})"
324+
sql += _build_struct_sql(struct_options)
325+
sql += ")\n"
326+
return sql

tests/unit/bigquery/test_ml.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,32 @@ def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mo
200200
assert "['a', 'b'] AS stop_sequences" in generated_sql
201201
assert "true AS ground_with_google_search" in generated_sql
202202
assert "'TYPE' AS request_type" in generated_sql
203+
204+
205+
@mock.patch("bigframes.pandas.read_gbq_query")
206+
@mock.patch("bigframes.pandas.read_pandas")
207+
def test_generate_embedding_with_pandas_dataframe(
208+
read_pandas_mock, read_gbq_query_mock
209+
):
210+
df = pd.DataFrame({"col1": [1, 2, 3]})
211+
read_pandas_mock.return_value._to_sql_query.return_value = (
212+
"SELECT * FROM `pandas_df`",
213+
[],
214+
[],
215+
)
216+
ml_ops.generate_embedding(
217+
MODEL_SERIES,
218+
input_=df,
219+
flatten_json_output=True,
220+
task_type="RETRIEVAL_DOCUMENT",
221+
output_dimensionality=256,
222+
)
223+
read_pandas_mock.assert_called_once()
224+
read_gbq_query_mock.assert_called_once()
225+
generated_sql = read_gbq_query_mock.call_args[0][0]
226+
assert "ML.GENERATE_EMBEDDING" in generated_sql
227+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
228+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
229+
assert "true AS flatten_json_output" in generated_sql
230+
assert "'RETRIEVAL_DOCUMENT' AS task_type" in generated_sql
231+
assert "256 AS output_dimensionality" in generated_sql

tests/unit/core/sql/test_ml.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,22 @@ def test_generate_text_model_with_options(snapshot):
201201
request_type="TYPE",
202202
)
203203
snapshot.assert_match(sql, "generate_text_model_with_options.sql")
204+
205+
206+
def test_generate_embedding_model_basic(snapshot):
207+
sql = bigframes.core.sql.ml.generate_embedding(
208+
model_name="my_project.my_dataset.my_model",
209+
table="SELECT * FROM new_data",
210+
)
211+
snapshot.assert_match(sql, "generate_embedding_model_basic.sql")
212+
213+
214+
def test_generate_embedding_model_with_options(snapshot):
215+
sql = bigframes.core.sql.ml.generate_embedding(
216+
model_name="my_project.my_dataset.my_model",
217+
table="SELECT * FROM new_data",
218+
flatten_json_output=True,
219+
task_type="RETRIEVAL_DOCUMENT",
220+
output_dimensionality=256,
221+
)
222+
snapshot.assert_match(sql, "generate_embedding_model_with_options.sql")

0 commit comments

Comments
 (0)