Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 13a30e8

Browse files
committed
proof of concept
1 parent 8804ada commit 13a30e8

6 files changed

Lines changed: 167 additions & 1 deletion

File tree

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import json
18+
19+
from typing import List, Literal, Mapping, Sequence, Tuple, Any
20+
21+
from bigframes import series
22+
from bigframes.operations import ai_ops
23+
24+
25+
def ai_generate_bool(
26+
prompt: series.Series | Sequence[str | series.Series],
27+
*,
28+
connection_id: str | None = None,
29+
endpoint: str | None = None,
30+
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
31+
model_params: Mapping[Any, Any] | None = None,
32+
) -> series.Series:
33+
""" """
34+
35+
if request_type not in ("dedicated", "shared", "unspecified"):
36+
raise ValueError(f"Unsupported request type: {request_type}")
37+
38+
if isinstance(prompt, series.Series):
39+
prompt_context, series_list = _separate_context_and_series([prompt])
40+
elif isinstance(prompt, Sequence):
41+
prompt_context, series_list = _separate_context_and_series(prompt)
42+
else:
43+
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
44+
45+
if not series_list:
46+
raise ValueError("Please provide at least one Series in the prompt")
47+
48+
operator = ai_ops.AIGenerateBool(
49+
tuple(prompt_context),
50+
connection_id=connection_id or series_list[0]._session._bq_connection,
51+
endpoint=endpoint,
52+
request_type=request_type,
53+
model_params=json.dumps(model_params) if model_params else None,
54+
)
55+
56+
return series_list[0]._apply_nary_op(operator, series_list[1:])
57+
58+
59+
def _separate_context_and_series(
60+
prompt: Sequence[str | series.Series],
61+
) -> Tuple[List[str | None], List[series.Series]]:
62+
"""
63+
Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
64+
in the prompt. The original item order is kept.
65+
66+
For example:
67+
Input: ("str1", series1, "str2", "str3", series2)
68+
Output: ["str1", None, "str2", "str3", None], [series1, series2]
69+
"""
70+
71+
prompt_context: List[str|None] = []
72+
series_list: List[series.Series] = []
73+
74+
for item in prompt:
75+
if isinstance(item, str):
76+
prompt_context.append(item)
77+
elif isinstance(item, series.Series):
78+
prompt_context.append(None)
79+
else:
80+
raise ValueError(f"Unsupported type in prompt: {type(item)}")
81+
82+
return prompt_context, series_list

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,11 @@ def array_reduce_op_impl(x: ibis_types.Value, op: ops.ArrayReduceOp):
12281228
op.aggregation, typing.cast(ibis_types.Column, arr_vals)
12291229
)
12301230
)
1231+
1232+
# AI Ops
1233+
@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True)
1234+
def ai_generate_bool(*values: ibis_types.Value, op: ops.AIGenerateBool):
1235+
12311236

12321237

12331238
# JSON Ops
@@ -2172,3 +2177,14 @@ def str_strip_op( # type: ignore[empty-body]
21722177
x: ibis_dtypes.String, to_strip: ibis_dtypes.String
21732178
) -> ibis_dtypes.String:
21742179
"""Remove leading and trailing characters."""
2180+
2181+
2182+
@ibis_udf.scalar.builtin(name="AI.GENERATE_BOOL", named_args=True, ignore_none_values=True)
2183+
def ai_generate_bool( # type: ignore[empty-body]
2184+
prompt: ibis_types.Value,
2185+
connection_id: str,
2186+
endpoint = None,
2187+
request_type = None,
2188+
model_params = None,
2189+
) -> ibis_dtypes.Value:
2190+
"""Call AI.GENERATE_BOOL with the prompt."""

bigframes/operations/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from bigframes.operations.ai_ops import AIGenerateBool
1718
from bigframes.operations.array_ops import (
1819
ArrayIndexOp,
1920
ArrayReduceOp,
@@ -408,6 +409,8 @@
408409
"geo_x_op",
409410
"geo_y_op",
410411
"GeoStDistanceOp",
412+
#AI ops
413+
"AIGenerateBool",
411414
# Numpy ops mapping
412415
"NUMPY_TO_BINOP",
413416
"NUMPY_TO_OP",

bigframes/operations/ai_ops.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from typing import Tuple, Literal
19+
20+
import pandas as pd
21+
import pyarrow as pa
22+
23+
from bigframes import dtypes
24+
from bigframes.operations import base_ops
25+
26+
27+
@dataclasses.dataclass(frozen=True)
28+
class AIGenerateBool(base_ops.NaryOp):
29+
30+
# Prompt with column referneces replaced with "None" placeholder
31+
prompt_context: Tuple[str | None, ...]
32+
33+
connection_id: str
34+
endpoint: str | None
35+
request_type: Literal["dedicated", "shared", "unspecified"]
36+
model_params: str | None
37+
38+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
39+
return pd.ArrowDtype(
40+
pa.struct(
41+
(
42+
pa.field("result", pa.bool_()),
43+
pa.field("full_response", pa.string()),
44+
pa.field("status", pa.string()),
45+
)
46+
)
47+
)

third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,15 @@ def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str:
12331233
)
12341234

12351235
def visit_ScalarUDF(self, op, **kw):
1236-
return self.f[self.__sql_name__(op)](*kw.values())
1236+
if op.__config__.get("named_args"):
1237+
args = []
1238+
for name, value in kw.items():
1239+
if op.__config__.get("ignore_none_values") and isinstance(value, sge.Null):
1240+
continue
1241+
args.append(sge.Kwarg(this=sg.to_identifier(name), expression=value))
1242+
else:
1243+
args = list(kw.values())
1244+
return self.f[self.__sql_name__(op)](*args)
12371245

12381246
def visit_AggUDF(self, op, *, where, **kw):
12391247
return self.agg[self.__sql_name__(op)](*kw.values(), where=where)

third_party/bigframes_vendored/ibis/expr/operations/udf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def builtin(
198198
database: str | None = None,
199199
catalog: str | None = None,
200200
signature: tuple[tuple[Any, ...], Any] | None = None,
201+
named_args: bool = False,
202+
ignore_none_values = False,
201203
**kwargs: Any,
202204
) -> Callable[[Callable], Callable[..., ir.Value]]:
203205
...
@@ -212,6 +214,8 @@ def builtin(
212214
database=None,
213215
catalog=None,
214216
signature=None,
217+
named_args=False,
218+
ignore_none_values=False,
215219
**kwargs,
216220
):
217221
"""Construct a scalar user-defined function that is built-in to the backend.
@@ -235,6 +239,10 @@ def builtin(
235239
For **builtin** UDFs, only the **return type** annotation is required.
236240
See [the user guide](/how-to/extending/builtin.qmd#input-types) for
237241
more information.
242+
named_args
243+
Whether to compile the function with named arguments.
244+
ignore_none_values
245+
If true, named arguments whose value is None do no appear in the compiled SQL.
238246
kwargs
239247
Additional backend-specific configuration arguments for the UDF.
240248
@@ -258,6 +266,8 @@ def builtin(
258266
database=database,
259267
catalog=catalog,
260268
signature=signature,
269+
named_args=named_args,
270+
ignore_none_values=ignore_none_values,
261271
**kwargs,
262272
)
263273

0 commit comments

Comments
 (0)