Skip to content

Commit 9d9c9f4

Browse files
committed
refactor
1 parent 17d3ac8 commit 9d9c9f4

4 files changed

Lines changed: 30 additions & 74 deletions

File tree

sqlglot/generators/duckdb.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,26 +2354,13 @@ def tobinary_sql(self, expression: exp.ToBinary) -> str:
23542354

23552355
@unsupported_args("format")
23562356
def tonumber_sql(self, expression: exp.ToNumber) -> str:
2357-
"""
2358-
Snowflake's TO_NUMBER without precision/scale defaults to NUMBER(38, 0),
2359-
which truncates decimals. The parser sets these defaults at parse time.
2360-
Always cast to DECIMAL(precision, scale) using the values from the AST.
2361-
2362-
Oracle's TO_NUMBER without precision/scale should convert to DOUBLE.
2363-
"""
23642357
precision = expression.args.get("precision")
23652358
scale = expression.args.get("scale")
23662359

2367-
# Build DECIMAL type with precision and scale from AST
23682360
if precision and scale:
2369-
# Snowflake parser ensures defaults (38, 0) are set when not specified
23702361
decimal_type = exp.DataType.build(f"DECIMAL({precision.name}, {scale.name})")
2371-
elif precision is None and scale is None:
2372-
# Oracle or other dialects that don't set defaults - convert to DOUBLE
2373-
decimal_type = exp.DataType.build("DOUBLE")
23742362
else:
2375-
# Fallback for partial specification
2376-
decimal_type = exp.DataType.build("DECIMAL(38, 0)")
2363+
decimal_type = exp.DataType.build("DOUBLE")
23772364

23782365
return self.sql(exp.cast(expression.this, decimal_type))
23792366

sqlglot/generators/snowflake.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -738,39 +738,6 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) ->
738738
if expression.is_type(exp.DType.GEOMETRY):
739739
return self.func("TO_GEOMETRY", expression.this)
740740

741-
# Convert CAST to DECIMAL/NUMERIC to TO_NUMBER only for string inputs
742-
# Don't convert TryCast - it's handled by trycast_sql
743-
if expression.is_type(exp.DType.DECIMAL) and not isinstance(expression, exp.TryCast):
744-
value = expression.this
745-
746-
# Annotate types if not already done
747-
if value.type is None:
748-
from sqlglot.optimizer.annotate_types import annotate_types
749-
750-
value = annotate_types(value, dialect=self.dialect)
751-
752-
# Only convert to TO_NUMBER for string inputs
753-
if value.is_string or value.is_type(*exp.DataType.TEXT_TYPES):
754-
# Extract precision and scale from DECIMAL(p, s)
755-
params = expression.to.expressions or []
756-
precision = (
757-
params[0].this
758-
if len(params) >= 1 and isinstance(params[0], exp.DataTypeParam)
759-
else None
760-
)
761-
scale = (
762-
params[1].this
763-
if len(params) >= 2 and isinstance(params[1], exp.DataTypeParam)
764-
else None
765-
)
766-
767-
to_number = exp.ToNumber(
768-
this=value,
769-
precision=precision,
770-
scale=scale,
771-
)
772-
return self.tonumber_sql(to_number)
773-
774741
return super().cast_sql(expression, safe_prefix=safe_prefix)
775742

776743
def trycast_sql(self, expression: exp.TryCast) -> str:

sqlglot/parsers/snowflake.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ def _build_approx_top_k(args: t.List) -> exp.ApproxTopK:
3939
return exp.ApproxTopK.from_arg_list(args)
4040

4141

42+
def _build_to_number(args: t.List, safe: bool = False) -> exp.ToNumber:
43+
second_arg = seq_get(args, 1)
44+
if second_arg and second_arg.is_number:
45+
fmt = None
46+
precision = second_arg
47+
scale = seq_get(args, 2) or exp.Literal.number(0)
48+
else:
49+
fmt = second_arg
50+
precision = seq_get(args, 2) or exp.Literal.number(38)
51+
scale = seq_get(args, 3) or exp.Literal.number(0)
52+
53+
return exp.ToNumber(
54+
this=seq_get(args, 0),
55+
format=fmt,
56+
precision=precision,
57+
scale=scale,
58+
safe=safe,
59+
)
60+
61+
4262
def _build_date_from_parts(args: t.List) -> exp.DateFromParts:
4363
return exp.DateFromParts(
4464
year=seq_get(args, 0),
@@ -623,15 +643,7 @@ class SnowflakeParser(parser.Parser):
623643
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DType.DATE, safe=True),
624644
**dict.fromkeys(
625645
("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"),
626-
lambda args: exp.ToNumber(
627-
this=seq_get(args, 0),
628-
format=seq_get(args, 1) if len(args) in (2, 4) else None,
629-
precision=(seq_get(args, 2) if len(args) in (2, 4) else seq_get(args, 1))
630-
or exp.Literal.number(38),
631-
scale=(seq_get(args, 3) if len(args) in (2, 4) else seq_get(args, 2))
632-
or exp.Literal.number(0),
633-
safe=True,
634-
),
646+
lambda args: _build_to_number(args, safe=True),
635647
),
636648
"TRY_TO_DOUBLE": lambda args: exp.ToDouble(
637649
this=seq_get(args, 0), format=seq_get(args, 1), safe=True
@@ -654,14 +666,7 @@ class SnowflakeParser(parser.Parser):
654666
"TO_DATE": _build_datetime("TO_DATE", exp.DType.DATE),
655667
**dict.fromkeys(
656668
("TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"),
657-
lambda args: exp.ToNumber(
658-
this=seq_get(args, 0),
659-
format=seq_get(args, 1) if len(args) in (2, 4) else None,
660-
precision=(seq_get(args, 2) if len(args) in (2, 4) else seq_get(args, 1))
661-
or exp.Literal.number(38),
662-
scale=(seq_get(args, 3) if len(args) in (2, 4) else seq_get(args, 2))
663-
or exp.Literal.number(0),
664-
),
669+
lambda args: _build_to_number(args),
665670
),
666671
"TO_TIME": _build_datetime("TO_TIME", exp.DType.TIME),
667672
"TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DType.TIMESTAMP),

tests/dialects/test_duckdb.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -763,37 +763,34 @@ def test_duckdb(self):
763763
},
764764
)
765765

766-
# TO_NUMBER transpilation from Snowflake to DuckDB
767766
self.validate_all(
768767
"SELECT CAST('12.3456' AS DECIMAL(38, 0))",
769768
read={
770769
"snowflake": "SELECT TO_NUMBER('12.3456')",
771770
},
772771
write={
773772
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(38, 0))",
774-
"snowflake": "SELECT TO_NUMBER('12.3456')",
775773
},
776774
)
777775
self.validate_all(
778-
"SELECT CAST('12.3456' AS DECIMAL(10, 1))",
776+
"SELECT CAST('12.3456' AS DECIMAL(10, 0))",
779777
read={
780-
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 1)",
778+
"snowflake": "SELECT TO_NUMBER('12.3456', 10)",
781779
},
782780
write={
783-
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 1))",
784-
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 1)",
781+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 0))",
785782
},
786783
)
787784
self.validate_all(
788-
"SELECT CAST('3,741.72' AS DECIMAL(6, 2))",
785+
"SELECT CAST('12.3456' AS DECIMAL(10, 2))",
789786
read={
790-
"snowflake": "SELECT TO_DECIMAL('3,741.72', '9,999.99', 6, 2)",
787+
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 2)",
791788
},
792789
write={
793-
"duckdb": "SELECT CAST('3,741.72' AS DECIMAL(6, 2))",
794-
"snowflake": "SELECT TO_NUMBER('3,741.72', 6, 2)", # Format is lost during transpilation
790+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 2))",
795791
},
796792
)
793+
797794
self.validate_all(
798795
"VAR_POP(x)",
799796
read={

0 commit comments

Comments
 (0)