Skip to content

Commit dace9de

Browse files
committed
feat(bigframes): update ai.score to match its SQL version
1 parent adbabae commit dace9de

7 files changed

Lines changed: 56 additions & 13 deletions

File tree

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,8 @@ def score(
941941
prompt: PROMPT_TYPE,
942942
*,
943943
connection_id: str | None = None,
944+
endpoint: str | None = None,
945+
max_error_ratio: float | None = None,
944946
) -> series.Series:
945947
"""
946948
Computes a score based on rubrics described in natural language. It will return a double value.
@@ -958,20 +960,21 @@ def score(
958960
2 3.0
959961
dtype: Float64
960962
961-
.. note::
962-
963-
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
964-
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
965-
and might have limited support. For more information, see the launch stage descriptions
966-
(https://cloud.google.com/products#product-launch-stages).
967-
968963
Args:
969964
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
970965
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
971966
or pandas Series.
972967
connection_id (str, optional):
973968
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
974969
If not provided, the query uses your end-user credential.
970+
endpoint (str, optional):
971+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
972+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
973+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model
974+
based on your query to have the best cost to quality tradeoff for the task.
975+
max_error_ratio (float, optional):
976+
A value between `0.0` and `1.0` that contains the maximum acceptable ratio of row-level inference failures to
977+
rows processed on this function. If this value is exceeded, then the query fails.
975978
976979
Returns:
977980
bigframes.series.Series: A new series of double (float) values.
@@ -983,6 +986,8 @@ def score(
983986
operator = ai_ops.AIScore(
984987
prompt_context=tuple(prompt_context),
985988
connection_id=connection_id,
989+
endpoint=endpoint,
990+
max_error_ratio=max_error_ratio,
986991
)
987992

988993
return series_list[0]._apply_nary_op(operator, series_list[1:])

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,8 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal
20052005
return ai_ops.AIScore(
20062006
_construct_prompt(values, op.prompt_context), # type: ignore
20072007
op.connection_id, # type: ignore
2008+
op.endpoint, # type: ignore
2009+
op.max_error_ratio, # type: ignore
20082010
).to_expr()
20092011

20102012

packages/bigframes/bigframes/operations/ai_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ class AIScore(base_ops.NaryOp):
172172

173173
prompt_context: Tuple[str | None, ...]
174174
connection_id: str | None
175+
endpoint: str | None
176+
max_error_ratio: float | None
175177

176178
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
177179
return dtypes.FLOAT_DTYPE

packages/bigframes/bigframes/pandas/io/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,8 @@ def from_glob_path(
654654
def _get_bqclient_and_project() -> Tuple[bigquery.Client, str]:
655655
# Address circular imports in doctest due to bigframes/session/__init__.py
656656
# containing a lot of logic and samples.
657-
from bigframes.session import clients
658657
import bigframes._config.auth
658+
from bigframes.session import clients
659659

660660
credentials, project = bigframes._config.auth.resolve_credentials_and_project(
661661
config.options.bigquery
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
AI.SCORE(
3+
prompt => (`string_col`, ' is the same as ', `string_col`),
4+
endpoint => 'gemini-2.5-flash',
5+
max_error_ratio => 0.5
6+
) AS `result`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,27 @@ def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id)
407407
op = ops.AIScore(
408408
prompt_context=(None, " is the same as ", None),
409409
connection_id=connection_id,
410+
endpoint=None,
411+
max_error_ratio=None,
412+
)
413+
414+
sql = utils._apply_ops_to_sql(
415+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
416+
)
417+
418+
snapshot.assert_match(sql, "out.sql")
419+
420+
421+
def test_ai_score_with_endpoint_and_max_error_ratio(
422+
scalar_types_df: dataframe.DataFrame, snapshot
423+
):
424+
col_name = "string_col"
425+
426+
op = ops.AIScore(
427+
prompt_context=(None, " is the same as ", None),
428+
connection_id=None,
429+
endpoint="gemini-2.5-flash",
430+
max_error_ratio=0.5,
410431
)
411432

412433
sql = utils._apply_ops_to_sql(

packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ class AIIf(Value):
138138

139139
prompt: Value
140140
connection_id: Optional[Value[dt.String]]
141-
endpoint: Optional[Value[dt.String]] = None
142-
optimization_mode: Optional[Value[dt.String]] = None
143-
max_error_ratio: Optional[Value[dt.Float64]] = None
141+
endpoint: Optional[Value[dt.String]]
142+
optimization_mode: Optional[Value[dt.String]]
143+
max_error_ratio: Optional[Value[dt.Float64]]
144144

145145
shape = rlz.shape_like("prompt")
146146

@@ -151,7 +151,7 @@ def dtype(self) -> dt.Struct:
151151

152152
@public
153153
class AIClassify(Value):
154-
"""Generate True/False based on the prompt"""
154+
"""Generate categories based on the prompt"""
155155

156156
input: Value
157157
categories: Value[dt.Array[dt.String]]
@@ -166,13 +166,19 @@ def dtype(self) -> dt.Struct:
166166

167167
@public
168168
class AIScore(Value):
169-
"""Generate doubles based on the prompt"""
169+
"""Generate scores based on the prompt"""
170170

171171
prompt: Value
172172
connection_id: Optional[Value[dt.String]]
173+
endpoint: Optional[Value[dt.String]]
174+
max_error_ratio: Optional[Value[dt.Float64]]
173175

174176
shape = rlz.shape_like("prompt")
175177

178+
@attribute
179+
def dtype(self) -> dt.Struct:
180+
return dt.float64
181+
176182

177183
@public
178184
class AISimilarity(Value):

0 commit comments

Comments
 (0)