1+ # Copyright 2025 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ from __future__ import annotations
16+
17+ import json
18+
19+ from typing import List , Literal , Mapping , Sequence , Tuple , Any
20+
21+ from bigframes import series
22+ from bigframes .operations import ai_ops
23+
24+
25+ def ai_generate_bool (
26+ prompt : series .Series | Sequence [str | series .Series ],
27+ * ,
28+ connection_id : str | None = None ,
29+ endpoint : str | None = None ,
30+ request_type : Literal ["dedicated" , "shared" , "unspecified" ] = "unspecified" ,
31+ model_params : Mapping [Any , Any ] | None = None ,
32+ ) -> series .Series :
33+ """ """
34+
35+ if request_type not in ("dedicated" , "shared" , "unspecified" ):
36+ raise ValueError (f"Unsupported request type: { request_type } " )
37+
38+ if isinstance (prompt , series .Series ):
39+ prompt_context , series_list = _separate_context_and_series ([prompt ])
40+ elif isinstance (prompt , Sequence ):
41+ prompt_context , series_list = _separate_context_and_series (prompt )
42+ else :
43+ raise ValueError (f"Unsupported prompt type: { type (prompt )} " )
44+
45+ if not series_list :
46+ raise ValueError ("Please provide at least one Series in the prompt" )
47+
48+ operator = ai_ops .AIGenerateBool (
49+ tuple (prompt_context ),
50+ connection_id = connection_id or series_list [0 ]._session ._bq_connection ,
51+ endpoint = endpoint ,
52+ request_type = request_type ,
53+ model_params = json .dumps (model_params ) if model_params else None ,
54+ )
55+
56+ return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
57+
58+
59+ def _separate_context_and_series (
60+ prompt : Sequence [str | series .Series ],
61+ ) -> Tuple [List [str | None ], List [series .Series ]]:
62+ """
63+ Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
64+ in the prompt. The original item order is kept.
65+
66+ For example:
67+ Input: ("str1", series1, "str2", "str3", series2)
68+ Output: ["str1", None, "str2", "str3", None], [series1, series2]
69+ """
70+
71+ prompt_context : List [str | None ] = []
72+ series_list : List [series .Series ] = []
73+
74+ for item in prompt :
75+ if isinstance (item , str ):
76+ prompt_context .append (item )
77+ elif isinstance (item , series .Series ):
78+ prompt_context .append (None )
79+ else :
80+ raise ValueError (f"Unsupported type in prompt: { type (item )} " )
81+
82+ return prompt_context , series_list
0 commit comments